1# --------------------------------------------------------------------------------------------
2# Copyright (c) Microsoft Corporation. All rights reserved.
3# Licensed under the MIT License. See License.txt in the project root for license information.
4# --------------------------------------------------------------------------------------------
5
6# pylint: disable=C0302
7from enum import Enum
8import calendar
9from datetime import datetime
10from dateutil.parser import parse
11
12from azure.cli.core.util import (
13    CLIError,
14    sdk_no_wait,
15)
16
17from azure.mgmt.sql.models import (
18    AdministratorName,
19    AdministratorType,
20    AuthenticationName,
21    BlobAuditingPolicyState,
22    CapabilityGroup,
23    CapabilityStatus,
24    ConnectionPolicyName,
25    CreateMode,
26    EncryptionProtector,
27    EncryptionProtectorName,
28    FailoverGroup,
29    FailoverGroupReadOnlyEndpoint,
30    FailoverGroupReadWriteEndpoint,
31    FailoverGroupReplicationRole,
32    FirewallRule,
33    InstanceFailoverGroup,
34    InstanceFailoverGroupReadOnlyEndpoint,
35    InstanceFailoverGroupReadWriteEndpoint,
36    LedgerDigestUploadsName,
37    LongTermRetentionPolicyName,
38    ManagedInstanceAzureADOnlyAuthentication,
39    ManagedInstanceEncryptionProtector,
40    ManagedInstanceExternalAdministrator,
41    ManagedInstanceKey,
42    ManagedInstanceLongTermRetentionPolicyName,
43    ManagedInstancePairInfo,
44    ManagedShortTermRetentionPolicyName,
45    OutboundFirewallRule,
46    PartnerInfo,
47    PartnerRegionInfo,
48    PerformanceLevelUnit,
49    ResourceIdentity,
50    RestoreDetailsName,
51    SecurityAlertPolicyName,
52    SecurityAlertPolicyState,
53    SensitivityLabel,
54    SensitivityLabelSource,
55    ServerAzureADOnlyAuthentication,
56    ServerConnectionPolicy,
57    ServerExternalAdministrator,
58    ServerInfo,
59    ServerKey,
60    ServerKeyType,
61    ServerNetworkAccessFlag,
62    ServiceObjectiveName,
63    ServerTrustGroup,
64    ShortTermRetentionPolicyName,
65    Sku,
66    StorageKeyType,
67    TransparentDataEncryptionName,
68    UserIdentity,
69    VirtualNetworkRule
70)
71
72from azure.cli.core.profiles import ResourceType
73from azure.cli.core.commands.client_factory import get_mgmt_service_client
74from azure.cli.command_modules.monitor._client_factory import cf_monitor
75from azure.cli.command_modules.monitor.operations.diagnostics_settings import create_diagnostics_settings
76
77from knack.log import get_logger
78from knack.prompting import prompt_y_n
79
80from ._util import (
81    get_sql_capabilities_operations,
82    get_sql_servers_operations,
83    get_sql_managed_instances_operations,
84    get_sql_restorable_dropped_database_managed_backup_short_term_retention_policies_operations,
85)
86
87
88logger = get_logger(__name__)
89
90###############################################
91#                Common funcs                 #
92###############################################
93
94
95def _get_server_location(cli_ctx, server_name, resource_group_name):
96    '''
97    Returns the location (i.e. Azure region) that the specified server is in.
98    '''
99
100    server_client = get_sql_servers_operations(cli_ctx, None)
101    # pylint: disable=no-member
102    return server_client.get(
103        server_name=server_name,
104        resource_group_name=resource_group_name).location
105
106
107def _get_managed_restorable_dropped_database_backup_short_term_retention_client(cli_ctx):
108    '''
109    Returns client for managed restorable dropped databases.
110    '''
111
112    server_client = \
113        get_sql_restorable_dropped_database_managed_backup_short_term_retention_policies_operations(cli_ctx, None)
114
115    # pylint: disable=no-member
116    return server_client
117
118
119def _get_managed_instance_location(cli_ctx, managed_instance_name, resource_group_name):
120    '''
121    Returns the location (i.e. Azure region) that the specified managed instance is in.
122    '''
123
124    managed_instance_client = get_sql_managed_instances_operations(cli_ctx, None)
125    # pylint: disable=no-member
126    return managed_instance_client.get(
127        managed_instance_name=managed_instance_name,
128        resource_group_name=resource_group_name).location
129
130
131def _get_location_capability(cli_ctx, location, group):
132    '''
133    Gets the location capability for a location and verifies that it is available.
134    '''
135
136    capabilities_client = get_sql_capabilities_operations(cli_ctx, None)
137    location_capability = capabilities_client.list_by_location(location, group)
138    _assert_capability_available(location_capability)
139    return location_capability
140
141
142def _any_sku_values_specified(sku):
143    '''
144    Returns True if the sku object has any properties that are specified
145    (i.e. not None).
146    '''
147
148    return any(val for key, val in sku.__dict__.items())
149
150
151def _compute_model_matches(sku_name, compute_model):
152    '''
153    Returns True if sku name matches the compute model.
154    Please update is function if compute_model has more than 2 enums.
155    '''
156
157    if (_is_serverless_slo(sku_name) and compute_model == ComputeModelType.serverless):
158        return True
159    if (not _is_serverless_slo(sku_name) and compute_model != ComputeModelType.serverless):
160        return True
161    return False
162
163
164def _is_serverless_slo(sku_name):
165    '''
166    Returns True if the sku name is a serverless sku.
167    '''
168
169    return "_S_" in sku_name
170
171
172def _get_default_server_version(location_capabilities):
173    '''
174    Gets the default server version capability from the full location
175    capabilities response.
176
177    If none have 'default' status, gets the first capability that has
178    'available' status.
179
180    If there is no default or available server version, falls back to
181    server version 12.0 in order to maintain compatibility with older
182    Azure CLI releases (2.0.25 and earlier).
183    '''
184    server_versions = location_capabilities.supported_server_versions
185
186    def is_v12(capability):
187        return capability.name == "12.0"
188
189    return _get_default_capability(server_versions, fallback_predicate=is_v12)
190
191
192def _get_default_capability(capabilities, fallback_predicate=None):
193    '''
194    Gets the first capability in the collection that has 'default' status.
195    If none have 'default' status, gets the first capability that has 'available' status.
196    '''
197    logger.debug('_get_default_capability: %s', capabilities)
198
199    # Get default capability
200    r = next((c for c in capabilities if c.status == CapabilityStatus.DEFAULT), None)
201    if r:
202        logger.debug('_get_default_capability found default: %s', r)
203        return r
204
205    # No default capability, so fallback to first available capability
206    r = next((c for c in capabilities if c.status == CapabilityStatus.AVAILABLE), None)
207    if r:
208        logger.debug('_get_default_capability found available: %s', r)
209        return r
210
211    # No available capability, so use custom fallback
212    if fallback_predicate:
213        logger.debug('_get_default_capability using fallback')
214        r = next((c for c in capabilities if fallback_predicate(c)), None)
215        if r:
216            logger.debug('_get_default_capability found fallback: %s', r)
217            return r
218
219    # No custom fallback, so we have to throw an error.
220    logger.debug('_get_default_capability failed')
221    raise CLIError('Provisioning is restricted in this region. Please choose a different region.')
222
223
224def _assert_capability_available(capability):
225    '''
226    Asserts that the capability is available (or default). Throws CLIError if the
227    capability is unavailable.
228    '''
229    logger.debug('_assert_capability_available: %s', capability)
230
231    if not is_available(capability.status):
232        raise CLIError(capability.reason)
233
234
235def is_available(status):
236    '''
237    Returns True if the capability status is available (including default).
238    There are three capability statuses:
239        VISIBLE: customer can see the slo but cannot use it
240        AVAILABLE: customer can see the slo and can use it
241        DEFAULT: customer can see the slo and can use it
242    Thus, only check whether status is not VISIBLE would return the correct value.
243    '''
244
245    return status not in CapabilityStatus.VISIBLE
246
247
248def _filter_available(capabilities):
249    '''
250    Filters out the capabilities by removing values that are not available.
251    '''
252
253    return [c for c in capabilities if is_available(c.status)]
254
255
256def _find_edition_capability(sku, supported_editions):
257    '''
258    Finds the DB edition capability in the collection of supported editions
259    that matches the requested sku.
260
261    If the sku has no edition specified, returns the default edition.
262
263    (Note: tier and edition mean the same thing.)
264    '''
265    logger.debug('_find_edition_capability: %s; %s', sku, supported_editions)
266
267    if sku.tier:
268        # Find requested edition capability
269        try:
270            return next(e for e in supported_editions if e.name == sku.tier)
271        except StopIteration:
272            candidate_editions = [e.name for e in supported_editions]
273            raise CLIError('Could not find tier ''{}''. Supported tiers are: {}'.format(
274                sku.tier, candidate_editions
275            ))
276    else:
277        # Find default edition capability
278        return _get_default_capability(supported_editions)
279
280
281def _find_family_capability(sku, supported_families):
282    '''
283    Finds the family capability in the collection of supported families
284    that matches the requested sku.
285
286    If the edition has no family specified, returns the default family.
287    '''
288    logger.debug('_find_family_capability: %s; %s', sku, supported_families)
289
290    if sku.family:
291        # Find requested family capability
292        try:
293            return next(f for f in supported_families if f.name == sku.family)
294        except StopIteration:
295            candidate_families = [e.name for e in supported_families]
296            raise CLIError('Could not find family ''{}''. Supported families are: {}'.format(
297                sku.family, candidate_families
298            ))
299    else:
300        # Find default family capability
301        return _get_default_capability(supported_families)
302
303
304def _find_performance_level_capability(sku, supported_service_level_objectives, allow_reset_family, compute_model=None):
305    '''
306    Finds the DB or elastic pool performance level (i.e. service objective) in the
307    collection of supported service objectives that matches the requested sku's
308    family and capacity.
309
310    If the sku has no capacity or family specified, returns the default service
311    objective.
312    '''
313
314    logger.debug('_find_performance_level_capability: %s, %s, allow_reset_family: %s, compute_model: %s',
315                 sku, supported_service_level_objectives, allow_reset_family, compute_model)
316
317    if sku.capacity:
318        try:
319            # Find requested service objective based on capacity & family.
320            # Note that for non-vcore editions, family is None.
321            return next(slo for slo in supported_service_level_objectives
322                        if ((slo.sku.family == sku.family) or
323                            (slo.sku.family is None and allow_reset_family)) and
324                        int(slo.sku.capacity) == int(sku.capacity) and
325                        _compute_model_matches(slo.sku.name, compute_model))
326        except StopIteration:
327            if allow_reset_family:
328                raise CLIError(
329                    "Could not find sku in tier '{tier}' with capacity {capacity}."
330                    " Supported capacities for '{tier}' are: {capacities}."
331                    " Please specify one of these supported values for capacity.".format(
332                        tier=sku.tier,
333                        capacity=sku.capacity,
334                        capacities=[slo.sku.capacity for slo in supported_service_level_objectives]
335                    ))
336            raise CLIError(
337                "Could not find sku in tier '{tier}' with family '{family}', capacity {capacity}."
338                " Supported families & capacities for '{tier}' are: {skus}. Please specify one of these"
339                " supported combinations of family and capacity.".format(
340                    tier=sku.tier,
341                    family=sku.family,
342                    capacity=sku.capacity,
343                    skus=[(slo.sku.family, slo.sku.capacity)
344                          for slo in supported_service_level_objectives]
345                ))
346    elif sku.family:
347        # Error - cannot find based on family alone.
348        raise CLIError('If --family is specified, --capacity must also be specified.')
349    else:
350        # Find default service objective
351        return _get_default_capability(supported_service_level_objectives)
352
353
354def _db_elastic_pool_update_sku(
355        cmd,
356        instance,
357        service_objective,
358        tier,
359        family,
360        capacity,
361        find_sku_from_capabilities_func,
362        compute_model=None):
363    '''
364    Updates the sku of a DB or elastic pool.
365    '''
366
367    # Set sku name
368    if service_objective:
369        instance.sku = Sku(name=service_objective)
370
371    # Set tier
372    allow_reset_family = False
373    if tier:
374        if not service_objective:
375            # Wipe out old sku name so that it does not conflict with new tier
376            instance.sku.name = None
377
378        instance.sku.tier = tier
379
380        if instance.sku.family and not family:
381            # If we are changing tier and old sku has family but
382            # new family is unspecified, allow sku search to wipe out family.
383            #
384            # This is needed so that tier can be successfully changed from
385            # a tier that has family (e.g. GeneralPurpose) to a tier that has
386            # no family (e.g. Standard).
387            allow_reset_family = True
388
389    # Set family
390    if family:
391        if not service_objective:
392            # Wipe out old sku name so that it does not conflict with new family
393            instance.sku.name = None
394        instance.sku.family = family
395
396    # Set capacity
397    if capacity:
398        instance.sku.capacity = capacity
399
400    # Wipe out sku name if serverless vs provisioned db offerings changed,
401    # only if sku name has not be wiped by earlier logic, and new compute model has been requested.
402    if instance.sku.name and compute_model:
403        if not _compute_model_matches(instance.sku.name, compute_model):
404            instance.sku.name = None
405
406    # If sku name was wiped out by any of the above, resolve the requested sku name
407    # using capabilities.
408    if not instance.sku.name:
409        instance.sku = find_sku_from_capabilities_func(
410            cmd.cli_ctx, instance.location, instance.sku,
411            allow_reset_family=allow_reset_family, compute_model=compute_model)
412
413
414def _get_tenant_id():
415    '''
416    Gets tenantId from current subscription.
417    '''
418    from azure.cli.core._profile import Profile
419
420    profile = Profile()
421    sub = profile.get_subscription()
422    return sub['tenantId']
423
424
425def _get_identity_object_from_type(
426        assignIdentityIsPresent,
427        resourceIdentityType,
428        userAssignedIdentities,
429        existingResourceIdentity):
430    '''
431    Gets the resource identity type.
432    '''
433    identityResult = None
434
435    if resourceIdentityType is not None and resourceIdentityType == ResourceIdType.none.value:
436        identityResult = ResourceIdentity(type=ResourceIdType.none.value)
437        return identityResult
438
439    if assignIdentityIsPresent and resourceIdentityType is not None:
440        # When UMI is of type SystemAssigned,UserAssigned
441        if resourceIdentityType == ResourceIdType.system_assigned_user_assigned.value:
442            umiDict = None
443
444            if userAssignedIdentities is None:
445                raise CLIError('"The list of user assigned identity ids needs to be passed if the'
446                               'IdentityType is UserAssigned or SystemAssignedUserAssigned.')
447
448            if existingResourceIdentity is not None and existingResourceIdentity.user_assigned_identities is not None:
449                identityResult = _get_sys_assigned_user_assigned_identity(userAssignedIdentities,
450                                                                          existingResourceIdentity)
451
452            # Create scenarios
453            else:
454                for identity in userAssignedIdentities:
455                    if umiDict is None:
456                        umiDict = {identity: UserIdentity()}
457                    else:
458                        umiDict[identity] = UserIdentity()  # pylint: disable=unsupported-assignment-operation
459
460                identityResult = ResourceIdentity(type=ResourceIdType.system_assigned_user_assigned.value,
461                                                  user_assigned_identities=umiDict)
462        # When UMI is of type UserAssigned
463        if resourceIdentityType == ResourceIdType.user_assigned.value:
464            umiDict = None
465
466            if userAssignedIdentities is None:
467                raise CLIError('"The list of user assigned identity ids needs to be passed if the '
468                               'IdentityType is UserAssigned or SystemAssignedUserAssigned.')
469
470            if existingResourceIdentity is not None and existingResourceIdentity.user_assigned_identities is not None:
471                identityResult = _get__user_assigned_identity(userAssignedIdentities, existingResourceIdentity)
472
473            else:
474                for identity in userAssignedIdentities:
475                    if umiDict is None:
476                        umiDict = {identity: UserIdentity()}
477                    else:
478                        umiDict[identity] = UserIdentity()  # pylint: disable=unsupported-assignment-operation
479
480                identityResult = ResourceIdentity(type=ResourceIdType.user_assigned.value,
481                                                  user_assigned_identities=umiDict)
482    elif assignIdentityIsPresent:
483        identityResult = ResourceIdentity(type=ResourceIdType.system_assigned.value)
484
485    if assignIdentityIsPresent is False and existingResourceIdentity is not None:
486        identityResult = existingResourceIdentity
487
488    print(identityResult)
489    return identityResult
490
491
492def _get_sys_assigned_user_assigned_identity(
493        userAssignedIdentities,
494        existingResourceIdentity):
495
496    for identity in userAssignedIdentities:
497        existingResourceIdentity.user_assigned_identities.update({identity: UserIdentity()})
498
499    identityResult = ResourceIdentity(type=ResourceIdType.system_assigned_user_assigned.value)
500
501    return identityResult
502
503
504def _get__user_assigned_identity(
505        userAssignedIdentities,
506        existingResourceIdentity):
507
508    for identity in userAssignedIdentities:
509        existingResourceIdentity.user_assigned_identities.update({identity: UserIdentity()})
510
511    identityResult = ResourceIdentity(type=ResourceIdType.user_assigned.value)
512
513    return identityResult
514
515
516_DEFAULT_SERVER_VERSION = "12.0"
517
518
519def _failover_group_update_common(
520        instance,
521        failover_policy=None,
522        grace_period=None,):
523    '''
524    Updates the failover group grace period and failover policy. Common logic for both Sterling and Managed Instance
525    '''
526
527    if failover_policy is not None:
528        instance.read_write_endpoint.failover_policy = failover_policy
529
530    if instance.read_write_endpoint.failover_policy == FailoverPolicyType.manual.value:
531        grace_period = None
532        instance.read_write_endpoint.failover_with_data_loss_grace_period_minutes = grace_period
533
534    if grace_period is not None:
535        grace_period = int(grace_period) * 60
536        instance.read_write_endpoint.failover_with_data_loss_grace_period_minutes = grace_period
537
538
539def _complete_maintenance_configuration_id(cli_ctx, argument_value=None):
540    '''
541    Completes maintenance configuration id from short to full type if needed
542    '''
543
544    from msrestazure.tools import resource_id, is_valid_resource_id
545    from azure.cli.core.commands.client_factory import get_subscription_id
546
547    if argument_value and not is_valid_resource_id(argument_value):
548        return resource_id(
549            subscription=get_subscription_id(cli_ctx),
550            namespace='Microsoft.Maintenance',
551            type='publicMaintenanceConfigurations',
552            name=argument_value)
553
554    return argument_value
555
556###############################################
557#                sql db                       #
558###############################################
559
560
561# pylint: disable=too-few-public-methods
562class ClientType(Enum):
563    '''
564    Types of SQL clients whose connection strings we can generate.
565    '''
566
567    ado_net = 'ado.net'
568    sqlcmd = 'sqlcmd'
569    jdbc = 'jdbc'
570    php_pdo = 'php_pdo'
571    php = 'php'
572    odbc = 'odbc'
573
574
575class ClientAuthenticationType(Enum):
576    '''
577    Types of SQL client authentication mechanisms for connection strings
578    that we can generate.
579    '''
580
581    sql_password = 'SqlPassword'
582    active_directory_password = 'ADPassword'
583    active_directory_integrated = 'ADIntegrated'
584
585
586class FailoverPolicyType(Enum):
587    automatic = 'Automatic'
588    manual = 'Manual'
589
590
591class SqlServerMinimalTlsVersionType(Enum):
592    tls_1_0 = "1.0"
593    tls_1_1 = "1.1"
594    tls_1_2 = "1.2"
595
596
597class ResourceIdType(Enum):
598    '''
599    Gets the type of resource identity.
600    '''
601    system_assigned = 'SystemAssigned'
602    user_assigned = 'UserAssigned'
603    system_assigned_user_assigned = 'SystemAssigned,UserAssigned'
604    none = 'None'
605
606
607class SqlManagedInstanceMinimalTlsVersionType(Enum):
608    no_tls = "None"
609    tls_1_0 = "1.0"
610    tls_1_1 = "1.1"
611    tls_1_2 = "1.2"
612
613
614class ComputeModelType(str, Enum):
615
616    provisioned = "Provisioned"
617    serverless = "Serverless"
618
619
620class DatabaseEdition(str, Enum):
621
622    web = "Web"
623    business = "Business"
624    basic = "Basic"
625    standard = "Standard"
626    premium = "Premium"
627    premium_rs = "PremiumRS"
628    free = "Free"
629    stretch = "Stretch"
630    data_warehouse = "DataWarehouse"
631    system = "System"
632    system2 = "System2"
633    general_purpose = "GeneralPurpose"
634    business_critical = "BusinessCritical"
635    hyperscale = "Hyperscale"
636
637
638class AuthenticationType(str, Enum):
639
640    sql = "SQL"
641    ad_password = "ADPassword"
642
643
644def _get_server_dns_suffx(cli_ctx):
645    '''
646    Gets the DNS suffix for servers in this Azure environment.
647    '''
648
649    # Allow dns suffix to be overridden by environment variable for testing purposes
650    from os import getenv
651    return getenv('_AZURE_CLI_SQL_DNS_SUFFIX', default=cli_ctx.cloud.suffixes.sql_server_hostname)
652
653
654def _get_managed_db_resource_id(cli_ctx, resource_group_name, managed_instance_name, database_name):
655    '''
656    Gets the Managed db resource id in this Azure environment.
657    '''
658
659    # url parse package has different names in Python 2 and 3. 'six' package works cross-version.
660    from azure.cli.core.commands.client_factory import get_subscription_id
661    from msrestazure.tools import resource_id
662
663    return resource_id(
664        subscription=get_subscription_id(cli_ctx),
665        resource_group=resource_group_name,
666        namespace='Microsoft.Sql', type='managedInstances',
667        name=managed_instance_name,
668        child_type_1='databases',
669        child_name_1=database_name)
670
671
672def _to_filetimeutc(dateTime):
673    '''
674    Changes given datetime to filetimeutc string
675    '''
676
677    NET_epoch = datetime(1601, 1, 1)
678    UNIX_epoch = datetime(1970, 1, 1)
679
680    epoch_delta = (UNIX_epoch - NET_epoch)
681
682    log_time = parse(dateTime)
683
684    net_ts = calendar.timegm((log_time + epoch_delta).timetuple())
685
686    # units of seconds since NET epoch
687    filetime_utc_ts = net_ts * (10**7) + log_time.microsecond * 10
688
689    return filetime_utc_ts
690
691
692def _get_managed_dropped_db_resource_id(
693        cli_ctx,
694        resource_group_name,
695        managed_instance_name,
696        database_name,
697        deleted_time):
698    '''
699    Gets the Managed db resource id in this Azure environment.
700    '''
701
702    # url parse package has different names in Python 2 and 3. 'six' package works cross-version.
703    from six.moves.urllib.parse import quote  # pylint: disable=import-error
704    from azure.cli.core.commands.client_factory import get_subscription_id
705    from msrestazure.tools import resource_id
706
707    return (resource_id(
708        subscription=get_subscription_id(cli_ctx),
709        resource_group=resource_group_name,
710        namespace='Microsoft.Sql', type='managedInstances',
711        name=managed_instance_name,
712        child_type_1='restorableDroppedDatabases',
713        child_name_1='{},{}'.format(
714            quote(database_name),
715            _to_filetimeutc(deleted_time))))
716
717
718def db_show_conn_str(
719        cmd,
720        client_provider,
721        database_name='<databasename>',
722        server_name='<servername>',
723        auth_type=ClientAuthenticationType.sql_password.value):
724    '''
725    Builds a SQL connection string for a specified client provider.
726    '''
727
728    server_suffix = _get_server_dns_suffx(cmd.cli_ctx)
729
730    conn_str_props = {
731        'server': server_name,
732        'server_fqdn': '{}{}'.format(server_name, server_suffix),
733        'server_suffix': server_suffix,
734        'db': database_name
735    }
736
737    formats = {
738        ClientType.ado_net.value: {
739            ClientAuthenticationType.sql_password.value:
740                'Server=tcp:{server_fqdn},1433;Database={db};User ID=<username>;'
741                'Password=<password>;Encrypt=true;Connection Timeout=30;',
742            ClientAuthenticationType.active_directory_password.value:
743                'Server=tcp:{server_fqdn},1433;Database={db};User ID=<username>;'
744                'Password=<password>;Encrypt=true;Connection Timeout=30;'
745                'Authentication="Active Directory Password"',
746            ClientAuthenticationType.active_directory_integrated.value:
747                'Server=tcp:{server_fqdn},1433;Database={db};Encrypt=true;'
748                'Connection Timeout=30;Authentication="Active Directory Integrated"'
749        },
750        ClientType.sqlcmd.value: {
751            ClientAuthenticationType.sql_password.value:
752                'sqlcmd -S tcp:{server_fqdn},1433 -d {db} -U <username> -P <password> -N -l 30',
753            ClientAuthenticationType.active_directory_password.value:
754                'sqlcmd -S tcp:{server_fqdn},1433 -d {db} -U <username> -P <password> -G -N -l 30',
755            ClientAuthenticationType.active_directory_integrated.value:
756                'sqlcmd -S tcp:{server_fqdn},1433 -d {db} -G -N -l 30',
757        },
758        ClientType.jdbc.value: {
759            ClientAuthenticationType.sql_password.value:
760                'jdbc:sqlserver://{server_fqdn}:1433;database={db};user=<username>@{server};'
761                'password=<password>;encrypt=true;trustServerCertificate=false;'
762                'hostNameInCertificate=*{server_suffix};loginTimeout=30',
763            ClientAuthenticationType.active_directory_password.value:
764                'jdbc:sqlserver://{server_fqdn}:1433;database={db};user=<username>;'
765                'password=<password>;encrypt=true;trustServerCertificate=false;'
766                'hostNameInCertificate=*{server_suffix};loginTimeout=30;'
767                'authentication=ActiveDirectoryPassword',
768            ClientAuthenticationType.active_directory_integrated.value:
769                'jdbc:sqlserver://{server_fqdn}:1433;database={db};'
770                'encrypt=true;trustServerCertificate=false;'
771                'hostNameInCertificate=*{server_suffix};loginTimeout=30;'
772                'authentication=ActiveDirectoryIntegrated',
773        },
774        ClientType.php_pdo.value: {
775            # pylint: disable=line-too-long
776            ClientAuthenticationType.sql_password.value:
777                '$conn = new PDO("sqlsrv:server = tcp:{server_fqdn},1433; Database = {db}; LoginTimeout = 30; Encrypt = 1; TrustServerCertificate = 0;", "<username>", "<password>");',
778            ClientAuthenticationType.active_directory_password.value:
779                CLIError('PHP Data Object (PDO) driver only supports SQL Password authentication.'),
780            ClientAuthenticationType.active_directory_integrated.value:
781                CLIError('PHP Data Object (PDO) driver only supports SQL Password authentication.'),
782        },
783        ClientType.php.value: {
784            # pylint: disable=line-too-long
785            ClientAuthenticationType.sql_password.value:
786                '$connectionOptions = array("UID"=>"<username>@{server}", "PWD"=>"<password>", "Database"=>{db}, "LoginTimeout" => 30, "Encrypt" => 1, "TrustServerCertificate" => 0); $serverName = "tcp:{server_fqdn},1433"; $conn = sqlsrv_connect($serverName, $connectionOptions);',
787            ClientAuthenticationType.active_directory_password.value:
788                CLIError('PHP sqlsrv driver only supports SQL Password authentication.'),
789            ClientAuthenticationType.active_directory_integrated.value:
790                CLIError('PHP sqlsrv driver only supports SQL Password authentication.'),
791        },
792        ClientType.odbc.value: {
793            ClientAuthenticationType.sql_password.value:
794                'Driver={{ODBC Driver 13 for SQL Server}};Server=tcp:{server_fqdn},1433;'
795                'Database={db};Uid=<username>@{server};Pwd=<password>;Encrypt=yes;'
796                'TrustServerCertificate=no;',
797            ClientAuthenticationType.active_directory_password.value:
798                'Driver={{ODBC Driver 13 for SQL Server}};Server=tcp:{server_fqdn},1433;'
799                'Database={db};Uid=<username>@{server};Pwd=<password>;Encrypt=yes;'
800                'TrustServerCertificate=no;Authentication=ActiveDirectoryPassword',
801            ClientAuthenticationType.active_directory_integrated.value:
802                'Driver={{ODBC Driver 13 for SQL Server}};Server=tcp:{server_fqdn},1433;'
803                'Database={db};Encrypt=yes;TrustServerCertificate=no;'
804                'Authentication=ActiveDirectoryIntegrated',
805        }
806    }
807
808    f = formats[client_provider][auth_type]
809
810    if isinstance(f, Exception):
811        # Error
812        raise f
813
814    # Success
815    return f.format(**conn_str_props)
816
817
818class DatabaseIdentity():  # pylint: disable=too-few-public-methods
819    '''
820    Helper class to bundle up database identity properties and generate
821    database resource id.
822    '''
823
824    def __init__(self, cli_ctx, database_name, server_name, resource_group_name):
825
826        self.database_name = database_name
827        self.server_name = server_name
828        self.resource_group_name = resource_group_name
829        self.cli_ctx = cli_ctx
830
831    def id(self):
832        # url parse package has different names in Python 2 and 3. 'six' package works cross-version.
833        from six.moves.urllib.parse import quote  # pylint: disable=import-error
834        from azure.cli.core.commands.client_factory import get_subscription_id
835
836        return '/subscriptions/{}/resourceGroups/{}/providers/Microsoft.Sql/servers/{}/databases/{}'.format(
837            quote(get_subscription_id(self.cli_ctx)),
838            quote(self.resource_group_name),
839            quote(self.server_name),
840            quote(self.database_name))
841
842
843def _find_db_sku_from_capabilities(cli_ctx, location, sku, allow_reset_family=False, compute_model=None):
844    '''
845    Given a requested sku which may have some properties filled in
846    (e.g. tier and capacity), finds the canonical matching sku
847    from the given location's capabilities.
848    '''
849
850    logger.debug('_find_db_sku_from_capabilities input: %s', sku)
851
852    if sku.name:
853        # User specified sku.name, so nothing else needs to be resolved.
854        logger.debug('_find_db_sku_from_capabilities return sku as is')
855        return sku
856
857    if not _any_sku_values_specified(sku):
858        # User did not request any properties of sku, so just wipe it out.
859        # Server side will pick a default.
860        logger.debug('_find_db_sku_from_capabilities return None')
861        return None
862
863    # Some properties of sku are specified, but not name. Use the requested properties
864    # to find a matching capability and copy the sku from there.
865
866    # Get location capability
867    loc_capability = _get_location_capability(cli_ctx, location, CapabilityGroup.SUPPORTED_EDITIONS)
868
869    # Get default server version capability
870    server_version_capability = _get_default_server_version(loc_capability)
871
872    # Find edition capability, based on requested sku properties
873    edition_capability = _find_edition_capability(
874        sku, server_version_capability.supported_editions)
875
876    # Find performance level capability, based on requested sku properties
877    performance_level_capability = _find_performance_level_capability(
878        sku, edition_capability.supported_service_level_objectives,
879        allow_reset_family=allow_reset_family,
880        compute_model=compute_model)
881
882    # Ideally, we would return the sku object from capability (`return performance_level_capability.sku`).
883    # However not all db create modes support using `capacity` to find slo, so instead we put
884    # the slo name into the sku name property.
885    result = Sku(name=performance_level_capability.name)
886    logger.debug('_find_db_sku_from_capabilities return: %s', result)
887    return result
888
889
890def _validate_elastic_pool_id(
891        cli_ctx,
892        elastic_pool_id,
893        server_name,
894        resource_group_name):
895    '''
896    Validates elastic_pool_id is either None or a valid resource id.
897
898    If elastic_pool_id has a value but it is not a valid resource id,
899    then assume that user specified elastic pool name which we need to
900    convert to elastic pool id using the provided server & resource group
901    name.
902
903    Returns the elastic_pool_id, which may have been updated and may be None.
904    '''
905
906    from msrestazure.tools import resource_id, is_valid_resource_id
907    from azure.cli.core.commands.client_factory import get_subscription_id
908
909    if elastic_pool_id and not is_valid_resource_id(elastic_pool_id):
910        return resource_id(
911            subscription=get_subscription_id(cli_ctx),
912            resource_group=resource_group_name,
913            namespace='Microsoft.Sql',
914            type='servers',
915            name=server_name,
916            child_type_1='elasticPools',
917            child_name_1=elastic_pool_id)
918
919    return elastic_pool_id
920
921
922def _db_dw_create(
923        cli_ctx,
924        client,
925        source_db,
926        dest_db,
927        no_wait,
928        sku=None,
929        secondary_type=None,
930        **kwargs):
931    '''
932    Creates a DB (with any create mode) or DW.
933    Handles common concerns such as setting location and sku properties.
934    '''
935
936    # This check needs to be here, because server side logic of
937    # finding a default sku for Serverless is not yet implemented.
938    if kwargs['compute_model'] == ComputeModelType.serverless:
939        if not sku or not sku.tier or not sku.family or not sku.capacity:
940            raise CLIError('When creating a severless database, please pass in edition, '
941                           'family, and capacity parameters through -e -f -c')
942
943    # Determine server location
944    kwargs['location'] = _get_server_location(
945        cli_ctx,
946        server_name=dest_db.server_name,
947        resource_group_name=dest_db.resource_group_name)
948
949    # Set create mode properties
950    if source_db:
951        kwargs['source_database_id'] = source_db.id()
952
953    if secondary_type:
954        kwargs['secondary_type'] = secondary_type
955
956    # If sku.name is not specified, resolve the requested sku name
957    # using capabilities.
958    kwargs['sku'] = _find_db_sku_from_capabilities(
959        cli_ctx,
960        kwargs['location'],
961        sku,
962        compute_model=kwargs['compute_model'])
963
964    # Validate elastic pool id
965    kwargs['elastic_pool_id'] = _validate_elastic_pool_id(
966        cli_ctx,
967        kwargs['elastic_pool_id'],
968        dest_db.server_name,
969        dest_db.resource_group_name)
970
971    # Expand maintenance configuration id if needed
972    kwargs['maintenance_configuration_id'] = _complete_maintenance_configuration_id(
973        cli_ctx,
974        kwargs['maintenance_configuration_id'])
975
976    # Create
977    return sdk_no_wait(no_wait, client.begin_create_or_update,
978                       server_name=dest_db.server_name,
979                       resource_group_name=dest_db.resource_group_name,
980                       database_name=dest_db.database_name,
981                       parameters=kwargs)
982
983
984def _should_show_backup_storage_redundancy_warnings(target_db_location):
985    if target_db_location.lower() in ['southeastasia', 'brazilsouth', 'eastasia']:
986        return True
987    return False
988
989
990def _backup_storage_redundancy_specify_geo_warning():
991    print("""Selected value for backup storage redundancy is geo-redundant storage.
992    Note that database backups will be geo-replicated to the paired region.
993    To learn more about Azure Paired Regions visit https://aka.ms/azure-ragrs-regions.""")
994
995
996def _confirm_backup_storage_redundancy_take_geo_warning():
997    # if not storage_account_type:
998    confirmation = prompt_y_n("""You have not specified the value for backup storage redundancy
999    which will default to geo-redundant storage. Note that database backups will be geo-replicated
1000    to the paired region. To learn more about Azure Paired Regions visit https://aka.ms/azure-ragrs-regions.
1001    Do you want to proceed?""")
1002    if not confirmation:
1003        return False
1004    return True
1005
1006
1007def _backup_storage_redundancy_take_source_warning():
1008    print("""You have not specified the value for backup storage redundancy
1009    which will default to source's backup storage redundancy.
1010    To learn more about Azure Paired Regions visit https://aka.ms/azure-ragrs-regions.""")
1011
1012
1013def db_create(
1014        cmd,
1015        client,
1016        database_name,
1017        server_name,
1018        resource_group_name,
1019        no_wait=False,
1020        yes=None,
1021        **kwargs):
1022    '''
1023    Creates a DB (with 'Default' create mode.)
1024    '''
1025
1026    # Check backup storage redundancy configurations
1027    location = _get_server_location(
1028        cmd.cli_ctx,
1029        server_name=server_name,
1030        resource_group_name=resource_group_name)
1031
1032    if not yes and _should_show_backup_storage_redundancy_warnings(location):
1033        if not kwargs['requested_backup_storage_redundancy']:
1034            if not _confirm_backup_storage_redundancy_take_geo_warning():
1035                return None
1036        if kwargs['requested_backup_storage_redundancy'] == 'Geo':
1037            _backup_storage_redundancy_specify_geo_warning()
1038
1039    return _db_dw_create(
1040        cmd.cli_ctx,
1041        client,
1042        None,
1043        DatabaseIdentity(cmd.cli_ctx, database_name, server_name, resource_group_name),
1044        no_wait,
1045        **kwargs)
1046
1047
1048def _use_source_db_tier(
1049        client,
1050        database_name,
1051        server_name,
1052        resource_group_name,
1053        kwargs):
1054    '''
1055    Gets the specified db and copies its sku tier into kwargs.
1056    '''
1057
1058    if _any_sku_values_specified(kwargs['sku']):
1059        source = client.get(resource_group_name, server_name, database_name)
1060        kwargs['sku'].tier = source.sku.tier
1061
1062
1063def db_copy(
1064        cmd,
1065        client,
1066        database_name,
1067        server_name,
1068        resource_group_name,
1069        dest_name,
1070        dest_server_name=None,
1071        dest_resource_group_name=None,
1072        no_wait=False,
1073        **kwargs):
1074    '''
1075    Copies a DB (i.e. create with 'Copy' create mode.)
1076    '''
1077
1078    # Determine optional values
1079    dest_server_name = dest_server_name or server_name
1080    dest_resource_group_name = dest_resource_group_name or resource_group_name
1081
1082    # Set create mode
1083    kwargs['create_mode'] = 'Copy'
1084
1085    # Some sku properties may be filled in from the command line. However
1086    # the sku tier must be the same as the source tier, so it is grabbed
1087    # from the source db instead of from command line.
1088    _use_source_db_tier(
1089        client,
1090        database_name,
1091        server_name,
1092        resource_group_name,
1093        kwargs)
1094
1095    # Check backup storage redundancy configurations
1096    location = _get_server_location(cmd.cli_ctx,
1097                                    server_name=dest_server_name,
1098                                    resource_group_name=dest_resource_group_name)
1099    if _should_show_backup_storage_redundancy_warnings(location):
1100        if not kwargs['requested_backup_storage_redundancy']:
1101            _backup_storage_redundancy_take_source_warning()
1102        if kwargs['requested_backup_storage_redundancy'] == 'Geo':
1103            _backup_storage_redundancy_specify_geo_warning()
1104
1105    return _db_dw_create(
1106        cmd.cli_ctx,
1107        client,
1108        DatabaseIdentity(cmd.cli_ctx, database_name, server_name, resource_group_name),
1109        DatabaseIdentity(cmd.cli_ctx, dest_name, dest_server_name, dest_resource_group_name),
1110        no_wait,
1111        **kwargs)
1112
1113
1114def db_create_replica(
1115        cmd,
1116        client,
1117        database_name,
1118        server_name,
1119        resource_group_name,
1120        partner_server_name,
1121        partner_database_name=None,
1122        partner_resource_group_name=None,
1123        secondary_type=None,
1124        no_wait=False,
1125        **kwargs):
1126    '''
1127    Creates a secondary replica DB (i.e. create with 'Secondary' create mode.)
1128
1129    Custom function makes create mode more convenient.
1130    '''
1131
1132    # Determine optional values
1133    partner_resource_group_name = partner_resource_group_name or resource_group_name
1134    partner_database_name = partner_database_name or database_name
1135
1136    # Set create mode
1137    kwargs['create_mode'] = CreateMode.SECONDARY
1138
1139    # Some sku properties may be filled in from the command line. However
1140    # the sku tier must be the same as the source tier, so it is grabbed
1141    # from the source db instead of from command line.
1142    _use_source_db_tier(
1143        client,
1144        database_name,
1145        server_name,
1146        resource_group_name,
1147        kwargs)
1148
1149    # Check backup storage redundancy configurations
1150    location = _get_server_location(cmd.cli_ctx,
1151                                    server_name=partner_server_name,
1152                                    resource_group_name=partner_resource_group_name)
1153    if _should_show_backup_storage_redundancy_warnings(location):
1154        if not kwargs['requested_backup_storage_redundancy']:
1155            _backup_storage_redundancy_take_source_warning()
1156        if kwargs['requested_backup_storage_redundancy'] == 'Geo':
1157            _backup_storage_redundancy_specify_geo_warning()
1158
1159    return _db_dw_create(
1160        cmd.cli_ctx,
1161        client,
1162        DatabaseIdentity(cmd.cli_ctx, database_name, server_name, resource_group_name),
1163        DatabaseIdentity(cmd.cli_ctx, partner_database_name, partner_server_name, partner_resource_group_name),
1164        no_wait,
1165        secondary_type=secondary_type,
1166        **kwargs)
1167
1168
1169# Renames a database.
1170def db_rename(
1171        cmd,
1172        client,
1173        database_name,
1174        server_name,
1175        resource_group_name,
1176        new_name,
1177        **kwargs):
1178    '''
1179    Renames a DB.
1180    '''
1181    kwargs['id'] = DatabaseIdentity(
1182        cmd.cli_ctx,
1183        new_name,
1184        server_name,
1185        resource_group_name
1186    ).id()
1187
1188    client.rename(
1189        resource_group_name,
1190        server_name,
1191        database_name,
1192        parameters=kwargs)
1193
1194    return client.get(
1195        resource_group_name,
1196        server_name,
1197        new_name)
1198
1199
1200def db_restore(
1201        cmd,
1202        client,
1203        database_name,
1204        server_name,
1205        resource_group_name,
1206        dest_name,
1207        restore_point_in_time=None,
1208        source_database_deletion_date=None,
1209        no_wait=False,
1210        **kwargs):
1211    '''
1212    Restores an existing or deleted DB (i.e. create with 'Restore'
1213    or 'PointInTimeRestore' create mode.)
1214
1215    Custom function makes create mode more convenient.
1216    '''
1217
1218    if not (restore_point_in_time or source_database_deletion_date):
1219        raise CLIError('Either --time or --deleted-time must be specified.')
1220
1221    # Set create mode properties
1222    is_deleted = source_database_deletion_date is not None
1223
1224    kwargs['restore_point_in_time'] = restore_point_in_time
1225    kwargs['source_database_deletion_date'] = source_database_deletion_date
1226    kwargs['create_mode'] = CreateMode.RESTORE if is_deleted else CreateMode.POINT_IN_TIME_RESTORE
1227
1228    # Check backup storage redundancy configurations
1229    location = _get_server_location(cmd.cli_ctx, server_name=server_name, resource_group_name=resource_group_name)
1230    if _should_show_backup_storage_redundancy_warnings(location):
1231        if not kwargs['requested_backup_storage_redundancy']:
1232            _backup_storage_redundancy_take_source_warning()
1233        if kwargs['requested_backup_storage_redundancy'] == 'Geo':
1234            _backup_storage_redundancy_specify_geo_warning()
1235
1236    return _db_dw_create(
1237        cmd.cli_ctx,
1238        client,
1239        DatabaseIdentity(cmd.cli_ctx, database_name, server_name, resource_group_name),
1240        # Cross-server restore is not supported. So dest server/group must be the same as source.
1241        DatabaseIdentity(cmd.cli_ctx, dest_name, server_name, resource_group_name),
1242        no_wait,
1243        **kwargs)
1244
1245
1246# pylint: disable=inconsistent-return-statements
1247def db_failover(
1248        client,
1249        database_name,
1250        server_name,
1251        resource_group_name,
1252        allow_data_loss=False):
1253    '''
1254    Fails over a database by setting the specified database as the new primary.
1255
1256    Wrapper function which uses the server location so that the user doesn't
1257    need to specify replication link id.
1258    '''
1259
1260    # List replication links
1261    links = list(client.list_by_database(
1262        database_name=database_name,
1263        server_name=server_name,
1264        resource_group_name=resource_group_name))
1265
1266    if not links:
1267        raise CLIError('The specified database has no replication links.')
1268
1269    # If a replica is primary, then it has 1 or more links (to its secondaries).
1270    # If a replica is secondary, then it has exactly 1 link (to its primary).
1271    primary_link = next((link for link in links if link.partner_role == FailoverGroupReplicationRole.PRIMARY), None)
1272    if not primary_link:
1273        # No link to a primary, so this must already be a primary. Do nothing.
1274        return
1275
1276    # Choose which failover method to use
1277    if allow_data_loss:
1278        failover_func = client.begin_failover_allow_data_loss
1279    else:
1280        failover_func = client.begin_failover
1281
1282    # Execute failover from the primary to this database
1283    return failover_func(
1284        database_name=database_name,
1285        server_name=server_name,
1286        resource_group_name=resource_group_name,
1287        link_id=primary_link.name)
1288
1289
1290class DatabaseCapabilitiesAdditionalDetails(Enum):  # pylint: disable=too-few-public-methods
1291    '''
1292    Additional details that may be optionally included when getting db capabilities.
1293    '''
1294
1295    max_size = 'max-size'
1296
1297
1298def db_list_capabilities(
1299        client,
1300        location,
1301        edition=None,
1302        service_objective=None,
1303        dtu=None,
1304        vcores=None,
1305        show_details=None,
1306        available=False):
1307    '''
1308    Gets database capabilities and optionally applies the specified filters.
1309    '''
1310
1311    # Fixup parameters
1312    if not show_details:
1313        show_details = []
1314
1315    # Get capabilities tree from server
1316    capabilities = client.list_by_location(location, CapabilityGroup.SUPPORTED_EDITIONS)
1317
1318    # Get subtree related to databases
1319    editions = _get_default_server_version(capabilities).supported_editions
1320
1321    # Filter by edition
1322    if edition:
1323        editions = [e for e in editions if e.name.lower() == edition.lower()]
1324
1325    # Filter by service objective
1326    if service_objective:
1327        for e in editions:
1328            e.supported_service_level_objectives = [
1329                slo for slo in e.supported_service_level_objectives
1330                if slo.name.lower() == service_objective.lower()]
1331
1332    # Filter by dtu
1333    if dtu:
1334        for e in editions:
1335            e.supported_service_level_objectives = [
1336                slo for slo in e.supported_service_level_objectives
1337                if slo.performance_level.value == int(dtu) and
1338                slo.performance_level.unit == PerformanceLevelUnit.DTU]
1339
1340    # Filter by vcores
1341    if vcores:
1342        for e in editions:
1343            e.supported_service_level_objectives = [
1344                slo for slo in e.supported_service_level_objectives
1345                if slo.performance_level.value == int(vcores) and
1346                slo.performance_level.unit == PerformanceLevelUnit.V_CORES]
1347
1348    # Filter by availability
1349    if available:
1350        editions = _filter_available(editions)
1351        for e in editions:
1352            e.supported_service_level_objectives = _filter_available(e.supported_service_level_objectives)
1353            for slo in e.supported_service_level_objectives:
1354                if slo.supported_max_sizes:
1355                    slo.supported_max_sizes = _filter_available(slo.supported_max_sizes)
1356
1357    # Remove editions with no service objectives (due to filters)
1358    editions = [e for e in editions if e.supported_service_level_objectives]
1359
1360    # Optionally hide supported max sizes
1361    if DatabaseCapabilitiesAdditionalDetails.max_size.value not in show_details:
1362        for e in editions:
1363            for slo in e.supported_service_level_objectives:
1364                if slo.supported_max_sizes:
1365                    slo.supported_max_sizes = []
1366
1367    return editions
1368
1369
1370# pylint: disable=inconsistent-return-statements
1371def db_delete_replica_link(
1372        client,
1373        database_name,
1374        server_name,
1375        resource_group_name,
1376        # Partner dbs must have the same name as one another
1377        partner_server_name,
1378        partner_resource_group_name=None,
1379        # Base command code handles confirmation, but it passes '--yes' parameter to us if
1380        # provided. We don't care about this parameter and it gets handled weirdly if we
1381        # expliclty specify it with default value here (e.g. `yes=None` or `yes=True`), receiving
1382        # it in kwargs seems to work.
1383        **kwargs):  # pylint: disable=unused-argument
1384    '''
1385    Deletes a replication link.
1386    '''
1387
1388    # Determine optional values
1389    partner_resource_group_name = partner_resource_group_name or resource_group_name
1390
1391    # Find the replication link
1392    links = list(client.list_by_database(
1393        database_name=database_name,
1394        server_name=server_name,
1395        resource_group_name=resource_group_name))
1396
1397    # The link doesn't tell us the partner resource group name, so we just have to count on
1398    # partner server name being unique
1399    link = next((link for link in links if link.partner_server == partner_server_name), None)
1400    if not link:
1401        # No link exists, nothing to be done
1402        return
1403
1404    return client.delete(
1405        database_name=database_name,
1406        server_name=server_name,
1407        resource_group_name=resource_group_name,
1408        link_id=link.name)
1409
1410
1411def db_export(
1412        client,
1413        database_name,
1414        server_name,
1415        resource_group_name,
1416        storage_key_type,
1417        storage_key,
1418        **kwargs):
1419    '''
1420    Exports a database to a bacpac file.
1421    '''
1422
1423    storage_key = _pad_sas_key(storage_key_type, storage_key)
1424
1425    kwargs['storage_key_type'] = storage_key_type
1426    kwargs['storage_key'] = storage_key
1427
1428    return client.begin_export(
1429        database_name=database_name,
1430        server_name=server_name,
1431        resource_group_name=resource_group_name,
1432        parameters=kwargs)
1433
1434
1435def db_import(
1436        client,
1437        database_name,
1438        server_name,
1439        resource_group_name,
1440        storage_key_type,
1441        storage_key,
1442        **kwargs):
1443    '''
1444    Imports a bacpac file into an existing database.
1445    '''
1446
1447    storage_key = _pad_sas_key(storage_key_type, storage_key)
1448
1449    kwargs['storage_key_type'] = storage_key_type
1450    kwargs['storage_key'] = storage_key
1451
1452    return client.begin_import_method(
1453        database_name=database_name,
1454        server_name=server_name,
1455        resource_group_name=resource_group_name,
1456        parameters=kwargs)
1457
1458
1459def _pad_sas_key(
1460        storage_key_type,
1461        storage_key):
1462    '''
1463    Import/Export API requires that "?" precede SAS key as an argument.
1464    Adds ? prefix if it wasn't included.
1465    '''
1466
1467    if storage_key_type.lower() == StorageKeyType.SHARED_ACCESS_KEY.value.lower():  # pylint: disable=no-member
1468        if storage_key[0] != '?':
1469            storage_key = '?' + storage_key
1470    return storage_key
1471
1472
1473def db_list(
1474        client,
1475        server_name,
1476        resource_group_name,
1477        elastic_pool_name=None):
1478    '''
1479    Lists databases in a server or elastic pool.
1480    '''
1481
1482    if elastic_pool_name:
1483        # List all databases in the elastic pool
1484        return client.list_by_elastic_pool(
1485            server_name=server_name,
1486            resource_group_name=resource_group_name,
1487            elastic_pool_name=elastic_pool_name)
1488
1489        # List all databases in the server
1490    return client.list_by_server(resource_group_name=resource_group_name, server_name=server_name)
1491
1492
1493def db_update(
1494        cmd,
1495        instance,
1496        server_name,
1497        resource_group_name,
1498        elastic_pool_id=None,
1499        max_size_bytes=None,
1500        service_objective=None,
1501        zone_redundant=None,
1502        tier=None,
1503        family=None,
1504        capacity=None,
1505        read_scale=None,
1506        high_availability_replica_count=None,
1507        min_capacity=None,
1508        auto_pause_delay=None,
1509        compute_model=None,
1510        requested_backup_storage_redundancy=None,
1511        maintenance_configuration_id=None):
1512    '''
1513    Applies requested parameters to a db resource instance for a DB update.
1514    '''
1515
1516    # Verify edition
1517    if instance.sku.tier.lower() == DatabaseEdition.data_warehouse.value.lower():  # pylint: disable=no-member
1518        raise CLIError('Azure SQL Data Warehouse can be updated with the command'
1519                       ' `az sql dw update`.')
1520
1521    # Check backup storage redundancy configuration
1522    location = _get_server_location(cmd.cli_ctx, server_name=server_name, resource_group_name=resource_group_name)
1523    if _should_show_backup_storage_redundancy_warnings(location):
1524        if requested_backup_storage_redundancy == 'Geo':
1525            _backup_storage_redundancy_specify_geo_warning()
1526
1527    #####
1528    # Set sku-related properties
1529    #####
1530
1531    # Verify that elastic_pool_name and requested_service_objective_name param values are not
1532    # totally inconsistent. If elastic pool and service objective name are both specified, and
1533    # they are inconsistent (i.e. service objective is not 'ElasticPool'), then the service
1534    # actually ignores the value of service objective name (!!). We are trying to protect the CLI
1535    # user from this unintuitive behavior.
1536    if (elastic_pool_id and service_objective and
1537            service_objective != ServiceObjectiveName.ELASTIC_POOL):
1538        raise CLIError('If elastic pool is specified, service objective must be'
1539                       ' unspecified or equal \'{}\'.'.format(
1540                           ServiceObjectiveName.ELASTIC_POOL))
1541
1542    # Update both elastic pool and sku. The service treats elastic pool and sku properties like PATCH,
1543    # so if either of these properties is null then the service will keep the property unchanged -
1544    # except if pool is null/empty and service objective is a standalone SLO value (e.g. 'S0',
1545    # 'S1', etc), in which case the pool being null/empty is meaningful - it means remove from
1546    # pool.
1547
1548    # Validate elastic pool id
1549    instance.elastic_pool_id = _validate_elastic_pool_id(
1550        cmd.cli_ctx,
1551        elastic_pool_id,
1552        server_name,
1553        resource_group_name)
1554
1555    # Finding out requesting compute_model
1556    if not compute_model:
1557        compute_model = (
1558            ComputeModelType.serverless if _is_serverless_slo(instance.sku.name)
1559            else ComputeModelType.provisioned)
1560
1561    # Update sku
1562    _db_elastic_pool_update_sku(
1563        cmd,
1564        instance,
1565        service_objective,
1566        tier,
1567        family,
1568        capacity,
1569        find_sku_from_capabilities_func=_find_db_sku_from_capabilities,
1570        compute_model=compute_model)
1571
1572    # TODO Temporary workaround for elastic pool sku name issue
1573    if instance.elastic_pool_id:
1574        instance.sku = None
1575
1576    #####
1577    # Set other (non-sku related) properties
1578    #####
1579
1580    if max_size_bytes:
1581        instance.max_size_bytes = max_size_bytes
1582
1583    if zone_redundant is not None:
1584        instance.zone_redundant = zone_redundant
1585
1586    if read_scale is not None:
1587        instance.read_scale = read_scale
1588
1589    if high_availability_replica_count is not None:
1590        instance.high_availability_replica_count = high_availability_replica_count
1591
1592    # Set storage_account_type even if storage_acount_type is None
1593    # Otherwise, empty value defaults to current storage_account_type
1594    # and will potentially conflict with a previously requested update
1595    instance.requested_backup_storage_redundancy = requested_backup_storage_redundancy
1596
1597    instance.maintenance_configuration_id = _complete_maintenance_configuration_id(
1598        cmd.cli_ctx,
1599        maintenance_configuration_id)
1600
1601    #####
1602    # Set other (serverless related) properties
1603    #####
1604    if min_capacity:
1605        instance.min_capacity = min_capacity
1606
1607    if auto_pause_delay:
1608        instance.auto_pause_delay = auto_pause_delay
1609
1610    return instance
1611
1612
1613#####
1614#           sql db audit-policy & threat-policy
1615#####
1616
1617
1618def _find_storage_account_resource_group(cli_ctx, name):
1619    '''
1620    Finds a storage account's resource group by querying ARM resource cache.
1621
1622    Why do we have to do this: so we know the resource group in order to later query the storage API
1623    to determine the account's keys and endpoint. Why isn't this just a command line parameter:
1624    because if it was a command line parameter then the customer would need to specify storage
1625    resource group just to update some unrelated property, which is annoying and makes no sense to
1626    the customer.
1627    '''
1628
1629    storage_type = 'Microsoft.Storage/storageAccounts'
1630    classic_storage_type = 'Microsoft.ClassicStorage/storageAccounts'
1631
1632    query = "name eq '{}' and (resourceType eq '{}' or resourceType eq '{}')".format(
1633        name, storage_type, classic_storage_type)
1634
1635    client = get_mgmt_service_client(cli_ctx, ResourceType.MGMT_RESOURCE_RESOURCES)
1636    resources = list(client.resources.list(filter=query))
1637
1638    if not resources:
1639        raise CLIError("No storage account with name '{}' was found.".format(name))
1640
1641    if len(resources) > 1:
1642        raise CLIError("Multiple storage accounts with name '{}' were found.".format(name))
1643
1644    if resources[0].type == classic_storage_type:
1645        raise CLIError("The storage account with name '{}' is a classic storage account which is"
1646                       " not supported by this command. Use a non-classic storage account or"
1647                       " specify storage endpoint and key instead.".format(name))
1648
1649    # Split the uri and return just the resource group
1650    return resources[0].id.split('/')[4]
1651
1652
1653def _get_storage_account_name(storage_endpoint):
1654    '''
1655    Determines storage account name from endpoint url string.
1656    e.g. 'https://mystorage.blob.core.windows.net' -> 'mystorage'
1657    '''
1658    # url parse package has different names in Python 2 and 3. 'six' package works cross-version.
1659    from six.moves.urllib.parse import urlparse  # pylint: disable=import-error
1660
1661    return urlparse(storage_endpoint).netloc.split('.')[0]
1662
1663
1664def _get_storage_endpoint(
1665        cli_ctx,
1666        storage_account,
1667        resource_group_name):
1668    '''
1669    Gets storage account endpoint by querying storage ARM API.
1670    '''
1671    from azure.mgmt.storage import StorageManagementClient
1672
1673    # Get storage account
1674    client = get_mgmt_service_client(cli_ctx, StorageManagementClient)
1675    account = client.storage_accounts.get_properties(
1676        resource_group_name=resource_group_name,
1677        account_name=storage_account)
1678
1679    # Get endpoint
1680    # pylint: disable=no-member
1681    endpoints = account.primary_endpoints
1682    try:
1683        return endpoints.blob
1684    except AttributeError:
1685        raise CLIError("The storage account with name '{}' (id '{}') has no blob endpoint. Use a"
1686                       " different storage account.".format(account.name, account.id))
1687
1688
1689def _get_storage_key(
1690        cli_ctx,
1691        storage_account,
1692        resource_group_name,
1693        use_secondary_key):
1694    '''
1695    Gets storage account key by querying storage ARM API.
1696    '''
1697    from azure.mgmt.storage import StorageManagementClient
1698
1699    # Get storage keys
1700    client = get_mgmt_service_client(cli_ctx, StorageManagementClient)
1701    keys = client.storage_accounts.list_keys(
1702        resource_group_name=resource_group_name,
1703        account_name=storage_account)
1704
1705    # Choose storage key
1706    index = 1 if use_secondary_key else 0
1707    return keys.keys[index].value  # pylint: disable=no-member
1708
1709
1710def _db_security_policy_update(
1711        cli_ctx,
1712        instance,
1713        enabled,
1714        storage_account,
1715        storage_endpoint,
1716        storage_account_access_key,
1717        use_secondary_key):
1718    '''
1719    Common code for updating audit and threat detection policy.
1720    '''
1721
1722    # Validate storage endpoint arguments
1723    if storage_endpoint and storage_account:
1724        raise CLIError('--storage-endpoint and --storage-account cannot both be specified.')
1725
1726    # Set storage endpoint
1727    if storage_endpoint:
1728        instance.storage_endpoint = storage_endpoint
1729    if storage_account:
1730        storage_resource_group = _find_storage_account_resource_group(cli_ctx, storage_account)
1731        instance.storage_endpoint = _get_storage_endpoint(cli_ctx, storage_account, storage_resource_group)
1732
1733    # Set storage access key
1734    if storage_account_access_key:
1735        # Access key is specified
1736        instance.storage_account_access_key = storage_account_access_key
1737    elif enabled:
1738        # Access key is not specified, but state is Enabled.
1739        # If state is Enabled, then access key property is required in PUT. However access key is
1740        # readonly (GET returns empty string for access key), so we need to determine the value
1741        # and then PUT it back. (We don't want the user to be force to specify this, because that
1742        # would be very annoying when updating non-storage-related properties).
1743        # This doesn't work if the user used generic update args, i.e. `--set state=Enabled`
1744        # instead of `--state Enabled`, since the generic update args are applied after this custom
1745        # function, but at least we tried.
1746        if not storage_account:
1747            storage_account = _get_storage_account_name(instance.storage_endpoint)
1748            storage_resource_group = _find_storage_account_resource_group(cli_ctx, storage_account)
1749
1750        instance.storage_account_access_key = _get_storage_key(
1751            cli_ctx,
1752            storage_account,
1753            storage_resource_group,
1754            use_secondary_key)
1755
1756
1757def _check_audit_policy_state(
1758        state,
1759        value):
1760    return state is not None and state.lower() == value.lower()
1761
1762
1763def _is_audit_policy_state_enabled(state):
1764    return _check_audit_policy_state(state, BlobAuditingPolicyState.ENABLED)
1765
1766
1767def _is_audit_policy_state_disabled(state):
1768    return _check_audit_policy_state(state, BlobAuditingPolicyState.DISABLED)
1769
1770
1771def _is_audit_policy_state_none_or_disabled(state):
1772    return state is None or _check_audit_policy_state(state, BlobAuditingPolicyState.DISABLED)
1773
1774
1775def _get_diagnostic_settings_url(
1776        cmd,
1777        resource_group_name,
1778        server_name,
1779        database_name=None):
1780
1781    from azure.cli.core.commands.client_factory import get_subscription_id
1782
1783    return '/subscriptions/{}/resourceGroups/{}/providers/Microsoft.Sql/servers/{}/databases/{}'.format(
1784        get_subscription_id(cmd.cli_ctx),
1785        resource_group_name, server_name,
1786        database_name if database_name is not None else "master")
1787
1788
1789def _get_diagnostic_settings(
1790        cmd,
1791        resource_group_name,
1792        server_name,
1793        database_name=None):
1794    '''
1795    Common code to get server or database diagnostic settings
1796    '''
1797
1798    diagnostic_settings_url = _get_diagnostic_settings_url(
1799        cmd=cmd, resource_group_name=resource_group_name,
1800        server_name=server_name, database_name=database_name)
1801    azure_monitor_client = cf_monitor(cmd.cli_ctx)
1802
1803    return azure_monitor_client.diagnostic_settings.list(diagnostic_settings_url)
1804
1805
1806def _fetch_first_audit_diagnostic_setting(diagnostic_settings, category_name):
1807    return next((ds for ds in diagnostic_settings if hasattr(ds, 'logs') and
1808                 next((log for log in ds.logs if log.enabled and
1809                       log.category == category_name), None) is not None), None)
1810
1811
1812def _fetch_all_audit_diagnostic_settings(diagnostic_settings, category_name):
1813    return [ds for ds in diagnostic_settings if hasattr(ds, 'logs') and
1814            next((log for log in ds.logs if log.enabled and
1815                  log.category == category_name), None) is not None]
1816
1817
1818def server_ms_support_audit_policy_get(
1819        client,
1820        server_name,
1821        resource_group_name):
1822    '''
1823    Get server Microsoft support operations audit policy
1824    '''
1825
1826    return client.get(
1827        resource_group_name=resource_group_name,
1828        server_name=server_name,
1829        dev_ops_auditing_settings_name='default')
1830
1831
1832def server_ms_support_audit_policy_set(
1833        client,
1834        server_name,
1835        resource_group_name,
1836        parameters):
1837    '''
1838    Set server Microsoft support operations audit policy
1839    '''
1840
1841    return client.begin_create_or_update(
1842        resource_group_name=resource_group_name,
1843        server_name=server_name,
1844        dev_ops_auditing_settings_name='default',
1845        parameters=parameters)
1846
1847
1848def _audit_policy_show(
1849        cmd,
1850        client,
1851        resource_group_name,
1852        server_name,
1853        database_name=None,
1854        category_name=None):
1855    '''
1856    Common code to get server (DevOps) or database audit policy including diagnostic settings
1857    '''
1858
1859    # Request audit policy
1860    if database_name is None:
1861        if category_name == 'DevOpsOperationsAudit':
1862            audit_policy = server_ms_support_audit_policy_get(
1863                client=client,
1864                resource_group_name=resource_group_name,
1865                server_name=server_name)
1866        else:
1867            audit_policy = client.get(
1868                resource_group_name=resource_group_name,
1869                server_name=server_name)
1870    else:
1871        audit_policy = client.get(
1872            resource_group_name=resource_group_name,
1873            server_name=server_name,
1874            database_name=database_name)
1875
1876    audit_policy.blob_storage_target_state = BlobAuditingPolicyState.DISABLED
1877    audit_policy.event_hub_target_state = BlobAuditingPolicyState.DISABLED
1878    audit_policy.log_analytics_target_state = BlobAuditingPolicyState.DISABLED
1879
1880    # If audit policy's state is disabled there is nothing to do
1881    if _is_audit_policy_state_disabled(audit_policy.state):
1882        return audit_policy
1883
1884    if not audit_policy.storage_endpoint:
1885        audit_policy.blob_storage_target_state = BlobAuditingPolicyState.DISABLED
1886    else:
1887        audit_policy.blob_storage_target_state = BlobAuditingPolicyState.ENABLED
1888
1889    # If 'is_azure_monitor_target_enabled' is false there is no reason to request diagnostic settings
1890    if not audit_policy.is_azure_monitor_target_enabled:
1891        return audit_policy
1892
1893    # Request diagnostic settings
1894    diagnostic_settings = _get_diagnostic_settings(
1895        cmd=cmd, resource_group_name=resource_group_name,
1896        server_name=server_name, database_name=database_name)
1897
1898    # Sort received diagnostic settings by name and get first element to ensure consistency between command executions
1899    diagnostic_settings.value.sort(key=lambda d: d.name)
1900    audit_diagnostic_setting = _fetch_first_audit_diagnostic_setting(diagnostic_settings.value, category_name)
1901
1902    # Initialize azure monitor properties
1903    if audit_diagnostic_setting is not None:
1904        if audit_diagnostic_setting.workspace_id is not None:
1905            audit_policy.log_analytics_target_state = BlobAuditingPolicyState.ENABLED
1906            audit_policy.log_analytics_workspace_resource_id = audit_diagnostic_setting.workspace_id
1907
1908        if audit_diagnostic_setting.event_hub_authorization_rule_id is not None:
1909            audit_policy.event_hub_target_state = BlobAuditingPolicyState.enabled
1910            audit_policy.event_hub_authorization_rule_id = audit_diagnostic_setting.event_hub_authorization_rule_id
1911            audit_policy.event_hub_name = audit_diagnostic_setting.event_hub_name
1912
1913    return audit_policy
1914
1915
1916def server_audit_policy_show(
1917        cmd,
1918        client,
1919        server_name,
1920        resource_group_name):
1921    '''
1922    Show server audit policy
1923    '''
1924
1925    return _audit_policy_show(
1926        cmd=cmd,
1927        client=client,
1928        resource_group_name=resource_group_name,
1929        server_name=server_name,
1930        category_name='SQLSecurityAuditEvents')
1931
1932
1933def db_audit_policy_show(
1934        cmd,
1935        client,
1936        server_name,
1937        resource_group_name,
1938        database_name):
1939    '''
1940    Show database audit policy
1941    '''
1942
1943    return _audit_policy_show(
1944        cmd=cmd,
1945        client=client,
1946        resource_group_name=resource_group_name,
1947        server_name=server_name,
1948        database_name=database_name,
1949        category_name='SQLSecurityAuditEvents')
1950
1951
1952def server_ms_support_audit_policy_show(
1953        cmd,
1954        client,
1955        server_name,
1956        resource_group_name):
1957    '''
1958    Show server Microsoft support operations audit policy
1959    '''
1960
1961    return _audit_policy_show(
1962        cmd=cmd,
1963        client=client,
1964        resource_group_name=resource_group_name,
1965        server_name=server_name,
1966        category_name='DevOpsOperationsAudit')
1967
1968
1969def _audit_policy_validate_arguments(
1970        state=None,
1971        blob_storage_target_state=None,
1972        storage_account=None,
1973        storage_endpoint=None,
1974        storage_account_access_key=None,
1975        retention_days=None,
1976        log_analytics_target_state=None,
1977        log_analytics_workspace_resource_id=None,
1978        event_hub_target_state=None,
1979        event_hub_authorization_rule_id=None,
1980        event_hub_name=None):
1981    '''
1982    Validate input agruments
1983    '''
1984
1985    blob_storage_arguments_provided = blob_storage_target_state is not None or\
1986        storage_account is not None or storage_endpoint is not None or\
1987        storage_account_access_key is not None or\
1988        retention_days is not None
1989
1990    log_analytics_arguments_provided = log_analytics_target_state is not None or\
1991        log_analytics_workspace_resource_id is not None
1992
1993    event_hub_arguments_provided = event_hub_target_state is not None or\
1994        event_hub_authorization_rule_id is not None or\
1995        event_hub_name is not None
1996
1997    if not state and not blob_storage_arguments_provided and\
1998            not log_analytics_arguments_provided and not event_hub_arguments_provided:
1999        raise CLIError('Either state or blob storage or log analytics or event hub arguments are missing')
2000
2001    if _is_audit_policy_state_enabled(state) and\
2002            blob_storage_target_state is None and log_analytics_target_state is None and event_hub_target_state is None:
2003        raise CLIError('One of the following arguments must be enabled:'
2004                       ' blob-storage-target-state, log-analytics-target-state, event-hub-target-state')
2005
2006    if _is_audit_policy_state_disabled(state) and\
2007            (blob_storage_arguments_provided or
2008             log_analytics_arguments_provided or
2009             event_hub_name):
2010        raise CLIError('No additional arguments should be provided once state is disabled')
2011
2012    if (_is_audit_policy_state_none_or_disabled(blob_storage_target_state)) and\
2013            (storage_account is not None or storage_endpoint is not None or
2014             storage_account_access_key is not None):
2015        raise CLIError('Blob storage account arguments cannot be specified'
2016                       ' if blob-storage-target-state is not provided or disabled')
2017
2018    if _is_audit_policy_state_enabled(blob_storage_target_state):
2019        if storage_account is not None and storage_endpoint is not None:
2020            raise CLIError('storage-account and storage-endpoint cannot be provided at the same time')
2021
2022        if storage_account is None and storage_endpoint is None:
2023            raise CLIError('Either storage-account or storage-endpoint must be provided')
2024
2025    # Server upper limit
2026    max_retention_days = 3285
2027
2028    if retention_days is not None and\
2029            (not retention_days.isdigit() or int(retention_days) <= 0 or int(retention_days) >= max_retention_days):
2030        raise CLIError('retention-days must be a positive number greater than zero and lower than {}'
2031                       .format(max_retention_days))
2032
2033    if _is_audit_policy_state_none_or_disabled(log_analytics_target_state) and\
2034            log_analytics_workspace_resource_id is not None:
2035        raise CLIError('Log analytics workspace resource id cannot be specified'
2036                       ' if log-analytics-target-state is not provided or disabled')
2037
2038    if _is_audit_policy_state_enabled(log_analytics_target_state) and\
2039            log_analytics_workspace_resource_id is None:
2040        raise CLIError('Log analytics workspace resource id must be specified'
2041                       ' if log-analytics-target-state is enabled')
2042
2043    if _is_audit_policy_state_none_or_disabled(event_hub_target_state) and\
2044            (event_hub_authorization_rule_id is not None or event_hub_name is not None):
2045        raise CLIError('Event hub arguments cannot be specified if event-hub-target-state is not provided or disabled')
2046
2047    if _is_audit_policy_state_enabled(event_hub_target_state) and event_hub_authorization_rule_id is None:
2048        raise CLIError('event-hub-authorization-rule-id must be specified if event-hub-target-state is enabled')
2049
2050
2051def _audit_policy_create_diagnostic_setting(
2052        cmd,
2053        resource_group_name,
2054        server_name,
2055        database_name=None,
2056        category_name=None,
2057        log_analytics_target_state=None,
2058        log_analytics_workspace_resource_id=None,
2059        event_hub_target_state=None,
2060        event_hub_authorization_rule_id=None,
2061        event_hub_name=None):
2062    '''
2063    Create audit diagnostic setting, i.e. containing single category - SQLSecurityAuditEvents or DevOpsOperationsAudit
2064    '''
2065
2066    # Generate diagnostic settings name to be created
2067    name = category_name
2068
2069    import inspect
2070    test_methods = ["test_sql_db_security_mgmt", "test_sql_server_security_mgmt", "test_sql_server_ms_support_mgmt"]
2071    test_mode = next((e for e in inspect.stack() if e.function in test_methods), None) is not None
2072
2073    # For test environment the name should be constant, i.e. match the name written in recorded yaml file
2074    if test_mode:
2075        name += '_LogAnalytics' if log_analytics_target_state is not None else ''
2076        name += '_EventHub' if event_hub_target_state is not None else ''
2077    else:
2078        import uuid
2079        name += '_' + str(uuid.uuid4())
2080
2081    diagnostic_settings_url = _get_diagnostic_settings_url(
2082        cmd=cmd,
2083        resource_group_name=resource_group_name,
2084        server_name=server_name,
2085        database_name=database_name)
2086
2087    azure_monitor_client = cf_monitor(cmd.cli_ctx)
2088
2089    LogSettings = cmd.get_models(
2090        'LogSettings',
2091        resource_type=ResourceType.MGMT_MONITOR,
2092        operation_group='diagnostic_settings')
2093
2094    RetentionPolicy = cmd.get_models(
2095        'RetentionPolicy',
2096        resource_type=ResourceType.MGMT_MONITOR,
2097        operation_group='diagnostic_settings')
2098
2099    return create_diagnostics_settings(
2100        client=azure_monitor_client.diagnostic_settings,
2101        name=name,
2102        resource_uri=diagnostic_settings_url,
2103        logs=[LogSettings(category=category_name, enabled=True,
2104                          retention_policy=RetentionPolicy(enabled=False, days=0))],
2105        metrics=None,
2106        event_hub=event_hub_name,
2107        event_hub_rule=event_hub_authorization_rule_id,
2108        storage_account=None,
2109        workspace=log_analytics_workspace_resource_id)
2110
2111
2112def _audit_policy_update_diagnostic_settings(
2113        cmd,
2114        server_name,
2115        resource_group_name,
2116        database_name=None,
2117        diagnostic_settings=None,
2118        category_name=None,
2119        log_analytics_target_state=None,
2120        log_analytics_workspace_resource_id=None,
2121        event_hub_target_state=None,
2122        event_hub_authorization_rule_id=None,
2123        event_hub_name=None):
2124    '''
2125    Update audit policy's diagnostic settings
2126    '''
2127
2128    # Fetch all audit diagnostic settings
2129    audit_diagnostic_settings = _fetch_all_audit_diagnostic_settings(diagnostic_settings.value, category_name)
2130    num_of_audit_diagnostic_settings = len(audit_diagnostic_settings)
2131
2132    # If more than 1 audit diagnostic settings found then throw error
2133    if num_of_audit_diagnostic_settings > 1:
2134        raise CLIError('Multiple audit diagnostics settings are already enabled')
2135
2136    diagnostic_settings_url = _get_diagnostic_settings_url(
2137        cmd=cmd,
2138        resource_group_name=resource_group_name,
2139        server_name=server_name,
2140        database_name=database_name)
2141
2142    azure_monitor_client = cf_monitor(cmd.cli_ctx)
2143
2144    # If no audit diagnostic settings found then create one if azure monitor is enabled
2145    if num_of_audit_diagnostic_settings == 0:
2146        if _is_audit_policy_state_enabled(log_analytics_target_state) or\
2147                _is_audit_policy_state_enabled(event_hub_target_state):
2148            created_diagnostic_setting = _audit_policy_create_diagnostic_setting(
2149                cmd=cmd,
2150                resource_group_name=resource_group_name,
2151                server_name=server_name,
2152                database_name=database_name,
2153                category_name=category_name,
2154                log_analytics_target_state=log_analytics_target_state,
2155                log_analytics_workspace_resource_id=log_analytics_workspace_resource_id,
2156                event_hub_target_state=event_hub_target_state,
2157                event_hub_authorization_rule_id=event_hub_authorization_rule_id,
2158                event_hub_name=event_hub_name)
2159
2160            # Return rollback data tuple
2161            return [("delete", created_diagnostic_setting)]
2162
2163        # azure monitor is disabled - there is nothing to do
2164        return None
2165
2166    # This leaves us with case when num_of_audit_diagnostic_settings is 1
2167    audit_diagnostic_setting = audit_diagnostic_settings[0]
2168
2169    # Initialize actually updated azure monitor fields
2170    if log_analytics_target_state is None:
2171        log_analytics_workspace_resource_id = audit_diagnostic_setting.workspace_id
2172    elif _is_audit_policy_state_disabled(log_analytics_target_state):
2173        log_analytics_workspace_resource_id = None
2174
2175    if event_hub_target_state is None:
2176        event_hub_authorization_rule_id = audit_diagnostic_setting.event_hub_authorization_rule_id
2177        event_hub_name = audit_diagnostic_setting.event_hub_name
2178    elif _is_audit_policy_state_disabled(event_hub_target_state):
2179        event_hub_authorization_rule_id = None
2180        event_hub_name = None
2181
2182    is_azure_monitor_target_enabled = log_analytics_workspace_resource_id is not None or\
2183        event_hub_authorization_rule_id is not None
2184
2185    has_other_categories = next((log for log in audit_diagnostic_setting.logs
2186                                 if log.enabled and log.category != category_name), None) is not None
2187
2188    # If there is no other categories except SQLSecurityAuditEvents\DevOpsOperationsAudit update or delete
2189    # the existing single diagnostic settings
2190    if not has_other_categories:
2191        # If azure monitor is enabled then update existing single audit diagnostic setting
2192        if is_azure_monitor_target_enabled:
2193            create_diagnostics_settings(
2194                client=azure_monitor_client.diagnostic_settings,
2195                name=audit_diagnostic_setting.name,
2196                resource_uri=diagnostic_settings_url,
2197                logs=audit_diagnostic_setting.logs,
2198                metrics=audit_diagnostic_setting.metrics,
2199                event_hub=event_hub_name,
2200                event_hub_rule=event_hub_authorization_rule_id,
2201                storage_account=audit_diagnostic_setting.storage_account_id,
2202                workspace=log_analytics_workspace_resource_id)
2203
2204            # Return rollback data tuple
2205            return [("update", audit_diagnostic_setting)]
2206
2207        # Azure monitor is disabled, delete existing single audit diagnostic setting
2208        azure_monitor_client.diagnostic_settings.delete(diagnostic_settings_url, audit_diagnostic_setting.name)
2209
2210        # Return rollback data tuple
2211        return [("create", audit_diagnostic_setting)]
2212
2213    # In case there are other categories in the existing single audit diagnostic setting a "split" must be performed:
2214    #   1. Disable SQLSecurityAuditEvents\DevOpsOperationsAudit category in found audit diagnostic setting
2215    #   2. Create new diagnostic setting with SQLSecurityAuditEvents\DevOpsOperationsAudit category,
2216    #      i.e. audit diagnostic setting
2217
2218    # Build updated logs list with disabled SQLSecurityAuditEvents\DevOpsOperationsAudit category
2219    updated_logs = []
2220
2221    LogSettings = cmd.get_models(
2222        'LogSettings',
2223        resource_type=ResourceType.MGMT_MONITOR,
2224        operation_group='diagnostic_settings')
2225
2226    RetentionPolicy = cmd.get_models(
2227        'RetentionPolicy',
2228        resource_type=ResourceType.MGMT_MONITOR,
2229        operation_group='diagnostic_settings')
2230
2231    for log in audit_diagnostic_setting.logs:
2232        if log.category == category_name:
2233            updated_logs.append(LogSettings(category=log.category, enabled=False,
2234                                            retention_policy=RetentionPolicy(enabled=False, days=0)))
2235        else:
2236            updated_logs.append(log)
2237
2238    # Update existing diagnostic settings
2239    create_diagnostics_settings(
2240        client=azure_monitor_client.diagnostic_settings,
2241        name=audit_diagnostic_setting.name,
2242        resource_uri=diagnostic_settings_url,
2243        logs=updated_logs,
2244        metrics=audit_diagnostic_setting.metrics,
2245        event_hub=audit_diagnostic_setting.event_hub_name,
2246        event_hub_rule=audit_diagnostic_setting.event_hub_authorization_rule_id,
2247        storage_account=audit_diagnostic_setting.storage_account_id,
2248        workspace=audit_diagnostic_setting.workspace_id)
2249
2250    # Add original 'audit_diagnostic_settings' to rollback_data list
2251    rollback_data = [("update", audit_diagnostic_setting)]
2252
2253    # Create new diagnostic settings with enabled SQLSecurityAuditEvents\DevOpsOperationsAudit category
2254    # only if azure monitor is enabled
2255    if is_azure_monitor_target_enabled:
2256        created_diagnostic_setting = _audit_policy_create_diagnostic_setting(
2257            cmd=cmd,
2258            resource_group_name=resource_group_name,
2259            server_name=server_name,
2260            database_name=database_name,
2261            category_name=category_name,
2262            log_analytics_target_state=log_analytics_target_state,
2263            log_analytics_workspace_resource_id=log_analytics_workspace_resource_id,
2264            event_hub_target_state=event_hub_target_state,
2265            event_hub_authorization_rule_id=event_hub_authorization_rule_id,
2266            event_hub_name=event_hub_name)
2267
2268        # Add 'created_diagnostic_settings' to rollback_data list in reverse order
2269        rollback_data.insert(0, ("delete", created_diagnostic_setting))
2270
2271    return rollback_data
2272
2273
2274def _audit_policy_update_apply_blob_storage_details(
2275        cmd,
2276        instance,
2277        blob_storage_target_state,
2278        storage_account,
2279        storage_endpoint,
2280        storage_account_access_key,
2281        retention_days):
2282    '''
2283    Apply blob storage details on policy update
2284    '''
2285    if hasattr(instance, 'is_storage_secondary_key_in_use'):
2286        is_storage_secondary_key_in_use = instance.is_storage_secondary_key_in_use
2287    else:
2288        is_storage_secondary_key_in_use = False
2289
2290    if blob_storage_target_state is None:
2291        # Original audit policy has no storage_endpoint
2292        if not instance.storage_endpoint:
2293            instance.storage_endpoint = None
2294            instance.storage_account_access_key = None
2295        else:
2296            # Resolve storage_account_access_key based on original storage_endpoint
2297            storage_account = _get_storage_account_name(instance.storage_endpoint)
2298            storage_resource_group = _find_storage_account_resource_group(cmd.cli_ctx, storage_account)
2299
2300            instance.storage_account_access_key = _get_storage_key(
2301                cli_ctx=cmd.cli_ctx,
2302                storage_account=storage_account,
2303                resource_group_name=storage_resource_group,
2304                use_secondary_key=is_storage_secondary_key_in_use)
2305    elif _is_audit_policy_state_enabled(blob_storage_target_state):
2306        # Resolve storage_endpoint using provided storage_account
2307        if storage_account is not None:
2308            storage_resource_group = _find_storage_account_resource_group(cmd.cli_ctx, storage_account)
2309            storage_endpoint = _get_storage_endpoint(cmd.cli_ctx, storage_account, storage_resource_group)
2310
2311        if storage_endpoint is not None:
2312            instance.storage_endpoint = storage_endpoint
2313
2314        if storage_account_access_key is not None:
2315            instance.storage_account_access_key = storage_account_access_key
2316        elif storage_endpoint is not None:
2317            # Resolve storage_account if not provided
2318            if storage_account is None:
2319                storage_account = _get_storage_account_name(storage_endpoint)
2320                storage_resource_group = _find_storage_account_resource_group(cmd.cli_ctx, storage_account)
2321
2322            # Resolve storage_account_access_key based on storage_account
2323            instance.storage_account_access_key = _get_storage_key(
2324                cli_ctx=cmd.cli_ctx,
2325                storage_account=storage_account,
2326                resource_group_name=storage_resource_group,
2327                use_secondary_key=is_storage_secondary_key_in_use)
2328
2329        # Apply retenation days
2330        if hasattr(instance, 'retention_days') and retention_days is not None:
2331            instance.retention_days = retention_days
2332    else:
2333        instance.storage_endpoint = None
2334        instance.storage_account_access_key = None
2335
2336
2337def _audit_policy_update_apply_azure_monitor_target_enabled(
2338        instance,
2339        diagnostic_settings,
2340        category_name,
2341        log_analytics_target_state,
2342        event_hub_target_state):
2343    '''
2344    Apply value of is_azure_monitor_target_enabled on policy update
2345    '''
2346
2347    # If log_analytics_target_state and event_hub_target_state are None there is nothing to do
2348    if log_analytics_target_state is None and event_hub_target_state is None:
2349        return
2350
2351    if _is_audit_policy_state_enabled(log_analytics_target_state) or\
2352            _is_audit_policy_state_enabled(event_hub_target_state):
2353        instance.is_azure_monitor_target_enabled = True
2354    else:
2355        # Sort received diagnostic settings by name and get first element to ensure consistency
2356        # between command executions
2357        diagnostic_settings.value.sort(key=lambda d: d.name)
2358        audit_diagnostic_setting = _fetch_first_audit_diagnostic_setting(diagnostic_settings.value, category_name)
2359
2360        # Determine value of is_azure_monitor_target_enabled
2361        if audit_diagnostic_setting is None:
2362            updated_log_analytics_workspace_id = None
2363            updated_event_hub_authorization_rule_id = None
2364        else:
2365            updated_log_analytics_workspace_id = audit_diagnostic_setting.workspace_id
2366            updated_event_hub_authorization_rule_id = audit_diagnostic_setting.event_hub_authorization_rule_id
2367
2368        if _is_audit_policy_state_disabled(log_analytics_target_state):
2369            updated_log_analytics_workspace_id = None
2370
2371        if _is_audit_policy_state_disabled(event_hub_target_state):
2372            updated_event_hub_authorization_rule_id = None
2373
2374        instance.is_azure_monitor_target_enabled = updated_log_analytics_workspace_id is not None or\
2375            updated_event_hub_authorization_rule_id is not None
2376
2377
2378def _audit_policy_update_global_settings(
2379        cmd,
2380        instance,
2381        diagnostic_settings=None,
2382        category_name=None,
2383        state=None,
2384        blob_storage_target_state=None,
2385        storage_account=None,
2386        storage_endpoint=None,
2387        storage_account_access_key=None,
2388        audit_actions_and_groups=None,
2389        retention_days=None,
2390        log_analytics_target_state=None,
2391        event_hub_target_state=None):
2392    '''
2393    Update audit policy's global settings
2394    '''
2395
2396    # Apply state
2397    if state is not None:
2398        instance.state = BlobAuditingPolicyState[state.lower()]
2399
2400    # Apply additional command line arguments only if policy's state is enabled
2401    if _is_audit_policy_state_enabled(instance.state):
2402        # Apply blob_storage_target_state and all storage account details
2403        _audit_policy_update_apply_blob_storage_details(
2404            cmd=cmd,
2405            instance=instance,
2406            blob_storage_target_state=blob_storage_target_state,
2407            storage_account=storage_account,
2408            storage_endpoint=storage_endpoint,
2409            storage_account_access_key=storage_account_access_key,
2410            retention_days=retention_days)
2411
2412        # Apply audit_actions_and_groups
2413        if hasattr(instance, 'audit_actions_and_groups'):
2414            if audit_actions_and_groups is not None:
2415                instance.audit_actions_and_groups = audit_actions_and_groups
2416
2417            if not instance.audit_actions_and_groups or instance.audit_actions_and_groups == []:
2418                instance.audit_actions_and_groups = [
2419                    "SUCCESSFUL_DATABASE_AUTHENTICATION_GROUP",
2420                    "FAILED_DATABASE_AUTHENTICATION_GROUP",
2421                    "BATCH_COMPLETED_GROUP"]
2422
2423        # Apply is_azure_monitor_target_enabled
2424        _audit_policy_update_apply_azure_monitor_target_enabled(
2425            instance=instance,
2426            diagnostic_settings=diagnostic_settings,
2427            category_name=category_name,
2428            log_analytics_target_state=log_analytics_target_state,
2429            event_hub_target_state=event_hub_target_state)
2430
2431
2432def _audit_policy_update_rollback(
2433        cmd,
2434        server_name,
2435        resource_group_name,
2436        database_name,
2437        rollback_data):
2438    '''
2439    Rollback diagnostic settings change
2440    '''
2441
2442    diagnostic_settings_url = _get_diagnostic_settings_url(
2443        cmd=cmd,
2444        resource_group_name=resource_group_name,
2445        server_name=server_name,
2446        database_name=database_name)
2447
2448    azure_monitor_client = cf_monitor(cmd.cli_ctx)
2449
2450    for rd in rollback_data:
2451        rollback_diagnostic_setting = rd[1]
2452
2453        if rd[0] == "create" or rd[0] == "update":
2454            create_diagnostics_settings(
2455                client=azure_monitor_client.diagnostic_settings,
2456                name=rollback_diagnostic_setting.name,
2457                resource_uri=diagnostic_settings_url,
2458                logs=rollback_diagnostic_setting.logs,
2459                metrics=rollback_diagnostic_setting.metrics,
2460                event_hub=rollback_diagnostic_setting.event_hub_name,
2461                event_hub_rule=rollback_diagnostic_setting.event_hub_authorization_rule_id,
2462                storage_account=rollback_diagnostic_setting.storage_account_id,
2463                workspace=rollback_diagnostic_setting.workspace_id)
2464        else:  # delete
2465            azure_monitor_client.diagnostic_settings.delete(diagnostic_settings_url, rollback_diagnostic_setting.name)
2466
2467
2468def _audit_policy_update(
2469        cmd,
2470        instance,
2471        server_name,
2472        resource_group_name,
2473        database_name=None,
2474        state=None,
2475        blob_storage_target_state=None,
2476        storage_account=None,
2477        storage_endpoint=None,
2478        storage_account_access_key=None,
2479        audit_actions_and_groups=None,
2480        retention_days=None,
2481        category_name=None,
2482        log_analytics_target_state=None,
2483        log_analytics_workspace_resource_id=None,
2484        event_hub_target_state=None,
2485        event_hub_authorization_rule_id=None,
2486        event_hub_name=None):
2487
2488    # Arguments validation
2489    _audit_policy_validate_arguments(
2490        state=state,
2491        blob_storage_target_state=blob_storage_target_state,
2492        storage_account=storage_account,
2493        storage_endpoint=storage_endpoint,
2494        storage_account_access_key=storage_account_access_key,
2495        retention_days=retention_days,
2496        log_analytics_target_state=log_analytics_target_state,
2497        log_analytics_workspace_resource_id=log_analytics_workspace_resource_id,
2498        event_hub_target_state=event_hub_target_state,
2499        event_hub_authorization_rule_id=event_hub_authorization_rule_id,
2500        event_hub_name=event_hub_name)
2501
2502    # Get diagnostic settings only if log_analytics_target_state or event_hub_target_state is provided
2503    if log_analytics_target_state is not None or event_hub_target_state is not None:
2504        diagnostic_settings = _get_diagnostic_settings(
2505            cmd=cmd,
2506            resource_group_name=resource_group_name,
2507            server_name=server_name,
2508            database_name=database_name)
2509
2510        # Update diagnostic settings
2511        rollback_data = _audit_policy_update_diagnostic_settings(
2512            cmd=cmd,
2513            server_name=server_name,
2514            resource_group_name=resource_group_name,
2515            database_name=database_name,
2516            diagnostic_settings=diagnostic_settings,
2517            category_name=category_name,
2518            log_analytics_target_state=log_analytics_target_state,
2519            log_analytics_workspace_resource_id=log_analytics_workspace_resource_id,
2520            event_hub_target_state=event_hub_target_state,
2521            event_hub_authorization_rule_id=event_hub_authorization_rule_id,
2522            event_hub_name=event_hub_name)
2523    else:
2524        diagnostic_settings = None
2525        rollback_data = None
2526
2527    # Update auditing global settings
2528    try:
2529        _audit_policy_update_global_settings(
2530            cmd=cmd,
2531            instance=instance,
2532            diagnostic_settings=diagnostic_settings,
2533            category_name=category_name,
2534            state=state,
2535            blob_storage_target_state=blob_storage_target_state,
2536            storage_account=storage_account,
2537            storage_endpoint=storage_endpoint,
2538            storage_account_access_key=storage_account_access_key,
2539            audit_actions_and_groups=audit_actions_and_groups,
2540            retention_days=retention_days,
2541            log_analytics_target_state=log_analytics_target_state,
2542            event_hub_target_state=event_hub_target_state)
2543
2544        return instance
2545    except Exception as err:
2546        logger.debug(err)
2547
2548        if rollback_data is not None:
2549            _audit_policy_update_rollback(
2550                cmd=cmd,
2551                server_name=server_name,
2552                resource_group_name=resource_group_name,
2553                database_name=database_name,
2554                rollback_data=rollback_data)
2555
2556        # Reraise the original exception
2557        raise err
2558
2559
2560def server_audit_policy_update(
2561        cmd,
2562        instance,
2563        server_name,
2564        resource_group_name,
2565        state=None,
2566        blob_storage_target_state=None,
2567        storage_account=None,
2568        storage_endpoint=None,
2569        storage_account_access_key=None,
2570        audit_actions_and_groups=None,
2571        retention_days=None,
2572        log_analytics_target_state=None,
2573        log_analytics_workspace_resource_id=None,
2574        event_hub_target_state=None,
2575        event_hub_authorization_rule_id=None,
2576        event_hub=None):
2577    '''
2578    Update server audit policy
2579    '''
2580
2581    return _audit_policy_update(
2582        cmd=cmd,
2583        instance=instance,
2584        server_name=server_name,
2585        resource_group_name=resource_group_name,
2586        database_name=None,
2587        state=state,
2588        blob_storage_target_state=blob_storage_target_state,
2589        storage_account=storage_account,
2590        storage_endpoint=storage_endpoint,
2591        storage_account_access_key=storage_account_access_key,
2592        audit_actions_and_groups=audit_actions_and_groups,
2593        retention_days=retention_days,
2594        category_name='SQLSecurityAuditEvents',
2595        log_analytics_target_state=log_analytics_target_state,
2596        log_analytics_workspace_resource_id=log_analytics_workspace_resource_id,
2597        event_hub_target_state=event_hub_target_state,
2598        event_hub_authorization_rule_id=event_hub_authorization_rule_id,
2599        event_hub_name=event_hub)
2600
2601
2602def db_audit_policy_update(
2603        cmd,
2604        instance,
2605        server_name,
2606        resource_group_name,
2607        database_name,
2608        state=None,
2609        blob_storage_target_state=None,
2610        storage_account=None,
2611        storage_endpoint=None,
2612        storage_account_access_key=None,
2613        audit_actions_and_groups=None,
2614        retention_days=None,
2615        log_analytics_target_state=None,
2616        log_analytics_workspace_resource_id=None,
2617        event_hub_target_state=None,
2618        event_hub_authorization_rule_id=None,
2619        event_hub=None):
2620    '''
2621    Update database audit policy
2622    '''
2623
2624    return _audit_policy_update(
2625        cmd=cmd,
2626        instance=instance,
2627        server_name=server_name,
2628        resource_group_name=resource_group_name,
2629        database_name=database_name,
2630        state=state,
2631        blob_storage_target_state=blob_storage_target_state,
2632        storage_account=storage_account,
2633        storage_endpoint=storage_endpoint,
2634        storage_account_access_key=storage_account_access_key,
2635        audit_actions_and_groups=audit_actions_and_groups,
2636        retention_days=retention_days,
2637        category_name='SQLSecurityAuditEvents',
2638        log_analytics_target_state=log_analytics_target_state,
2639        log_analytics_workspace_resource_id=log_analytics_workspace_resource_id,
2640        event_hub_target_state=event_hub_target_state,
2641        event_hub_authorization_rule_id=event_hub_authorization_rule_id,
2642        event_hub_name=event_hub)
2643
2644
2645def server_ms_support_audit_policy_update(
2646        cmd,
2647        instance,
2648        server_name,
2649        resource_group_name,
2650        state=None,
2651        blob_storage_target_state=None,
2652        storage_account=None,
2653        storage_endpoint=None,
2654        storage_account_access_key=None,
2655        log_analytics_target_state=None,
2656        log_analytics_workspace_resource_id=None,
2657        event_hub_target_state=None,
2658        event_hub_authorization_rule_id=None,
2659        event_hub=None):
2660    '''
2661    Update server Microsoft support operations audit policy
2662    '''
2663
2664    return _audit_policy_update(
2665        cmd=cmd,
2666        instance=instance,
2667        server_name=server_name,
2668        resource_group_name=resource_group_name,
2669        database_name=None,
2670        state=state,
2671        blob_storage_target_state=blob_storage_target_state,
2672        storage_account=storage_account,
2673        storage_endpoint=storage_endpoint,
2674        storage_account_access_key=storage_account_access_key,
2675        audit_actions_and_groups=None,
2676        retention_days=None,
2677        category_name='DevOpsOperationsAudit',
2678        log_analytics_target_state=log_analytics_target_state,
2679        log_analytics_workspace_resource_id=log_analytics_workspace_resource_id,
2680        event_hub_target_state=event_hub_target_state,
2681        event_hub_authorization_rule_id=event_hub_authorization_rule_id,
2682        event_hub_name=event_hub)
2683
2684
2685def update_long_term_retention(
2686        client,
2687        database_name,
2688        server_name,
2689        resource_group_name,
2690        weekly_retention=None,
2691        monthly_retention=None,
2692        yearly_retention=None,
2693        week_of_year=None,
2694        **kwargs):
2695    '''
2696    Updates long term retention for managed database
2697    '''
2698    if not (weekly_retention or monthly_retention or yearly_retention):
2699        raise CLIError('Please specify retention setting(s).  See \'--help\' for more details.')
2700
2701    if yearly_retention and not week_of_year:
2702        raise CLIError('Please specify week of year for yearly retention.')
2703
2704    kwargs['weekly_retention'] = weekly_retention
2705
2706    kwargs['monthly_retention'] = monthly_retention
2707
2708    kwargs['yearly_retention'] = yearly_retention
2709
2710    kwargs['week_of_year'] = week_of_year
2711
2712    policy = client.begin_create_or_update(
2713        database_name=database_name,
2714        server_name=server_name,
2715        resource_group_name=resource_group_name,
2716        policy_name=LongTermRetentionPolicyName.DEFAULT,
2717        parameters=kwargs)
2718
2719    return policy
2720
2721
2722def get_long_term_retention(
2723        client,
2724        resource_group_name,
2725        database_name,
2726        server_name):
2727    '''
2728    Gets long term retention for managed database
2729    '''
2730
2731    return client.get(
2732        database_name=database_name,
2733        server_name=server_name,
2734        resource_group_name=resource_group_name,
2735        policy_name=LongTermRetentionPolicyName.DEFAULT)
2736
2737
2738def update_short_term_retention(
2739        client,
2740        database_name,
2741        server_name,
2742        resource_group_name,
2743        retention_days,
2744        diffbackup_hours,
2745        no_wait=False,
2746        **kwargs):
2747    '''
2748    Updates short term retention for live database
2749    '''
2750
2751    kwargs['retention_days'] = retention_days
2752    kwargs['diff_backup_interval_in_hours'] = diffbackup_hours
2753
2754    return sdk_no_wait(
2755        no_wait,
2756        client.begin_create_or_update,
2757        database_name=database_name,
2758        server_name=server_name,
2759        resource_group_name=resource_group_name,
2760        policy_name=ShortTermRetentionPolicyName.DEFAULT,
2761        parameters=kwargs)
2762
2763
2764def get_short_term_retention(
2765        client,
2766        database_name,
2767        server_name,
2768        resource_group_name):
2769    '''
2770    Gets short term retention for live database
2771    '''
2772
2773    return client.get(
2774        database_name=database_name,
2775        server_name=server_name,
2776        resource_group_name=resource_group_name,
2777        policy_name=ShortTermRetentionPolicyName.DEFAULT)
2778
2779
2780def _list_by_database_long_term_retention_backups(
2781        client,
2782        location_name,
2783        long_term_retention_server_name,
2784        long_term_retention_database_name,
2785        resource_group_name=None,
2786        only_latest_per_database=None,
2787        database_state=None):
2788    '''
2789    Gets the long term retention backups for a Managed Database
2790    '''
2791
2792    if resource_group_name:
2793        backups = client.list_by_resource_group_database(
2794            resource_group_name=resource_group_name,
2795            location_name=location_name,
2796            long_term_retention_server_name=long_term_retention_server_name,
2797            long_term_retention_database_name=long_term_retention_database_name,
2798            only_latest_per_database=only_latest_per_database,
2799            database_state=database_state)
2800    else:
2801        backups = client.list_by_database(
2802            location_name=location_name,
2803            long_term_retention_server_name=long_term_retention_server_name,
2804            long_term_retention_database_name=long_term_retention_database_name,
2805            only_latest_per_database=only_latest_per_database,
2806            database_state=database_state)
2807
2808    return backups
2809
2810
2811def _list_by_server_long_term_retention_backups(
2812        client,
2813        location_name,
2814        long_term_retention_server_name,
2815        resource_group_name=None,
2816        only_latest_per_database=None,
2817        database_state=None):
2818    '''
2819    Gets the long term retention backups within a Managed Instance
2820    '''
2821
2822    if resource_group_name:
2823        backups = client.list_by_resource_group_server(
2824            resource_group_name=resource_group_name,
2825            location_name=location_name,
2826            long_term_retention_server_name=long_term_retention_server_name,
2827            only_latest_per_database=only_latest_per_database,
2828            database_state=database_state)
2829    else:
2830        backups = client.list_by_server(
2831            location_name=location_name,
2832            long_term_retention_server_name=long_term_retention_server_name,
2833            only_latest_per_database=only_latest_per_database,
2834            database_state=database_state)
2835
2836    return backups
2837
2838
2839def _list_by_location_long_term_retention_backups(
2840        client,
2841        location_name,
2842        resource_group_name=None,
2843        only_latest_per_database=None,
2844        database_state=None):
2845    '''
2846    Gets the long term retention backups within a specified region.
2847    '''
2848
2849    if resource_group_name:
2850        backups = client.list_by_resource_group_location(
2851            resource_group_name=resource_group_name,
2852            location_name=location_name,
2853            only_latest_per_database=only_latest_per_database,
2854            database_state=database_state)
2855    else:
2856        backups = client.list_by_location(
2857            location_name=location_name,
2858            only_latest_per_database=only_latest_per_database,
2859            database_state=database_state)
2860
2861    return backups
2862
2863
2864def list_long_term_retention_backups(
2865        client,
2866        location_name,
2867        long_term_retention_server_name=None,
2868        long_term_retention_database_name=None,
2869        resource_group_name=None,
2870        only_latest_per_database=None,
2871        database_state=None):
2872    '''
2873    Lists the long term retention backups for a specified location, instance, or database.
2874    '''
2875
2876    if long_term_retention_server_name:
2877        if long_term_retention_database_name:
2878            backups = _list_by_database_long_term_retention_backups(
2879                client,
2880                location_name,
2881                long_term_retention_server_name,
2882                long_term_retention_database_name,
2883                resource_group_name,
2884                only_latest_per_database,
2885                database_state)
2886
2887        else:
2888            backups = _list_by_server_long_term_retention_backups(
2889                client,
2890                location_name,
2891                long_term_retention_server_name,
2892                resource_group_name,
2893                only_latest_per_database,
2894                database_state)
2895    else:
2896        backups = _list_by_location_long_term_retention_backups(
2897            client,
2898            location_name,
2899            resource_group_name,
2900            only_latest_per_database,
2901            database_state)
2902
2903    return backups
2904
2905
2906def restore_long_term_retention_backup(
2907        cmd,
2908        client,
2909        long_term_retention_backup_resource_id,
2910        target_database_name,
2911        target_server_name,
2912        target_resource_group_name,
2913        requested_backup_storage_redundancy,
2914        **kwargs):
2915    '''
2916    Restores an existing database (i.e. create with 'RestoreLongTermRetentionBackup' create mode.)
2917    '''
2918
2919    if not target_resource_group_name or not target_server_name or not target_database_name:
2920        raise CLIError('Please specify target resource(s). '
2921                       'Target resource group, target server, and target database '
2922                       'are all required to restore LTR backup.')
2923
2924    if not long_term_retention_backup_resource_id:
2925        raise CLIError('Please specify a long term retention backup.')
2926
2927    kwargs['location'] = _get_server_location(
2928        cmd.cli_ctx,
2929        server_name=target_server_name,
2930        resource_group_name=target_resource_group_name)
2931
2932    kwargs['create_mode'] = CreateMode.RESTORE_LONG_TERM_RETENTION_BACKUP
2933    kwargs['long_term_retention_backup_resource_id'] = long_term_retention_backup_resource_id
2934    kwargs['requested_backup_storage_redundancy'] = requested_backup_storage_redundancy
2935
2936    # Check backup storage redundancy configurations
2937    if _should_show_backup_storage_redundancy_warnings(kwargs['location']):
2938        if not kwargs['requested_backup_storage_redundancy']:
2939            _backup_storage_redundancy_take_source_warning()
2940        if kwargs['requested_backup_storage_redundancy'] == 'Geo':
2941            _backup_storage_redundancy_specify_geo_warning()
2942
2943    return client.begin_create_or_update(
2944        database_name=target_database_name,
2945        server_name=target_server_name,
2946        resource_group_name=target_resource_group_name,
2947        parameters=kwargs)
2948
2949
2950def db_threat_detection_policy_get(
2951        client,
2952        resource_group_name,
2953        server_name,
2954        database_name):
2955    '''
2956    Gets a threat detection policy.
2957    '''
2958
2959    return client.get(
2960        resource_group_name=resource_group_name,
2961        server_name=server_name,
2962        database_name=database_name,
2963        security_alert_policy_name=SecurityAlertPolicyName.DEFAULT)
2964
2965
2966def db_threat_detection_policy_update(
2967        cmd,
2968        instance,
2969        state=None,
2970        storage_account=None,
2971        storage_endpoint=None,
2972        storage_account_access_key=None,
2973        retention_days=None,
2974        email_addresses=None,
2975        disabled_alerts=None,
2976        email_account_admins=None):
2977    '''
2978    Updates a threat detection policy. Custom update function to apply parameters to instance.
2979    '''
2980
2981    # Apply state
2982    if state:
2983        instance.state = SecurityAlertPolicyState[state.lower()]
2984    enabled = instance.state.lower() == SecurityAlertPolicyState.ENABLED.value.lower()  # pylint: disable=no-member
2985
2986    # Set storage-related properties
2987    _db_security_policy_update(
2988        cmd.cli_ctx,
2989        instance,
2990        enabled,
2991        storage_account,
2992        storage_endpoint,
2993        storage_account_access_key,
2994        False)
2995
2996    # Set other properties
2997    if retention_days:
2998        instance.retention_days = retention_days
2999
3000    if email_addresses:
3001        instance.email_addresses = email_addresses
3002
3003    if disabled_alerts:
3004        instance.disabled_alerts = disabled_alerts
3005
3006    if email_account_admins:
3007        instance.email_account_admins = email_account_admins
3008
3009    return instance
3010
3011
3012def db_threat_detection_policy_update_setter(
3013        client,
3014        resource_group_name,
3015        server_name,
3016        database_name,
3017        parameters):
3018
3019    return client.create_or_update(
3020        resource_group_name=resource_group_name,
3021        server_name=server_name,
3022        database_name=database_name,
3023        security_alert_policy_name=SecurityAlertPolicyName.DEFAULT,
3024        parameters=parameters)
3025
3026
3027def db_sensitivity_label_show(
3028        client,
3029        database_name,
3030        server_name,
3031        schema_name,
3032        table_name,
3033        column_name,
3034        resource_group_name):
3035
3036    return client.get(
3037        resource_group_name,
3038        server_name,
3039        database_name,
3040        schema_name,
3041        table_name,
3042        column_name,
3043        SensitivityLabelSource.CURRENT)
3044
3045
3046def db_sensitivity_label_update(
3047        cmd,
3048        client,
3049        database_name,
3050        server_name,
3051        schema_name,
3052        table_name,
3053        column_name,
3054        resource_group_name,
3055        label_name=None,
3056        information_type=None):
3057    '''
3058    Updates a sensitivity label. Custom update function to apply parameters to instance.
3059    '''
3060
3061    # Get the information protection policy
3062    from azure.mgmt.security import SecurityCenter
3063    from azure.core.exceptions import ResourceNotFoundError
3064
3065    security_center_client = get_mgmt_service_client(cmd.cli_ctx, SecurityCenter, asc_location="centralus")
3066
3067    information_protection_policy = security_center_client.information_protection_policies.get(
3068        scope='/providers/Microsoft.Management/managementGroups/{}'.format(_get_tenant_id()),
3069        information_protection_policy_name="effective")
3070
3071    sensitivity_label = SensitivityLabel()
3072
3073    # Get the current label
3074    try:
3075        current_label = client.get(
3076            resource_group_name,
3077            server_name,
3078            database_name,
3079            schema_name,
3080            table_name,
3081            column_name,
3082            SensitivityLabelSource.CURRENT)
3083        # Initialize with existing values
3084        sensitivity_label.label_name = current_label.label_name
3085        sensitivity_label.label_id = current_label.label_id
3086        sensitivity_label.information_type = current_label.information_type
3087        sensitivity_label.information_type_id = current_label.information_type_id
3088
3089    except ResourceNotFoundError as ex:
3090        if not(ex and 'SensitivityLabelsLabelNotFound' in str(ex)):
3091            raise ex
3092
3093    # Find the label id and information type id in the policy by the label name provided
3094    label_id = None
3095    if label_name:
3096        label_id = next((id for id in information_protection_policy.labels
3097                         if information_protection_policy.labels[id].display_name.lower() ==
3098                         label_name.lower()),
3099                        None)
3100        if label_id is None:
3101            raise CLIError('The provided label name was not found in the information protection policy.')
3102        sensitivity_label.label_id = label_id
3103        sensitivity_label.label_name = label_name
3104    information_type_id = None
3105    if information_type:
3106        information_type_id = next((id for id in information_protection_policy.information_types
3107                                    if information_protection_policy.information_types[id].display_name.lower() ==
3108                                    information_type.lower()),
3109                                   None)
3110        if information_type_id is None:
3111            raise CLIError('The provided information type was not found in the information protection policy.')
3112        sensitivity_label.information_type_id = information_type_id
3113        sensitivity_label.information_type = information_type
3114
3115    return client.create_or_update(
3116        resource_group_name, server_name, database_name, schema_name, table_name, column_name, sensitivity_label)
3117
3118
3119###############################################
3120#                sql dw                       #
3121###############################################
3122
3123
3124def dw_create(
3125        cmd,
3126        client,
3127        database_name,
3128        server_name,
3129        resource_group_name,
3130        no_wait=False,
3131        **kwargs):
3132    '''
3133    Creates a datawarehouse.
3134    '''
3135
3136    # Set edition
3137    kwargs['sku'].tier = DatabaseEdition.data_warehouse.value
3138
3139    # Create
3140    return _db_dw_create(
3141        cmd.cli_ctx,
3142        client,
3143        None,
3144        DatabaseIdentity(cmd.cli_ctx, database_name, server_name, resource_group_name),
3145        no_wait,
3146        **kwargs)
3147
3148
3149def dw_list(
3150        client,
3151        server_name,
3152        resource_group_name):
3153    '''
3154    Lists data warehouses in a server or elastic pool.
3155    '''
3156
3157    dbs = client.list_by_server(
3158        resource_group_name=resource_group_name,
3159        server_name=server_name)
3160
3161    # Include only DW's
3162    return [db for db in dbs if db.sku.tier == DatabaseEdition.data_warehouse.value]
3163
3164
3165def dw_update(
3166        instance,
3167        max_size_bytes=None,
3168        service_objective=None):
3169    '''
3170    Updates a data warehouse. Custom update function to apply parameters to instance.
3171    '''
3172
3173    # Apply param values to instance
3174    if max_size_bytes:
3175        instance.max_size_bytes = max_size_bytes
3176
3177    if service_objective:
3178        instance.sku = Sku(name=service_objective)
3179
3180    return instance
3181
3182
3183def dw_pause(
3184        client,
3185        database_name,
3186        server_name,
3187        resource_group_name):
3188    '''
3189    Pauses a datawarehouse.
3190    '''
3191
3192    # Pause, but DO NOT return the result. Long-running POST operation
3193    # results are not returned correctly by SDK.
3194    client.begin_pause(
3195        server_name=server_name,
3196        resource_group_name=resource_group_name,
3197        database_name=database_name).wait()
3198
3199
3200def dw_resume(
3201        client,
3202        database_name,
3203        server_name,
3204        resource_group_name):
3205    '''
3206    Resumes a datawarehouse.
3207    '''
3208
3209    # Resume, but DO NOT return the result. Long-running POST operation
3210    # results are not returned correctly by SDK.
3211    client.begin_resume(
3212        server_name=server_name,
3213        resource_group_name=resource_group_name,
3214        database_name=database_name).wait()
3215
3216
3217###############################################
3218#                sql elastic-pool             #
3219###############################################
3220
3221
3222def _find_elastic_pool_sku_from_capabilities(cli_ctx, location, sku, allow_reset_family=False, compute_model=None):
3223    '''
3224    Given a requested sku which may have some properties filled in
3225    (e.g. tier and capacity), finds the canonical matching sku
3226    from the given location's capabilities.
3227    '''
3228
3229    logger.debug('_find_elastic_pool_sku_from_capabilities input: %s', sku)
3230
3231    if sku.name:
3232        # User specified sku.name, so nothing else needs to be resolved.
3233        logger.debug('_find_elastic_pool_sku_from_capabilities return sku as is')
3234        return sku
3235
3236    if not _any_sku_values_specified(sku):
3237        # User did not request any properties of sku, so just wipe it out.
3238        # Server side will pick a default.
3239        logger.debug('_find_elastic_pool_sku_from_capabilities return None')
3240        return None
3241
3242    # Some properties of sku are specified, but not name. Use the requested properties
3243    # to find a matching capability and copy the sku from there.
3244
3245    # Get location capability
3246    loc_capability = _get_location_capability(cli_ctx, location, CapabilityGroup.SUPPORTED_ELASTIC_POOL_EDITIONS)
3247
3248    # Get default server version capability
3249    server_version_capability = _get_default_server_version(loc_capability)
3250
3251    # Find edition capability, based on requested sku properties
3252    edition_capability = _find_edition_capability(sku, server_version_capability.supported_elastic_pool_editions)
3253
3254    # Find performance level capability, based on requested sku properties
3255    performance_level_capability = _find_performance_level_capability(
3256        sku, edition_capability.supported_elastic_pool_performance_levels,
3257        allow_reset_family=allow_reset_family,
3258        compute_model=compute_model)
3259
3260    # Copy sku object from capability
3261    result = performance_level_capability.sku
3262    logger.debug('_find_elastic_pool_sku_from_capabilities return: %s', result)
3263    return result
3264
3265
3266def elastic_pool_create(
3267        cmd,
3268        client,
3269        server_name,
3270        resource_group_name,
3271        elastic_pool_name,
3272        sku=None,
3273        maintenance_configuration_id=None,
3274        **kwargs):
3275    '''
3276    Creates an elastic pool.
3277    '''
3278
3279    # Determine server location
3280    kwargs['location'] = _get_server_location(
3281        cmd.cli_ctx,
3282        server_name=server_name,
3283        resource_group_name=resource_group_name)
3284
3285    # If sku.name is not specified, resolve the requested sku name
3286    # using capabilities.
3287    kwargs['sku'] = _find_elastic_pool_sku_from_capabilities(cmd.cli_ctx, kwargs['location'], sku)
3288
3289    # Expand maintenance configuration id if needed
3290    kwargs['maintenance_configuration_id'] = _complete_maintenance_configuration_id(
3291        cmd.cli_ctx,
3292        maintenance_configuration_id)
3293
3294    # Create
3295    return client.begin_create_or_update(
3296        server_name=server_name,
3297        resource_group_name=resource_group_name,
3298        elastic_pool_name=elastic_pool_name,
3299        parameters=kwargs)
3300
3301
3302def elastic_pool_update(
3303        cmd,
3304        instance,
3305        max_capacity=None,
3306        min_capacity=None,
3307        max_size_bytes=None,
3308        zone_redundant=None,
3309        tier=None,
3310        family=None,
3311        capacity=None,
3312        maintenance_configuration_id=None):
3313    '''
3314    Updates an elastic pool. Custom update function to apply parameters to instance.
3315    '''
3316
3317    #####
3318    # Set sku-related properties
3319    #####
3320
3321    # Update sku
3322    _db_elastic_pool_update_sku(
3323        cmd,
3324        instance,
3325        None,  # service_objective
3326        tier,
3327        family,
3328        capacity,
3329        find_sku_from_capabilities_func=_find_elastic_pool_sku_from_capabilities)
3330
3331    #####
3332    # Set other properties
3333    #####
3334
3335    if max_capacity:
3336        instance.per_database_settings.max_capacity = max_capacity
3337
3338    if min_capacity:
3339        instance.per_database_settings.min_capacity = min_capacity
3340
3341    if max_size_bytes:
3342        instance.max_size_bytes = max_size_bytes
3343
3344    if zone_redundant is not None:
3345        instance.zone_redundant = zone_redundant
3346
3347    instance.maintenance_configuration_id = _complete_maintenance_configuration_id(
3348        cmd.cli_ctx,
3349        maintenance_configuration_id)
3350
3351    return instance
3352
3353
3354class ElasticPoolCapabilitiesAdditionalDetails(Enum):  # pylint: disable=too-few-public-methods
3355    '''
3356    Additional details that may be optionally included when getting elastic pool capabilities.
3357    '''
3358
3359    max_size = 'max-size'
3360    db_min_dtu = 'db-min-dtu'
3361    db_max_dtu = 'db-max-dtu'
3362    db_max_size = 'db-max-size'
3363
3364
3365def elastic_pool_list_capabilities(
3366        client,
3367        location,
3368        edition=None,
3369        dtu=None,
3370        vcores=None,
3371        show_details=None,
3372        available=False):
3373    '''
3374    Gets elastic pool capabilities and optionally applies the specified filters.
3375    '''
3376
3377    # Fixup parameters
3378    if not show_details:
3379        show_details = []
3380    if dtu:
3381        dtu = int(dtu)
3382
3383    # Get capabilities tree from server
3384    capabilities = client.list_by_location(location, CapabilityGroup.SUPPORTED_ELASTIC_POOL_EDITIONS)
3385
3386    # Get subtree related to elastic pools
3387    editions = _get_default_server_version(capabilities).supported_elastic_pool_editions
3388
3389    # Filter by edition
3390    if edition:
3391        editions = [e for e in editions if e.name.lower() == edition.lower()]
3392
3393    # Filter by dtu
3394    if dtu:
3395        for e in editions:
3396            e.supported_elastic_pool_performance_levels = [
3397                pl for pl in e.supported_elastic_pool_performance_levels
3398                if pl.performance_level.value == int(dtu) and
3399                pl.performance_level.unit == PerformanceLevelUnit.DTU]
3400
3401    # Filter by vcores
3402    if vcores:
3403        for e in editions:
3404            e.supported_elastic_pool_performance_levels = [
3405                pl for pl in e.supported_elastic_pool_performance_levels
3406                if pl.performance_level.value == int(vcores) and
3407                pl.performance_level.unit == PerformanceLevelUnit.V_CORES]
3408
3409    # Filter by availability
3410    if available:
3411        editions = _filter_available(editions)
3412        for e in editions:
3413            e.supported_elastic_pool_performance_levels = _filter_available(e.supported_elastic_pool_performance_levels)
3414            for slo in e.supported_service_level_objectives:
3415                slo.supported_max_sizes = _filter_available(slo.supported_max_sizes)
3416
3417    # Remove editions with no service objectives (due to filters)
3418    editions = [e for e in editions if e.supported_elastic_pool_performance_levels]
3419
3420    for e in editions:
3421        for d in e.supported_elastic_pool_performance_levels:
3422            # Optionally hide supported max sizes
3423            if ElasticPoolCapabilitiesAdditionalDetails.max_size.value not in show_details:
3424                d.supported_max_sizes = []
3425
3426            # Optionally hide per database min & max dtus. min dtus are nested inside max dtus,
3427            # so only hide max dtus if both min and max should be hidden.
3428            if ElasticPoolCapabilitiesAdditionalDetails.db_min_dtu.value not in show_details:
3429                if ElasticPoolCapabilitiesAdditionalDetails.db_max_dtu.value not in show_details:
3430                    d.supported_per_database_max_performance_levels = []
3431
3432                for md in d.supported_per_database_max_performance_levels:
3433                    md.supported_per_database_min_performance_levels = []
3434
3435            # Optionally hide supported per db max sizes
3436            if ElasticPoolCapabilitiesAdditionalDetails.db_max_size.value not in show_details:
3437                d.supported_per_database_max_sizes = []
3438
3439    return editions
3440
3441###############################################
3442#                sql instance-pool            #
3443###############################################
3444
3445
3446def instance_pool_list(
3447        client,
3448        resource_group_name=None):
3449    '''
3450    Lists servers in a resource group or subscription
3451    '''
3452
3453    if resource_group_name:
3454        # List all instance pools in the resource group
3455        return client.list_by_resource_group(
3456            resource_group_name=resource_group_name)
3457
3458    # List all instance pools in the subscription
3459    return client.list()
3460
3461
3462def instance_pool_create(
3463        cmd,
3464        client,
3465        instance_pool_name,
3466        resource_group_name,
3467        no_wait=False,
3468        sku=None,
3469        **kwargs):
3470    '''
3471    Creates a new instance pool
3472    '''
3473
3474    kwargs['sku'] = _find_instance_pool_sku_from_capabilities(
3475        cmd.cli_ctx, kwargs['location'], sku)
3476
3477    return sdk_no_wait(no_wait, client.begin_create_or_update,
3478                       instance_pool_name=instance_pool_name,
3479                       resource_group_name=resource_group_name,
3480                       parameters=kwargs)
3481
3482
3483def instance_pool_update(
3484        instance,
3485        tags=None):
3486    '''
3487    Updates a instance pool
3488    '''
3489
3490    instance.tags = tags
3491
3492    return instance
3493
3494
3495def _find_instance_pool_sku_from_capabilities(cli_ctx, location, sku):
3496    '''
3497    Validate if the sku family and edition input by user are permissible in the region using
3498    capabilities API and get the SKU name
3499    '''
3500
3501    logger.debug('_find_instance_pool_sku_from_capabilities input: %s', sku)
3502
3503    # Get location capability
3504    loc_capability = _get_location_capability(
3505        cli_ctx, location, CapabilityGroup.SUPPORTED_MANAGED_INSTANCE_VERSIONS)
3506
3507    # Get default server version capability
3508    managed_instance_version_capability = _get_default_capability(
3509        loc_capability.supported_managed_instance_versions)
3510
3511    # Find edition capability, based on requested sku properties
3512    edition_capability = _find_edition_capability(
3513        sku, managed_instance_version_capability.supported_instance_pool_editions)
3514
3515    # Find family level capability, based on requested sku properties
3516    _find_family_capability(
3517        sku, edition_capability.supported_families)
3518
3519    result = Sku(
3520        name="instance-pool",
3521        tier=sku.tier,
3522        family=sku.family)
3523
3524    logger.debug(
3525        '_find_instance_pool_sku_from_capabilities return: %s',
3526        result)
3527    return result
3528
3529
3530###############################################
3531#                sql server                   #
3532###############################################
3533
3534def server_create(
3535        client,
3536        resource_group_name,
3537        server_name,
3538        assign_identity=False,
3539        no_wait=False,
3540        enable_public_network=None,
3541        restrict_outbound_network_access=None,
3542        key_id=None,
3543        user_assigned_identity_id=None,
3544        primary_user_assigned_identity_id=None,
3545        identity_type=None,
3546        enable_ad_only_auth=False,
3547        external_admin_principal_type=None,
3548        external_admin_sid=None,
3549        external_admin_name=None,
3550        **kwargs):
3551    '''
3552    Creates a server.
3553    '''
3554
3555    if assign_identity:
3556        kwargs['identity'] = _get_identity_object_from_type(True, identity_type, user_assigned_identity_id, None)
3557    else:
3558        kwargs['identity'] = _get_identity_object_from_type(False, identity_type, user_assigned_identity_id, None)
3559
3560    if enable_public_network is not None:
3561        kwargs['public_network_access'] = (
3562            ServerNetworkAccessFlag.ENABLED if enable_public_network
3563            else ServerNetworkAccessFlag.DISABLED)
3564
3565    if restrict_outbound_network_access is not None:
3566        kwargs['restrict_outbound_network_access'] = (
3567            ServerNetworkAccessFlag.ENABLED if restrict_outbound_network_access
3568            else ServerNetworkAccessFlag.DISABLED)
3569
3570    kwargs['key_id'] = key_id
3571
3572    kwargs['primary_user_assigned_identity_id'] = primary_user_assigned_identity_id
3573
3574    ad_only = None
3575    if enable_ad_only_auth:
3576        ad_only = True
3577
3578    tenant_id = None
3579    if external_admin_name is not None:
3580        tenant_id = _get_tenant_id()
3581
3582    kwargs['administrators'] = ServerExternalAdministrator(
3583        principal_type=external_admin_principal_type,
3584        login=external_admin_name,
3585        sid=external_admin_sid,
3586        azure_ad_only_authentication=ad_only,
3587        tenant_id=tenant_id)
3588
3589    # Create
3590    return sdk_no_wait(no_wait, client.begin_create_or_update,
3591                       server_name=server_name,
3592                       resource_group_name=resource_group_name,
3593                       parameters=kwargs)
3594
3595
3596def server_list(
3597        client,
3598        resource_group_name=None,
3599        expand_ad_admin=False):
3600    '''
3601    Lists servers in a resource group or subscription
3602    '''
3603
3604    expand = None
3605    if expand_ad_admin:
3606        expand = 'administrators/activedirectory'
3607
3608    if resource_group_name:
3609        # List all servers in the resource group
3610        return client.list_by_resource_group(resource_group_name=resource_group_name, expand=expand)
3611
3612    # List all servers in the subscription
3613    return client.list(expand)
3614
3615
3616def server_get(
3617        client,
3618        resource_group_name,
3619        server_name,
3620        expand_ad_admin=False):
3621    '''
3622    Gets a server
3623    '''
3624
3625    expand = None
3626    if expand_ad_admin:
3627        expand = 'administrators/activedirectory'
3628
3629    # List all servers in the subscription
3630    return client.get(resource_group_name, server_name, expand)
3631
3632
3633def server_update(
3634        instance,
3635        administrator_login_password=None,
3636        assign_identity=False,
3637        minimal_tls_version=None,
3638        enable_public_network=None,
3639        restrict_outbound_network_access=None,
3640        primary_user_assigned_identity_id=None,
3641        key_id=None,
3642        identity_type=None,
3643        user_assigned_identity_id=None):
3644    '''
3645    Updates a server. Custom update function to apply parameters to instance.
3646    '''
3647
3648    # Once assigned, the identity cannot be removed
3649    # if instance.identity is None and assign_identity:
3650    #    instance.identity = ResourceIdentity(type=IdentityType.system_assigned.value)
3651
3652    instance.identity = _get_identity_object_from_type(
3653        assign_identity,
3654        identity_type,
3655        user_assigned_identity_id,
3656        instance.identity)
3657
3658    # Apply params to instance
3659    instance.administrator_login_password = (
3660        administrator_login_password or instance.administrator_login_password)
3661    instance.minimal_tls_version = (
3662        minimal_tls_version or instance.minimal_tls_version)
3663
3664    if enable_public_network is not None:
3665        instance.public_network_access = (
3666            ServerNetworkAccessFlag.ENABLED if enable_public_network
3667            else ServerNetworkAccessFlag.DISABLED)
3668
3669    if restrict_outbound_network_access is not None:
3670        instance.public_network_access = (
3671            ServerNetworkAccessFlag.ENABLED if restrict_outbound_network_access
3672            else ServerNetworkAccessFlag.DISABLED)
3673
3674    instance.primary_user_assigned_identity_id = (
3675        primary_user_assigned_identity_id or instance.primary_user_assigned_identity_id)
3676
3677    instance.key_id = (key_id or instance.key_id)
3678
3679    return instance
3680
3681
3682#####
3683#           sql server ad-admin
3684#####
3685
3686
3687def server_ad_admin_set(
3688        client,
3689        resource_group_name,
3690        server_name,
3691        **kwargs):
3692    '''
3693    Sets a server's AD admin.
3694    '''
3695
3696    kwargs['tenant_id'] = _get_tenant_id()
3697    kwargs['administrator_type'] = AdministratorType.ACTIVE_DIRECTORY
3698
3699    return client.begin_create_or_update(
3700        server_name=server_name,
3701        resource_group_name=resource_group_name,
3702        administrator_name=AdministratorName.ACTIVE_DIRECTORY,
3703        parameters=kwargs)
3704
3705
3706def server_ad_admin_update(
3707        instance,
3708        login=None,
3709        sid=None,
3710        tenant_id=None):
3711    '''
3712    Updates a server' AD admin.
3713    '''
3714
3715    # Apply params to instance
3716    instance.login = login or instance.login
3717    instance.sid = sid or instance.sid
3718    instance.tenant_id = tenant_id or instance.tenant_id
3719
3720    return instance
3721
3722#####
3723#           sql server firewall-rule
3724#####
3725
3726
3727def firewall_rule_allow_all_azure_ips(
3728        client,
3729        server_name,
3730        resource_group_name):
3731    '''
3732    Creates a firewall rule with special start/end ip address value
3733    that represents all azure ips.
3734    '''
3735
3736    # Name of the rule that will be created
3737    rule_name = 'AllowAllAzureIPs'
3738
3739    # Special start/end IP that represents allowing all azure ips
3740    azure_ip_addr = '0.0.0.0'
3741
3742    return client.create_or_update(
3743        resource_group_name=resource_group_name,
3744        server_name=server_name,
3745        firewall_rule_name=rule_name,
3746        start_ip_address=azure_ip_addr,
3747        end_ip_address=azure_ip_addr)
3748
3749
3750def firewall_rule_update(
3751        client,
3752        firewall_rule_name,
3753        server_name,
3754        resource_group_name,
3755        start_ip_address=None,
3756        end_ip_address=None):
3757    '''
3758    Updates a firewall rule.
3759    '''
3760
3761    # Get existing instance
3762    instance = client.get(
3763        firewall_rule_name=firewall_rule_name,
3764        server_name=server_name,
3765        resource_group_name=resource_group_name)
3766
3767    # Send update
3768    return client.create_or_update(
3769        firewall_rule_name=firewall_rule_name,
3770        server_name=server_name,
3771        resource_group_name=resource_group_name,
3772        parameters=FirewallRule(start_ip_address=start_ip_address or instance.start_ip_address,
3773                                end_ip_address=end_ip_address or instance.end_ip_address))
3774
3775
3776def firewall_rule_create(
3777        client,
3778        firewall_rule_name,
3779        server_name,
3780        resource_group_name,
3781        start_ip_address=None,
3782        end_ip_address=None):
3783    '''
3784    Creates a firewall rule.
3785    '''
3786    return client.create_or_update(
3787        firewall_rule_name=firewall_rule_name,
3788        server_name=server_name,
3789        resource_group_name=resource_group_name,
3790        parameters=FirewallRule(start_ip_address=start_ip_address,
3791                                end_ip_address=end_ip_address))
3792
3793
3794#########################################################
3795#           sql server outbound-firewall-rule           #
3796#########################################################
3797
3798
3799def outbound_firewall_rule_create(
3800        client,
3801        server_name,
3802        resource_group_name,
3803        outbound_rule_fqdn):
3804    '''
3805    Creates a new outbound firewall rule.
3806    '''
3807    return client.begin_create_or_update(
3808        server_name=server_name,
3809        resource_group_name=resource_group_name,
3810        outbound_rule_fqdn=outbound_rule_fqdn,
3811        parameters=OutboundFirewallRule())
3812
3813
3814#########################################################
3815#           sql server key                              #
3816#########################################################
3817
3818
3819def server_key_create(
3820        client,
3821        resource_group_name,
3822        server_name,
3823        kid=None):
3824    '''
3825    Creates a server key.
3826    '''
3827
3828    key_name = _get_server_key_name_from_uri(kid)
3829
3830    return client.begin_create_or_update(
3831        resource_group_name=resource_group_name,
3832        server_name=server_name,
3833        key_name=key_name,
3834        parameters=ServerKey(
3835            server_key_type=ServerKeyType.AZURE_KEY_VAULT,
3836            uri=kid)
3837    )
3838
3839
3840def server_key_get(
3841        client,
3842        resource_group_name,
3843        server_name,
3844        kid):
3845    '''
3846    Gets a server key.
3847    '''
3848
3849    key_name = _get_server_key_name_from_uri(kid)
3850
3851    return client.get(
3852        resource_group_name=resource_group_name,
3853        server_name=server_name,
3854        key_name=key_name)
3855
3856
3857def server_key_delete(
3858        client,
3859        resource_group_name,
3860        server_name,
3861        kid):
3862    '''
3863    Deletes a server key.
3864    '''
3865
3866    key_name = _get_server_key_name_from_uri(kid)
3867
3868    return client.begin_delete(
3869        resource_group_name=resource_group_name,
3870        server_name=server_name,
3871        key_name=key_name)
3872
3873
3874# pylint: disable=line-too-long
3875def _get_server_key_name_from_uri(uri):
3876    '''
3877    Gets the key's name to use as a SQL server key.
3878
3879    The SQL server key API requires that the server key has a specific name
3880    based on the vault, key and key version.
3881    '''
3882    import re
3883
3884    match = re.match(r'https://(.)+\.(managedhsm.azure.net|managedhsm-preview.azure.net|vault.azure.net|vault-int.azure-int.net|vault.azure.cn|managedhsm.azure.cn|vault.usgovcloudapi.net|managedhsm.usgovcloudapi.net|vault.microsoftazure.de|managedhsm.microsoftazure.de|vault.cloudapi.eaglex.ic.gov|vault.cloudapi.microsoft.scloud)(:443)?\/keys/[^\/]+\/[0-9a-zA-Z]+$', uri)
3885
3886    if match is None:
3887        raise CLIError('The provided uri is invalid. Please provide a valid Azure Key Vault key id.  For example: '
3888                       '"https://YourVaultName.vault.azure.net/keys/YourKeyName/01234567890123456789012345678901" '
3889                       'or "https://YourManagedHsmRegion.YourManagedHsmName.managedhsm.azure.net/keys/YourKeyName/01234567890123456789012345678901"')
3890
3891    vault = uri.split('.')[0].split('/')[-1]
3892    key = uri.split('/')[-2]
3893    version = uri.split('/')[-1]
3894    return '{}_{}_{}'.format(vault, key, version)
3895
3896
3897#####
3898#           sql server dns-alias
3899#####
3900
3901
3902def server_dns_alias_set(
3903        cmd,
3904        client,
3905        resource_group_name,
3906        server_name,
3907        dns_alias_name,
3908        original_server_name,
3909        original_subscription_id=None,
3910        original_resource_group_name=None,
3911        **kwargs):
3912    '''
3913    Sets a server DNS alias.
3914    '''
3915    # url parse package has different names in Python 2 and 3. 'six' package works cross-version.
3916    from six.moves.urllib.parse import quote  # pylint: disable=import-error
3917    from azure.cli.core.commands.client_factory import get_subscription_id
3918
3919    # Build the old alias id
3920    old_alias_id = "/subscriptions/{}/resourceGroups/{}/providers/Microsoft.Sql/servers/{}/dnsAliases/{}".format(
3921        quote(original_subscription_id or get_subscription_id(cmd.cli_ctx)),
3922        quote(original_resource_group_name or resource_group_name),
3923        quote(original_server_name),
3924        quote(dns_alias_name))
3925
3926    kwargs['old_server_dns_alias_id'] = old_alias_id
3927
3928    return client.begin_acquire(
3929        resource_group_name=resource_group_name,
3930        server_name=server_name,
3931        dns_alias_name=dns_alias_name,
3932        parameters=kwargs)
3933
3934#####
3935#           sql server encryption-protector
3936#####
3937
3938
3939def encryption_protector_get(
3940        client,
3941        resource_group_name,
3942        server_name):
3943    '''
3944    Gets a server encryption protector.
3945    '''
3946
3947    return client.get(
3948        resource_group_name=resource_group_name,
3949        server_name=server_name,
3950        encryption_protector_name=EncryptionProtectorName.CURRENT)
3951
3952
3953def encryption_protector_update(
3954        client,
3955        resource_group_name,
3956        server_name,
3957        server_key_type,
3958        kid=None,
3959        auto_rotation_enabled=None):
3960    '''
3961    Updates a server encryption protector.
3962    '''
3963
3964    if server_key_type == ServerKeyType.SERVICE_MANAGED:
3965        key_name = 'ServiceManaged'
3966    else:
3967        if kid is None:
3968            raise CLIError('A uri must be provided if the server_key_type is AzureKeyVault.')
3969        key_name = _get_server_key_name_from_uri(kid)
3970
3971    return client.begin_create_or_update(
3972        resource_group_name=resource_group_name,
3973        server_name=server_name,
3974        encryption_protector_name=EncryptionProtectorName.CURRENT,
3975        parameters=EncryptionProtector(server_key_type=server_key_type,
3976                                       server_key_name=key_name,
3977                                       auto_rotation_enabled=auto_rotation_enabled))
3978
3979#####
3980#           sql server aad-only
3981#####
3982
3983
3984def server_aad_only_disable(
3985        client,
3986        resource_group_name,
3987        server_name):
3988    '''
3989    Disables a servers aad-only setting
3990    '''
3991
3992    return client.begin_create_or_update(
3993        resource_group_name=resource_group_name,
3994        server_name=server_name,
3995        authentication_name=AuthenticationName.DEFAULT,
3996        parameters=ServerAzureADOnlyAuthentication(
3997            azure_ad_only_authentication=False)
3998    )
3999
4000
4001def server_aad_only_enable(
4002        client,
4003        resource_group_name,
4004        server_name):
4005    '''
4006    Enables a servers aad-only setting
4007    '''
4008
4009    return client.begin_create_or_update(
4010        resource_group_name=resource_group_name,
4011        server_name=server_name,
4012        authentication_name=AuthenticationName.DEFAULT,
4013        parameters=ServerAzureADOnlyAuthentication(
4014            azure_ad_only_authentication=True)
4015    )
4016
4017
4018def server_aad_only_get(
4019        client,
4020        resource_group_name,
4021        server_name):
4022    '''
4023    Shows a servers aad-only setting
4024    '''
4025
4026    return client.get(
4027        resource_group_name=resource_group_name,
4028        server_name=server_name,
4029        authentication_name=AuthenticationName.DEFAULT)
4030
4031
4032###############################################
4033#           sql server ledger                 #
4034###############################################
4035
4036def ledger_digest_uploads_show(
4037        client,
4038        resource_group_name,
4039        server_name,
4040        database_name):
4041    '''
4042    Shows ledger storage target
4043    '''
4044
4045    return client.get(
4046        resource_group_name=resource_group_name,
4047        server_name=server_name,
4048        database_name=database_name,
4049        ledger_digest_uploads=LedgerDigestUploadsName.CURRENT)
4050
4051
4052def ledger_digest_uploads_enable(
4053        client,
4054        resource_group_name,
4055        server_name,
4056        database_name,
4057        endpoint,
4058        **kwargs):
4059    '''
4060    Enables ledger storage target
4061    '''
4062
4063    kwargs['digest_storage_endpoint'] = endpoint
4064
4065    return client.create_or_update(
4066        resource_group_name=resource_group_name,
4067        server_name=server_name,
4068        database_name=database_name,
4069        ledger_digest_uploads=LedgerDigestUploadsName.CURRENT,
4070        parameters=kwargs)
4071
4072
4073def ledger_digest_uploads_disable(
4074        client,
4075        resource_group_name,
4076        server_name,
4077        database_name):
4078    '''
4079    Disables ledger storage target
4080    '''
4081
4082    return client.disable(
4083        resource_group_name=resource_group_name,
4084        server_name=server_name,
4085        database_name=database_name,
4086        ledger_digest_uploads=LedgerDigestUploadsName.CURRENT)
4087
4088
4089###############################################
4090#           sql server trust groups           #
4091###############################################
4092
4093
4094def server_trust_group_create(
4095        client,
4096        resource_group_name,
4097        name,
4098        location,
4099        group_member,
4100        trust_scope,
4101        no_wait=False):
4102
4103    members = [ServerInfo(server_id=member) for member in group_member]
4104    return sdk_no_wait(no_wait, client.begin_create_or_update,
4105                       resource_group_name=resource_group_name,
4106                       location_name=location,
4107                       server_trust_group_name=name,
4108                       parameters=ServerTrustGroup(
4109                           group_members=members,
4110                           trust_scopes=trust_scope
4111                       ))
4112
4113
4114def server_trust_group_delete(
4115        client,
4116        resource_group_name,
4117        name,
4118        location,
4119        no_wait=False):
4120
4121    return sdk_no_wait(no_wait, client.begin_delete,
4122                       resource_group_name=resource_group_name,
4123                       location_name=location,
4124                       server_trust_group_name=name)
4125
4126
4127def server_trust_group_get(
4128        client,
4129        resource_group_name,
4130        name,
4131        location):
4132
4133    return client.get(resource_group_name=resource_group_name,
4134                      location_name=location,
4135                      server_trust_group_name=name)
4136
4137
4138def server_trust_group_list(
4139        client,
4140        resource_group_name,
4141        instance_name=None,
4142        location=None):
4143    if instance_name:
4144        return client.list_by_instance(resource_group_name=resource_group_name, managed_instance_name=instance_name)
4145    return client.list_by_location(resource_group_name=resource_group_name, location_name=location)
4146
4147
4148###############################################
4149#                sql managed instance         #
4150###############################################
4151
4152
4153def _find_managed_instance_sku_from_capabilities(
4154        cli_ctx,
4155        location,
4156        sku):
4157    '''
4158    Given a requested sku which may have some properties filled in
4159    (e.g. tier and family), finds the canonical matching sku
4160    from the given location's capabilities.
4161    '''
4162
4163    logger.debug('_find_managed_instance_sku_from_capabilities input: %s', sku)
4164
4165    if not _any_sku_values_specified(sku):
4166        # User did not request any properties of sku, so just wipe it out.
4167        # Server side will pick a default.
4168        logger.debug('_find_managed_instance_sku_from_capabilities return None')
4169        return None
4170
4171    # Some properties of sku are specified, but not name. Use the requested properties
4172    # to find a matching capability and copy the sku from there.
4173
4174    # Get location capability
4175    loc_capability = _get_location_capability(cli_ctx, location, CapabilityGroup.SUPPORTED_MANAGED_INSTANCE_VERSIONS)
4176
4177    # Get default server version capability
4178    managed_instance_version_capability = _get_default_capability(loc_capability.supported_managed_instance_versions)
4179
4180    # Find edition capability, based on requested sku properties
4181    edition_capability = _find_edition_capability(sku, managed_instance_version_capability.supported_editions)
4182
4183    # Find family level capability, based on requested sku properties
4184    family_capability = _find_family_capability(sku, edition_capability.supported_families)
4185
4186    result = Sku(name=family_capability.sku)
4187    logger.debug('_find_managed_instance_sku_from_capabilities return: %s', result)
4188    return result
4189
4190
4191def managed_instance_create(
4192        cmd,
4193        client,
4194        managed_instance_name,
4195        resource_group_name,
4196        location,
4197        virtual_network_subnet_id,
4198        assign_identity=False,
4199        sku=None,
4200        key_id=None,
4201        user_assigned_identity_id=None,
4202        primary_user_assigned_identity_id=None,
4203        identity_type=None,
4204        enable_ad_only_auth=False,
4205        external_admin_principal_type=None,
4206        external_admin_sid=None,
4207        external_admin_name=None,
4208        **kwargs):
4209    '''
4210    Creates a managed instance.
4211    '''
4212
4213    if assign_identity:
4214        kwargs['identity'] = _get_identity_object_from_type(True, identity_type, user_assigned_identity_id, None)
4215    else:
4216        kwargs['identity'] = _get_identity_object_from_type(False, identity_type, user_assigned_identity_id, None)
4217
4218    kwargs['location'] = location
4219    kwargs['sku'] = _find_managed_instance_sku_from_capabilities(cmd.cli_ctx, kwargs['location'], sku)
4220    kwargs['subnet_id'] = virtual_network_subnet_id
4221    kwargs['maintenance_configuration_id'] = _complete_maintenance_configuration_id(cmd.cli_ctx, kwargs['maintenance_configuration_id'])
4222
4223    if not kwargs['yes'] and kwargs['location'].lower() in ['southeastasia', 'brazilsouth', 'eastasia']:
4224        if kwargs['storage_account_type'] == 'GRS':
4225            confirmation = prompt_y_n("""Selected value for backup storage redundancy is geo-redundant storage.
4226             Note that database backups will be geo-replicated to the paired region.
4227             To learn more about Azure Paired Regions visit https://aka.ms/azure-ragrs-regions.
4228             Do you want to proceed?""")
4229            if not confirmation:
4230                return
4231
4232        if not kwargs['storage_account_type']:
4233            confirmation = prompt_y_n("""You have not specified the value for backup storage redundancy
4234            which will default to geo-redundant storage. Note that database backups will be geo-replicated
4235            to the paired region. To learn more about Azure Paired Regions visit https://aka.ms/azure-ragrs-regions.
4236            Do you want to proceed?""")
4237            if not confirmation:
4238                return
4239
4240    kwargs['key_id'] = key_id
4241
4242    kwargs['primary_user_assigned_identity_id'] = primary_user_assigned_identity_id
4243
4244    ad_only = None
4245    if enable_ad_only_auth:
4246        ad_only = True
4247
4248    tenant_id = None
4249    if external_admin_name is not None:
4250        tenant_id = _get_tenant_id()
4251
4252    kwargs['administrators'] = ManagedInstanceExternalAdministrator(
4253        principal_type=external_admin_principal_type,
4254        login=external_admin_name,
4255        sid=external_admin_sid,
4256        azure_ad_only_authentication=ad_only,
4257        tenant_id=tenant_id)
4258
4259    # Create
4260    return client.begin_create_or_update(
4261        managed_instance_name=managed_instance_name,
4262        resource_group_name=resource_group_name,
4263        parameters=kwargs)
4264
4265
4266def managed_instance_list(
4267        client,
4268        resource_group_name=None,
4269        expand_ad_admin=False):
4270    '''
4271    Lists servers in a resource group or subscription
4272    '''
4273
4274    expand = None
4275    if expand_ad_admin:
4276        expand = 'administrators/activedirectory'
4277
4278    if resource_group_name:
4279        # List all managed instances in the resource group
4280        return client.list_by_resource_group(resource_group_name=resource_group_name, expand=expand)
4281
4282    # List all managed instances in the subscription
4283    return client.list(expand)
4284
4285
4286def managed_instance_get(
4287        client,
4288        resource_group_name,
4289        managed_instance_name,
4290        expand_ad_admin=False):
4291    '''
4292    Gets a Managed Instance
4293    '''
4294
4295    expand = None
4296    if expand_ad_admin:
4297        expand = 'administrators/activedirectory'
4298
4299    # List all servers in the subscription
4300    return client.get(resource_group_name, managed_instance_name, expand)
4301
4302
4303def managed_instance_update(
4304        cmd,
4305        instance,
4306        administrator_login_password=None,
4307        license_type=None,
4308        vcores=None,
4309        storage_size_in_gb=None,
4310        assign_identity=False,
4311        proxy_override=None,
4312        public_data_endpoint_enabled=None,
4313        tier=None,
4314        family=None,
4315        minimal_tls_version=None,
4316        tags=None,
4317        maintenance_configuration_id=None,
4318        primary_user_assigned_identity_id=None,
4319        key_id=None,
4320        identity_type=None,
4321        user_assigned_identity_id=None,
4322        virtual_network_subnet_id=None):
4323    '''
4324    Updates a managed instance. Custom update function to apply parameters to instance.
4325    '''
4326
4327    # Once assigned, the identity cannot be removed
4328    instance.identity = _get_identity_object_from_type(
4329        assign_identity,
4330        identity_type,
4331        user_assigned_identity_id,
4332        instance.identity)
4333
4334    # Apply params to instance
4335    instance.administrator_login_password = (
4336        administrator_login_password or instance.administrator_login_password)
4337    instance.license_type = (
4338        license_type or instance.license_type)
4339    instance.v_cores = (
4340        vcores or instance.v_cores)
4341    instance.storage_size_in_gb = (
4342        storage_size_in_gb or instance.storage_size_in_gb)
4343    instance.proxy_override = (
4344        proxy_override or instance.proxy_override)
4345    instance.minimal_tls_version = (
4346        minimal_tls_version or instance.minimal_tls_version)
4347
4348    instance.sku.name = None
4349    instance.sku.tier = (
4350        tier or instance.sku.tier)
4351    instance.sku.family = (
4352        family or instance.sku.family)
4353    instance.sku = _find_managed_instance_sku_from_capabilities(
4354        cmd.cli_ctx,
4355        instance.location,
4356        instance.sku)
4357
4358    if public_data_endpoint_enabled is not None:
4359        instance.public_data_endpoint_enabled = public_data_endpoint_enabled
4360
4361    if tags is not None:
4362        instance.tags = tags
4363
4364    instance.maintenance_configuration_id = _complete_maintenance_configuration_id(cmd.cli_ctx, maintenance_configuration_id)
4365
4366    instance.primary_user_assigned_identity_id = (
4367        primary_user_assigned_identity_id or instance.primary_user_assigned_identity_id)
4368
4369    instance.key_id = (key_id or instance.key_id)
4370
4371    if virtual_network_subnet_id is not None:
4372        instance.subnet_id = virtual_network_subnet_id
4373
4374    return instance
4375
4376
4377#####
4378#           sql managed instance key
4379#####
4380
4381
4382def managed_instance_key_create(
4383        client,
4384        resource_group_name,
4385        managed_instance_name,
4386        kid=None):
4387    '''
4388    Creates a managed instance key.
4389    '''
4390
4391    key_name = _get_server_key_name_from_uri(kid)
4392
4393    return client.begin_create_or_update(
4394        resource_group_name=resource_group_name,
4395        managed_instance_name=managed_instance_name,
4396        key_name=key_name,
4397        parameters=ManagedInstanceKey(
4398            server_key_type=ServerKeyType.AZURE_KEY_VAULT,
4399            uri=kid
4400        )
4401    )
4402
4403
4404def managed_instance_key_get(
4405        client,
4406        resource_group_name,
4407        managed_instance_name,
4408        kid):
4409    '''
4410    Gets a managed instance key.
4411    '''
4412
4413    key_name = _get_server_key_name_from_uri(kid)
4414
4415    return client.get(
4416        resource_group_name=resource_group_name,
4417        managed_instance_name=managed_instance_name,
4418        key_name=key_name)
4419
4420
4421def managed_instance_key_delete(
4422        client,
4423        resource_group_name,
4424        managed_instance_name,
4425        kid):
4426    '''
4427    Deletes a managed instance key.
4428    '''
4429
4430    key_name = _get_server_key_name_from_uri(kid)
4431
4432    return client.begin_delete(
4433        resource_group_name=resource_group_name,
4434        managed_instance_name=managed_instance_name,
4435        key_name=key_name)
4436
4437#####
4438#           sql managed instance encryption-protector
4439#####
4440
4441
4442def managed_instance_encryption_protector_update(
4443        client,
4444        resource_group_name,
4445        managed_instance_name,
4446        server_key_type,
4447        kid=None,
4448        auto_rotation_enabled=None):
4449    '''
4450    Updates a server encryption protector.
4451    '''
4452
4453    if server_key_type == ServerKeyType.SERVICE_MANAGED:
4454        key_name = 'ServiceManaged'
4455    else:
4456        if kid is None:
4457            raise CLIError('A uri must be provided if the server_key_type is AzureKeyVault.')
4458        key_name = _get_server_key_name_from_uri(kid)
4459
4460    return client.begin_create_or_update(
4461        resource_group_name=resource_group_name,
4462        managed_instance_name=managed_instance_name,
4463        encryption_protector_name=EncryptionProtectorName.CURRENT,
4464        parameters=ManagedInstanceEncryptionProtector(server_key_type=server_key_type,
4465                                                      server_key_name=key_name,
4466                                                      auto_rotation_enabled=auto_rotation_enabled))
4467
4468
4469def managed_instance_encryption_protector_get(
4470        client,
4471        resource_group_name,
4472        managed_instance_name):
4473    '''
4474    Shows a server encryption protector.
4475    '''
4476
4477    return client.get(
4478        resource_group_name=resource_group_name,
4479        managed_instance_name=managed_instance_name,
4480        encryption_protector_name=EncryptionProtectorName.CURRENT)
4481
4482
4483#####
4484#           sql managed instance ad-admin
4485#####
4486
4487
4488def mi_ad_admin_set(
4489        client,
4490        resource_group_name,
4491        managed_instance_name,
4492        **kwargs):
4493    '''
4494    Creates a managed instance active directory administrator.
4495    '''
4496
4497    kwargs['tenant_id'] = _get_tenant_id()
4498    kwargs['administrator_type'] = AdministratorType.ACTIVE_DIRECTORY
4499
4500    return client.begin_create_or_update(
4501        resource_group_name=resource_group_name,
4502        managed_instance_name=managed_instance_name,
4503        administrator_name=AdministratorName.ACTIVE_DIRECTORY,
4504        parameters=kwargs
4505    )
4506
4507
4508def mi_ad_admin_delete(
4509        client,
4510        resource_group_name,
4511        managed_instance_name):
4512    '''
4513    Deletes a managed instance active directory administrator.
4514    '''
4515
4516    return client.begin_delete(
4517        resource_group_name=resource_group_name,
4518        managed_instance_name=managed_instance_name,
4519        administrator_name=AdministratorName.ACTIVE_DIRECTORY
4520    )
4521
4522
4523#####
4524#           sql managed instance aad-only
4525#####
4526
4527
4528def mi_aad_only_disable(
4529        client,
4530        resource_group_name,
4531        managed_instance_name):
4532    '''
4533    Disables the managed instance AAD-only setting
4534    '''
4535
4536    return client.begin_create_or_update(
4537        resource_group_name=resource_group_name,
4538        managed_instance_name=managed_instance_name,
4539        authentication_name=AuthenticationName.DEFAULT,
4540        parameters=ManagedInstanceAzureADOnlyAuthentication(
4541            azure_ad_only_authentication=False
4542        )
4543    )
4544
4545
4546def mi_aad_only_enable(
4547        client,
4548        resource_group_name,
4549        managed_instance_name):
4550    '''
4551    Enables the AAD-only setting
4552    '''
4553
4554    return client.begin_create_or_update(
4555        resource_group_name=resource_group_name,
4556        managed_instance_name=managed_instance_name,
4557        authentication_name=AuthenticationName.DEFAULT,
4558        parameters=ManagedInstanceAzureADOnlyAuthentication(
4559            azure_ad_only_authentication=True
4560        )
4561    )
4562
4563
4564def mi_aad_only_get(
4565        client,
4566        resource_group_name,
4567        managed_instance_name):
4568    '''
4569    Gets the AAD-only setting
4570    '''
4571
4572    return client.get(
4573        resource_group_name=resource_group_name,
4574        managed_instance_name=managed_instance_name,
4575        authentication_name=AuthenticationName.DEFAULT
4576    )
4577
4578###############################################
4579#                sql managed db               #
4580###############################################
4581
4582
4583def managed_db_create(
4584        cmd,
4585        client,
4586        database_name,
4587        managed_instance_name,
4588        resource_group_name,
4589        **kwargs):
4590
4591    # Determine managed instance location
4592    kwargs['location'] = _get_managed_instance_location(
4593        cmd.cli_ctx,
4594        managed_instance_name=managed_instance_name,
4595        resource_group_name=resource_group_name)
4596
4597    # Create
4598    return client.begin_create_or_update(
4599        database_name=database_name,
4600        managed_instance_name=managed_instance_name,
4601        resource_group_name=resource_group_name,
4602        parameters=kwargs)
4603
4604
4605def managed_db_restore(
4606        cmd,
4607        client,
4608        database_name,
4609        managed_instance_name,
4610        resource_group_name,
4611        target_managed_database_name,
4612        target_managed_instance_name=None,
4613        target_resource_group_name=None,
4614        deleted_time=None,
4615        **kwargs):
4616    '''
4617    Restores an existing managed DB (i.e. create with 'PointInTimeRestore' create mode.)
4618
4619    Custom function makes create mode more convenient.
4620    '''
4621
4622    if not target_managed_instance_name:
4623        target_managed_instance_name = managed_instance_name
4624
4625    if not target_resource_group_name:
4626        target_resource_group_name = resource_group_name
4627
4628    kwargs['location'] = _get_managed_instance_location(
4629        cmd.cli_ctx,
4630        managed_instance_name=managed_instance_name,
4631        resource_group_name=resource_group_name)
4632
4633    kwargs['create_mode'] = CreateMode.POINT_IN_TIME_RESTORE
4634
4635    if deleted_time:
4636        kwargs['restorable_dropped_database_id'] = _get_managed_dropped_db_resource_id(
4637            cmd.cli_ctx,
4638            resource_group_name,
4639            managed_instance_name,
4640            database_name,
4641            deleted_time)
4642    else:
4643        kwargs['source_database_id'] = _get_managed_db_resource_id(
4644            cmd.cli_ctx,
4645            resource_group_name,
4646            managed_instance_name,
4647            database_name)
4648
4649    return client.begin_create_or_update(
4650        database_name=target_managed_database_name,
4651        managed_instance_name=target_managed_instance_name,
4652        resource_group_name=target_resource_group_name,
4653        parameters=kwargs)
4654
4655
4656def update_short_term_retention_mi(
4657        cmd,
4658        client,
4659        database_name,
4660        managed_instance_name,
4661        resource_group_name,
4662        retention_days,
4663        deleted_time=None,
4664        **kwargs):
4665    '''
4666    Updates short term retention for database
4667    '''
4668
4669    kwargs['retention_days'] = retention_days
4670
4671    if deleted_time:
4672        database_name = '{},{}'.format(
4673            database_name,
4674            _to_filetimeutc(deleted_time))
4675
4676        client = \
4677            get_sql_restorable_dropped_database_managed_backup_short_term_retention_policies_operations(
4678                cmd.cli_ctx,
4679                None)
4680
4681        policy = client.begin_create_or_update(
4682            restorable_dropped_database_id=database_name,
4683            managed_instance_name=managed_instance_name,
4684            resource_group_name=resource_group_name,
4685            policy_name=ManagedShortTermRetentionPolicyName.DEFAULT,
4686            parameters=kwargs)
4687    else:
4688        policy = client.begin_create_or_update(
4689            database_name=database_name,
4690            managed_instance_name=managed_instance_name,
4691            resource_group_name=resource_group_name,
4692            policy_name=ManagedShortTermRetentionPolicyName.DEFAULT,
4693            parameters=kwargs)
4694
4695    return policy
4696
4697
4698def get_short_term_retention_mi(
4699        cmd,
4700        client,
4701        database_name,
4702        managed_instance_name,
4703        resource_group_name,
4704        deleted_time=None):
4705    '''
4706    Gets short term retention for database
4707    '''
4708
4709    if deleted_time:
4710        database_name = '{},{}'.format(
4711            database_name,
4712            _to_filetimeutc(deleted_time))
4713
4714        client = \
4715            get_sql_restorable_dropped_database_managed_backup_short_term_retention_policies_operations(
4716                cmd.cli_ctx,
4717                None)
4718
4719        policy = client.get(
4720            restorable_dropped_database_id=database_name,
4721            managed_instance_name=managed_instance_name,
4722            resource_group_name=resource_group_name,
4723            policy_name=ManagedShortTermRetentionPolicyName.DEFAULT)
4724    else:
4725        policy = client.get(
4726            database_name=database_name,
4727            managed_instance_name=managed_instance_name,
4728            resource_group_name=resource_group_name,
4729            policy_name=ManagedShortTermRetentionPolicyName.DEFAULT)
4730
4731    return policy
4732
4733
4734def _is_int(retention):
4735    try:
4736        int(retention)
4737        return True
4738    except ValueError:
4739        return False
4740
4741
4742def update_long_term_retention_mi(
4743        client,
4744        database_name,
4745        managed_instance_name,
4746        resource_group_name,
4747        weekly_retention=None,
4748        monthly_retention=None,
4749        yearly_retention=None,
4750        week_of_year=None,
4751        **kwargs):
4752    '''
4753    Updates long term retention for managed database
4754    '''
4755
4756    if not (weekly_retention or monthly_retention or yearly_retention):
4757        raise CLIError('Please specify retention setting(s).  See \'--help\' for more details.')
4758
4759    if yearly_retention and not week_of_year:
4760        raise CLIError('Please specify week of year for yearly retention.')
4761
4762    # if an int is provided for retention, convert to ISO 8601 using days
4763    if (weekly_retention and _is_int(weekly_retention)):
4764        weekly_retention = 'P%sD' % weekly_retention
4765        print(weekly_retention)
4766
4767    if (monthly_retention and _is_int(monthly_retention)):
4768        monthly_retention = 'P%sD' % monthly_retention
4769
4770    if (yearly_retention and _is_int(yearly_retention)):
4771        yearly_retention = 'P%sD' % yearly_retention
4772
4773    kwargs['weekly_retention'] = weekly_retention
4774
4775    kwargs['monthly_retention'] = monthly_retention
4776
4777    kwargs['yearly_retention'] = yearly_retention
4778
4779    kwargs['week_of_year'] = week_of_year
4780
4781    policy = client.begin_create_or_update(
4782        database_name=database_name,
4783        managed_instance_name=managed_instance_name,
4784        resource_group_name=resource_group_name,
4785        policy_name=ManagedInstanceLongTermRetentionPolicyName.DEFAULT,
4786        parameters=kwargs)
4787
4788    return policy
4789
4790
4791def get_long_term_retention_mi(
4792        client,
4793        database_name,
4794        managed_instance_name,
4795        resource_group_name):
4796    '''
4797    Gets long term retention for managed database
4798    '''
4799
4800    return client.get(
4801        database_name=database_name,
4802        managed_instance_name=managed_instance_name,
4803        resource_group_name=resource_group_name,
4804        policy_name=ManagedInstanceLongTermRetentionPolicyName.DEFAULT)
4805
4806
4807def _get_backup_id_resource_values(backup_id):
4808    '''
4809    Extract resource values from the backup id
4810    '''
4811
4812    backup_id = backup_id.replace('\'', '')
4813    backup_id = backup_id.replace('"', '')
4814
4815    if backup_id[0] == '/':
4816        # remove leading /
4817        backup_id = backup_id[1:]
4818
4819    resources_list = backup_id.split('/')
4820    resources_dict = {resources_list[i]: resources_list[i + 1] for i in range(0, len(resources_list), 2)}
4821
4822    if not ('locations'.casefold() in resources_dict and
4823            'longTermRetentionManagedInstances'.casefold() not in resources_dict and
4824            'longTermRetentionDatabases'.casefold() not in resources_dict and
4825            'longTermRetentionManagedInstanceBackups'.casefold() not in resources_dict):
4826
4827        # backup ID should contain all these
4828        raise CLIError('Please provide a valid resource URI.  See --help for example.')
4829
4830    return resources_dict
4831
4832
4833def get_long_term_retention_mi_backup(
4834        client,
4835        location_name=None,
4836        managed_instance_name=None,
4837        database_name=None,
4838        backup_name=None,
4839        backup_id=None):
4840    '''
4841    Gets the requested long term retention backup.
4842    '''
4843
4844    if backup_id:
4845        resources_dict = _get_backup_id_resource_values(backup_id)
4846
4847        location_name = resources_dict['locations']
4848        managed_instance_name = resources_dict['longTermRetentionManagedInstances']
4849        database_name = resources_dict['longTermRetentionDatabases']
4850        backup_name = resources_dict['longTermRetentionManagedInstanceBackups']
4851
4852    return client.get(
4853        location_name=location_name,
4854        managed_instance_name=managed_instance_name,
4855        database_name=database_name,
4856        backup_name=backup_name)
4857
4858
4859def _list_by_database_long_term_retention_mi_backups(
4860        client,
4861        location_name,
4862        managed_instance_name,
4863        database_name,
4864        resource_group_name=None,
4865        only_latest_per_database=None,
4866        database_state=None):
4867    '''
4868    Gets the long term retention backups for a Managed Database
4869    '''
4870
4871    if resource_group_name:
4872        backups = client.list_by_resource_group_database(
4873            resource_group_name=resource_group_name,
4874            location_name=location_name,
4875            managed_instance_name=managed_instance_name,
4876            database_name=database_name,
4877            only_latest_per_database=only_latest_per_database,
4878            database_state=database_state)
4879    else:
4880        backups = client.list_by_database(
4881            location_name=location_name,
4882            managed_instance_name=managed_instance_name,
4883            database_name=database_name,
4884            only_latest_per_database=only_latest_per_database,
4885            database_state=database_state)
4886
4887    return backups
4888
4889
4890def _list_by_instance_long_term_retention_mi_backups(
4891        client,
4892        location_name,
4893        managed_instance_name,
4894        resource_group_name=None,
4895        only_latest_per_database=None,
4896        database_state=None):
4897    '''
4898    Gets the long term retention backups within a Managed Instance
4899    '''
4900
4901    if resource_group_name:
4902        backups = client.list_by_resource_group_instance(
4903            resource_group_name=resource_group_name,
4904            location_name=location_name,
4905            managed_instance_name=managed_instance_name,
4906            only_latest_per_database=only_latest_per_database,
4907            database_state=database_state)
4908    else:
4909        backups = client.list_by_instance(
4910            location_name=location_name,
4911            managed_instance_name=managed_instance_name,
4912            only_latest_per_database=only_latest_per_database,
4913            database_state=database_state)
4914
4915    return backups
4916
4917
4918def _list_by_location_long_term_retention_mi_backups(
4919        client,
4920        location_name,
4921        resource_group_name=None,
4922        only_latest_per_database=None,
4923        database_state=None):
4924    '''
4925    Gets the long term retention backups within a specified region.
4926    '''
4927
4928    if resource_group_name:
4929        backups = client.list_by_resource_group_location(
4930            resource_group_name=resource_group_name,
4931            location_name=location_name,
4932            only_latest_per_database=only_latest_per_database,
4933            database_state=database_state)
4934    else:
4935        backups = client.list_by_location(
4936            location_name=location_name,
4937            only_latest_per_database=only_latest_per_database,
4938            database_state=database_state)
4939
4940    return backups
4941
4942
4943def list_long_term_retention_mi_backups(
4944        client,
4945        location_name,
4946        managed_instance_name=None,
4947        database_name=None,
4948        resource_group_name=None,
4949        only_latest_per_database=None,
4950        database_state=None):
4951    '''
4952    Lists the long term retention backups for a specified location, instance, or database.
4953    '''
4954
4955    if managed_instance_name:
4956        if database_name:
4957            backups = _list_by_database_long_term_retention_mi_backups(
4958                client,
4959                location_name,
4960                managed_instance_name,
4961                database_name,
4962                resource_group_name,
4963                only_latest_per_database,
4964                database_state)
4965
4966        else:
4967            backups = _list_by_instance_long_term_retention_mi_backups(
4968                client,
4969                location_name,
4970                managed_instance_name,
4971                resource_group_name,
4972                only_latest_per_database,
4973                database_state)
4974    else:
4975        backups = _list_by_location_long_term_retention_mi_backups(
4976            client,
4977            location_name,
4978            resource_group_name,
4979            only_latest_per_database,
4980            database_state)
4981
4982    return backups
4983
4984
4985def delete_long_term_retention_mi_backup(
4986        client,
4987        location_name=None,
4988        managed_instance_name=None,
4989        database_name=None,
4990        backup_name=None,
4991        backup_id=None):
4992    '''
4993    Deletes the requested long term retention backup.
4994    '''
4995
4996    if backup_id:
4997        resources_dict = _get_backup_id_resource_values(backup_id)
4998
4999        location_name = resources_dict['locations']
5000        managed_instance_name = resources_dict['longTermRetentionManagedInstances']
5001        database_name = resources_dict['longTermRetentionDatabases']
5002        backup_name = resources_dict['longTermRetentionManagedInstanceBackups']
5003
5004    return client.begin_delete(
5005        location_name=location_name,
5006        managed_instance_name=managed_instance_name,
5007        database_name=database_name,
5008        backup_name=backup_name)
5009
5010
5011def restore_long_term_retention_mi_backup(
5012        cmd,
5013        client,
5014        long_term_retention_backup_resource_id,
5015        target_managed_database_name,
5016        target_managed_instance_name,
5017        target_resource_group_name,
5018        **kwargs):
5019    '''
5020    Restores an existing managed DB (i.e. create with 'RestoreLongTermRetentionBackup' create mode.)
5021    '''
5022
5023    if not target_resource_group_name or not target_managed_instance_name or not target_managed_database_name:
5024        raise CLIError('Please specify target resource(s). '
5025                       'Target resource group, target instance, and target database '
5026                       'are all required for restore LTR backup.')
5027
5028    if not long_term_retention_backup_resource_id:
5029        raise CLIError('Please specify a long term retention backup.')
5030
5031    kwargs['location'] = _get_managed_instance_location(
5032        cmd.cli_ctx,
5033        managed_instance_name=target_managed_instance_name,
5034        resource_group_name=target_resource_group_name)
5035
5036    kwargs['create_mode'] = CreateMode.RESTORE_LONG_TERM_RETENTION_BACKUP
5037    kwargs['long_term_retention_backup_resource_id'] = long_term_retention_backup_resource_id
5038
5039    return client.begin_create_or_update(
5040        database_name=target_managed_database_name,
5041        managed_instance_name=target_managed_instance_name,
5042        resource_group_name=target_resource_group_name,
5043        parameters=kwargs)
5044
5045
5046def managed_db_log_replay_start(
5047        cmd,
5048        client,
5049        database_name,
5050        managed_instance_name,
5051        resource_group_name,
5052        auto_complete,
5053        last_backup_name,
5054        storage_container_uri,
5055        storage_container_sas_token,
5056        **kwargs):
5057    '''
5058    Start a log replay restore.
5059    '''
5060
5061    # Determine managed instance location
5062    kwargs['location'] = _get_managed_instance_location(
5063        cmd.cli_ctx,
5064        managed_instance_name=managed_instance_name,
5065        resource_group_name=resource_group_name)
5066
5067    kwargs['create_mode'] = CreateMode.RESTORE_EXTERNAL_BACKUP
5068
5069    if auto_complete and not last_backup_name:
5070        raise CLIError('Please specify a last backup name when using auto complete flag.')
5071
5072    kwargs['auto_complete_restore'] = auto_complete
5073    kwargs['last_backup_name'] = last_backup_name
5074
5075    kwargs['storageContainerUri'] = storage_container_uri
5076    kwargs['storageContainerSasToken'] = storage_container_sas_token
5077
5078    # Create
5079    return client.begin_create_or_update(
5080        database_name=database_name,
5081        managed_instance_name=managed_instance_name,
5082        resource_group_name=resource_group_name,
5083        parameters=kwargs)
5084
5085
5086def managed_db_log_replay_complete_restore(
5087        client,
5088        database_name,
5089        managed_instance_name,
5090        resource_group_name,
5091        **kwargs):
5092    '''
5093    Complete a log replay restore.
5094    '''
5095
5096    return client.begin_complete_restore(
5097        database_name=database_name,
5098        managed_instance_name=managed_instance_name,
5099        resource_group_name=resource_group_name,
5100        parameters=kwargs)
5101
5102
5103def managed_db_log_replay_get(
5104        client,
5105        database_name,
5106        managed_instance_name,
5107        resource_group_name):
5108    '''
5109    Gets a log replay restore.
5110    '''
5111
5112    return client.get(
5113        database_name=database_name,
5114        managed_instance_name=managed_instance_name,
5115        resource_group_name=resource_group_name,
5116        restore_details_name=RestoreDetailsName.DEFAULT)
5117
5118###############################################
5119#              sql failover-group             #
5120###############################################
5121
5122
5123def failover_group_create(
5124        cmd,
5125        client,
5126        resource_group_name,
5127        server_name,
5128        failover_group_name,
5129        partner_server,
5130        partner_resource_group=None,
5131        failover_policy=FailoverPolicyType.automatic.value,
5132        grace_period=1,
5133        add_db=None):
5134    '''
5135    Creates a failover group.
5136    '''
5137
5138    from six.moves.urllib.parse import quote  # pylint: disable=import-error
5139    from azure.cli.core.commands.client_factory import get_subscription_id
5140
5141    # Build the partner server id
5142    partner_server_id = "/subscriptions/{}/resourceGroups/{}/providers/Microsoft.Sql/servers/{}".format(
5143        quote(get_subscription_id(cmd.cli_ctx)),
5144        quote(partner_resource_group or resource_group_name),
5145        quote(partner_server))
5146
5147    partner_server = PartnerInfo(id=partner_server_id)
5148
5149    # Convert grace period from hours to minutes
5150    grace_period = int(grace_period) * 60
5151
5152    if failover_policy == FailoverPolicyType.manual.value:
5153        grace_period = None
5154
5155    if add_db is None:
5156        add_db = []
5157
5158    databases = _get_list_of_databases_for_fg(
5159        cmd,
5160        resource_group_name,
5161        server_name,
5162        [],
5163        add_db,
5164        [])
5165
5166    return client.begin_create_or_update(
5167        resource_group_name=resource_group_name,
5168        server_name=server_name,
5169        failover_group_name=failover_group_name,
5170        parameters=FailoverGroup(
5171            partner_servers=[partner_server],
5172            databases=databases,
5173            read_write_endpoint=FailoverGroupReadWriteEndpoint(
5174                failover_policy=failover_policy,
5175                failover_with_data_loss_grace_period_minutes=grace_period),
5176            read_only_endpoint=FailoverGroupReadOnlyEndpoint(
5177                failover_policy="Disabled")))
5178
5179
5180def failover_group_update(
5181        cmd,
5182        instance,
5183        resource_group_name,
5184        server_name,
5185        failover_policy=None,
5186        grace_period=None,
5187        add_db=None,
5188        remove_db=None):
5189    '''
5190    Updates the failover group.
5191    '''
5192
5193    _failover_group_update_common(
5194        instance,
5195        failover_policy,
5196        grace_period)
5197
5198    if add_db is None:
5199        add_db = []
5200
5201    if remove_db is None:
5202        remove_db = []
5203
5204    databases = _get_list_of_databases_for_fg(
5205        cmd,
5206        resource_group_name,
5207        server_name,
5208        instance.databases,
5209        add_db,
5210        remove_db)
5211
5212    instance.databases = databases
5213
5214    return instance
5215
5216
5217def failover_group_failover(
5218        client,
5219        resource_group_name,
5220        server_name,
5221        failover_group_name,
5222        allow_data_loss=False):
5223    '''
5224    Failover a failover group.
5225    '''
5226
5227    failover_group = client.get(
5228        resource_group_name=resource_group_name,
5229        server_name=server_name,
5230        failover_group_name=failover_group_name)
5231
5232    if failover_group.replication_role == "Primary":
5233        return
5234
5235    # Choose which failover method to use
5236    if allow_data_loss:
5237        failover_func = client.begin_force_failover_allow_data_loss
5238    else:
5239        failover_func = client.begin_failover
5240
5241    return failover_func(
5242        resource_group_name=resource_group_name,
5243        server_name=server_name,
5244        failover_group_name=failover_group_name)
5245
5246
5247def _get_list_of_databases_for_fg(
5248        cmd,
5249        resource_group_name,
5250        server_name,
5251        databases_in_fg,
5252        add_db,
5253        remove_db):
5254    '''
5255    Gets a list of databases that are supposed to be part of the failover group
5256    after the operation finishes
5257    It consolidates the list of dbs to add and remove with the list of databases
5258    that are already part of the failover group.
5259    '''
5260
5261    add_db_ids = [DatabaseIdentity(cmd.cli_ctx, d, server_name, resource_group_name).id() for d in add_db]
5262
5263    remove_db_ids = [DatabaseIdentity(cmd.cli_ctx, d, server_name, resource_group_name).id() for d in remove_db]
5264
5265    databases = list(({x.lower() for x in databases_in_fg} |
5266                      {x.lower() for x in add_db_ids}) - {x.lower() for x in remove_db_ids})
5267
5268    return databases
5269
5270###############################################
5271#                sql virtual cluster          #
5272###############################################
5273
5274
5275def virtual_cluster_list(
5276        client,
5277        resource_group_name=None):
5278    '''
5279    Lists virtual clusters in a resource group or subscription
5280    '''
5281
5282    if resource_group_name:
5283        # List all virtual clusters in the resource group
5284        return client.list_by_resource_group(resource_group_name=resource_group_name)
5285
5286    # List all virtual clusters in the subscription
5287    return client.list()
5288
5289
5290###############################################
5291#              sql instance failover group    #
5292###############################################
5293
5294def instance_failover_group_create(
5295        cmd,
5296        client,
5297        resource_group_name,
5298        managed_instance,
5299        failover_group_name,
5300        partner_managed_instance,
5301        partner_resource_group,
5302        failover_policy=FailoverPolicyType.automatic.value,
5303        grace_period=1):
5304    '''
5305    Creates a failover group.
5306    '''
5307
5308    managed_instance_client = get_sql_managed_instances_operations(cmd.cli_ctx, None)
5309    # pylint: disable=no-member
5310    primary_server = managed_instance_client.get(
5311        managed_instance_name=managed_instance,
5312        resource_group_name=resource_group_name)
5313
5314    partner_server = managed_instance_client.get(
5315        managed_instance_name=partner_managed_instance,
5316        resource_group_name=partner_resource_group)
5317
5318    # Build the partner server id
5319    managed_server_info_pair = ManagedInstancePairInfo(
5320        primary_managed_instance_id=primary_server.id,
5321        partner_managed_instance_id=partner_server.id)
5322    partner_region_info = PartnerRegionInfo(location=partner_server.location)
5323
5324    # Convert grace period from hours to minutes
5325    grace_period = int(grace_period) * 60
5326
5327    if failover_policy == FailoverPolicyType.manual.value:
5328        grace_period = None
5329
5330    return client.begin_create_or_update(
5331        resource_group_name=resource_group_name,
5332        location_name=primary_server.location,
5333        failover_group_name=failover_group_name,
5334        parameters=InstanceFailoverGroup(
5335            managed_instance_pairs=[managed_server_info_pair],
5336            partner_regions=[partner_region_info],
5337            read_write_endpoint=InstanceFailoverGroupReadWriteEndpoint(
5338                failover_policy=failover_policy,
5339                failover_with_data_loss_grace_period_minutes=grace_period),
5340            read_only_endpoint=InstanceFailoverGroupReadOnlyEndpoint(
5341                failover_policy="Disabled")))
5342
5343
5344def instance_failover_group_update(
5345        instance,
5346        failover_policy=None,
5347        grace_period=None,):
5348    '''
5349    Updates the failover group.
5350    '''
5351
5352    _failover_group_update_common(
5353        instance,
5354        failover_policy,
5355        grace_period)
5356
5357    return instance
5358
5359
5360def instance_failover_group_failover(
5361        client,
5362        resource_group_name,
5363        failover_group_name,
5364        location_name,
5365        allow_data_loss=False):
5366    '''
5367    Failover an instance failover group.
5368    '''
5369
5370    failover_group = client.get(
5371        resource_group_name=resource_group_name,
5372        failover_group_name=failover_group_name,
5373        location_name=location_name)
5374
5375    if failover_group.replication_role == "Primary":
5376        return
5377
5378    # Choose which failover method to use
5379    if allow_data_loss:
5380        failover_func = client.begin_force_failover_allow_data_loss
5381    else:
5382        failover_func = client.begin_failover
5383
5384    return failover_func(
5385        resource_group_name=resource_group_name,
5386        failover_group_name=failover_group_name,
5387        location_name=location_name)
5388
5389###############################################
5390#              sql server conn-policy         #
5391###############################################
5392
5393
5394def show_conn_policy(
5395        client,
5396        resource_group_name,
5397        server_name):
5398    '''
5399    Shows a connectin policy
5400    '''
5401    return client.get(
5402        resource_group_name=resource_group_name,
5403        server_name=server_name,
5404        connection_policy_name=ConnectionPolicyName.DEFAULT)
5405
5406
5407def update_conn_policy(
5408        client,
5409        resource_group_name,
5410        server_name,
5411        connection_type):
5412    '''
5413    Updates a connectin policy
5414    '''
5415    return client.create_or_update(
5416        resource_group_name=resource_group_name,
5417        server_name=server_name,
5418        connection_policy_name=ConnectionPolicyName.DEFAULT,
5419        parameters=ServerConnectionPolicy(
5420            connection_type=connection_type)
5421    )
5422
5423###############################################
5424#              sql db tde                     #
5425###############################################
5426
5427
5428def transparent_data_encryptions_set(
5429        client,
5430        resource_group_name,
5431        server_name,
5432        database_name,
5433        status,
5434        **kwargs):
5435    '''
5436    Sets a Transparent Data Encryption
5437    '''
5438    kwargs['status'] = status
5439
5440    return client.create_or_update(
5441        resource_group_name=resource_group_name,
5442        server_name=server_name,
5443        database_name=database_name,
5444        transparent_data_encryption_name=TransparentDataEncryptionName.CURRENT,
5445        parameters=kwargs)
5446
5447
5448def transparent_data_encryptions_get(
5449        client,
5450        resource_group_name,
5451        server_name,
5452        database_name):
5453    '''
5454    Shows a Transparent Data Encryption
5455    '''
5456
5457    return client.get(
5458        resource_group_name=resource_group_name,
5459        server_name=server_name,
5460        database_name=database_name,
5461        transparent_data_encryption_name=TransparentDataEncryptionName.CURRENT)
5462
5463
5464def tde_list_by_configuration(
5465        client,
5466        resource_group_name,
5467        server_name,
5468        database_name):
5469    '''
5470    Lists Transparent Data Encryption
5471    '''
5472
5473    return client.list_by_configuration(
5474        resource_group_name=resource_group_name,
5475        server_name=server_name,
5476        database_name=database_name,
5477        transparent_data_encryption_name=TransparentDataEncryptionName.CURRENT)
5478
5479###############################################
5480#              sql server vnet-rule           #
5481###############################################
5482
5483
5484def vnet_rule_begin_create_or_update(
5485        client,
5486        resource_group_name,
5487        server_name,
5488        virtual_network_rule_name,
5489        virtual_network_subnet_id,
5490        ignore_missing_vnet_service_endpoint=False):
5491    '''
5492    Creates or Updates Virtual Network Rules
5493    '''
5494
5495    return client.begin_create_or_update(
5496        resource_group_name=resource_group_name,
5497        server_name=server_name,
5498        virtual_network_rule_name=virtual_network_rule_name,
5499        parameters=VirtualNetworkRule(
5500            virtual_network_subnet_id=virtual_network_subnet_id,
5501            ignore_missing_vnet_service_endpoint=ignore_missing_vnet_service_endpoint)
5502    )
5503