1# --------------------------------------------------------------------------------------------
2# Copyright (c) Microsoft Corporation. All rights reserved.
3# Licensed under the MIT License. See License.txt in the project root for license information.
4# --------------------------------------------------------------------------------------------
5from knack.prompting import prompt_pass, NoTTYException
6from knack.util import CLIError
7from knack.log import get_logger
8from msrestazure.tools import parse_resource_id, resource_id, is_valid_resource_id, is_valid_resource_name
9from azure.cli.core.azclierror import ValidationError, ArgumentUsageError
10from azure.cli.core.commands.client_factory import get_mgmt_service_client, get_subscription_id
11from azure.cli.core.commands.validators import (
12    get_default_location_from_resource_group, validate_tags)
13from azure.cli.core.util import parse_proxy_resource_id
14from azure.cli.core.profiles import ResourceType
15from ._flexible_server_util import (get_mysql_versions, get_mysql_skus, get_mysql_storage_size,
16                                    get_mysql_backup_retention, get_mysql_tiers, get_mysql_list_skus_info,
17                                    get_postgres_list_skus_info, get_postgres_versions,
18                                    get_postgres_skus, get_postgres_storage_sizes, get_postgres_tiers,
19                                    _is_resource_name)
20
21logger = get_logger(__name__)
22
23
24# pylint: disable=import-outside-toplevel, raise-missing-from, unbalanced-tuple-unpacking
25def _get_resource_group_from_server_name(cli_ctx, server_name):
26    """
27    Fetch resource group from server name
28    :param str server_name: name of the server
29    :return: resource group name or None
30    :rtype: str
31    """
32
33    client = get_mgmt_service_client(cli_ctx, ResourceType.MGMT_RDBMS).servers
34    for server in client.list():
35        id_comps = parse_resource_id(server.id)
36        if id_comps['name'] == server_name:
37            return id_comps['resource_group']
38    return None
39
40
41def get_combined_validator(validators):
42    def _final_validator_impl(cmd, namespace):
43        # do additional creation validation
44        verbs = cmd.name.rsplit(' ', 2)
45        if verbs[1] == 'server' and verbs[2] == 'create':
46            password_validator(namespace)
47            get_default_location_from_resource_group(cmd, namespace)
48
49        validate_tags(namespace)
50
51        for validator in validators:
52            validator(namespace)
53
54    return _final_validator_impl
55
56
57def configuration_value_validator(ns):
58    val = ns.value
59    if val is None or not val.strip():
60        ns.value = None
61        ns.source = 'system-default'
62
63
64def tls_validator(ns):
65    if ns.minimal_tls_version:
66        if ns.ssl_enforcement is not None and ns.ssl_enforcement != 'Enabled':
67            raise CLIError('Cannot specify TLS version when ssl_enforcement is explicitly Disabled')
68
69
70def password_validator(ns):
71    if not ns.administrator_login_password:
72        try:
73            ns.administrator_login_password = prompt_pass(msg='Admin Password: ')
74        except NoTTYException:
75            raise CLIError('Please specify password in non-interactive mode.')
76
77
78def retention_validator(ns):
79    if ns.backup_retention is not None:
80        val = ns.backup_retention
81        if not 7 <= int(val) <= 35:
82            raise CLIError('incorrect usage: --backup-retention. Range is 7 to 35 days.')
83
84
85# Validates if a subnet id or name have been given by the user. If subnet id is given, vnet-name should not be provided.
86def validate_subnet(cmd, namespace):
87
88    subnet = namespace.virtual_network_subnet_id
89    subnet_is_id = is_valid_resource_id(subnet)
90    vnet = namespace.vnet_name
91
92    if (subnet_is_id and not vnet) or (not subnet and not vnet):
93        pass
94    elif subnet and not subnet_is_id and vnet:
95        namespace.virtual_network_subnet_id = resource_id(
96            subscription=get_subscription_id(cmd.cli_ctx),
97            resource_group=namespace.resource_group_name,
98            namespace='Microsoft.Network',
99            type='virtualNetworks',
100            name=vnet,
101            child_type_1='subnets',
102            child_name_1=subnet)
103    else:
104        raise CLIError('incorrect usage: [--subnet ID | --subnet NAME --vnet-name NAME]')
105    delattr(namespace, 'vnet_name')
106
107
108def validate_private_endpoint_connection_id(cmd, namespace):
109    if namespace.connection_id:
110        result = parse_proxy_resource_id(namespace.connection_id)
111        namespace.private_endpoint_connection_name = result['child_name_1']
112        namespace.server_name = result['name']
113        namespace.resource_group_name = result['resource_group']
114    if namespace.server_name and not namespace.resource_group_name:
115        namespace.resource_group_name = _get_resource_group_from_server_name(cmd.cli_ctx, namespace.server_name)
116
117    if not all([namespace.server_name, namespace.resource_group_name, namespace.private_endpoint_connection_name]):
118        raise CLIError('incorrect usage: [--id ID | --name NAME --server-name NAME]')
119
120    del namespace.connection_id
121
122
123def mysql_arguments_validator(db_context, location, tier, sku_name, storage_gb, backup_retention=None,
124                              server_name=None, zone=None, standby_availability_zone=None, high_availability=None,
125                              subnet=None, public_access=None, version=None, auto_grow=None, replication_role=None,
126                              instance=None):
127    validate_server_name(db_context, server_name, 'Microsoft.DBforMySQL/flexibleServers')
128    sku_info, single_az, _ = get_mysql_list_skus_info(db_context.cmd, location)
129    _network_arg_validator(subnet, public_access)
130    _mysql_tier_validator(tier, sku_info)  # need to be validated first
131    if tier is None and instance is not None:
132        tier = instance.sku.tier
133    _mysql_retention_validator(backup_retention, sku_info, tier)
134    _mysql_storage_validator(storage_gb, sku_info, tier, instance)
135    _mysql_sku_name_validator(sku_name, sku_info, tier, instance)
136    _mysql_high_availability_validator(high_availability, standby_availability_zone, zone, tier,
137                                       single_az, auto_grow, instance)
138    _mysql_version_validator(version, sku_info, tier, instance)
139    _mysql_auto_grow_validator(auto_grow, replication_role, high_availability, instance)
140
141
142def _mysql_retention_validator(backup_retention, sku_info, tier):
143    if backup_retention is not None:
144        backup_retention_range = get_mysql_backup_retention(sku_info, tier)
145        if not 1 <= int(backup_retention) <= backup_retention_range[1]:
146            raise CLIError('incorrect usage: --backup-retention. Range is {} to {} days.'
147                           .format(1, backup_retention_range[1]))
148
149
150def _mysql_storage_validator(storage_gb, sku_info, tier, instance):
151    if storage_gb is not None:
152        if instance:
153            original_size = instance.storage.storage_size_gb
154            if original_size > storage_gb:
155                raise CLIError('Updating storage cannot be smaller than the '
156                               'original storage size {} GiB.'.format(original_size))
157        storage_sizes = get_mysql_storage_size(sku_info, tier)
158        min_mysql_storage = 20
159        if not max(min_mysql_storage, storage_sizes[0]) <= storage_gb <= storage_sizes[1]:
160            raise CLIError('Incorrect value for --storage-size. Allowed values(in GiB) : Integers ranging {}-{}'
161                           .format(max(min_mysql_storage, storage_sizes[0]), storage_sizes[1]))
162
163
164def _mysql_tier_validator(tier, sku_info):
165    if tier:
166        tiers = get_mysql_tiers(sku_info)
167        if tier not in tiers:
168            raise CLIError('Incorrect value for --tier. Allowed values : {}'.format(tiers))
169
170
171def _mysql_sku_name_validator(sku_name, sku_info, tier, instance):
172    if instance is not None:
173        tier = instance.sku.tier if tier is None else tier
174    if sku_name:
175        skus = get_mysql_skus(sku_info, tier)
176        if sku_name not in skus:
177            raise CLIError('Incorrect value for --sku-name. The SKU name does not match tier selection. '
178                           'Default value for --tier is Burstable. '
179                           'For Memory Optimized and General Purpose you need to specify --tier value explicitly. '
180                           'Allowed values for {} tier: {}'.format(tier, skus))
181
182
183def _mysql_version_validator(version, sku_info, tier, instance):
184    if instance is not None:
185        tier = instance.sku.tier if tier is None else tier
186    if version:
187        versions = get_mysql_versions(sku_info, tier)
188        if version not in versions:
189            raise CLIError('Incorrect value for --version. Allowed values : {}'.format(versions))
190
191
192def _mysql_auto_grow_validator(auto_grow, replication_role, high_availability, instance):
193    if auto_grow is None:
194        return
195    if instance is not None:
196        replication_role = instance.replication_role if replication_role is None else replication_role
197        high_availability = instance.high_availability.mode if high_availability is None else high_availability
198    # if replica, cannot be disabled
199    if replication_role != 'None' and auto_grow.lower() == 'disabled':
200        raise ValidationError("Auto grow feature for replica server cannot be disabled.")
201    # if ha, cannot be disabled
202    if high_availability in ['Enabled', 'ZoneRedundant'] and auto_grow.lower() == 'disabled':
203        raise ValidationError("Auto grow feature for high availability server cannot be disabled.")
204
205
206def _mysql_high_availability_validator(high_availability, standby_availability_zone, zone, tier, single_az,
207                                       auto_grow, instance):
208    if instance:
209        tier = instance.sku.tier if tier is None else tier
210        auto_grow = instance.storage.auto_grow if auto_grow is None else auto_grow
211        zone = instance.availability_zone if zone is None else zone
212    if high_availability is not None and high_availability.lower() != 'disabled':
213        if tier == 'Burstable':
214            raise ArgumentUsageError("High availability is not supported for Burstable tier")
215        if single_az and high_availability.lower() == 'zoneredundant':
216            raise ArgumentUsageError("This region is single availability zone. "
217                                     "Zone redundant high availability is not supported "
218                                     "in a single availability zone region.")
219        if auto_grow.lower == 'Disabled':
220            raise ArgumentUsageError("Enabling High Availability requires Auto grow to be turned ON.")
221    if standby_availability_zone:
222        if not high_availability or high_availability.lower() != 'zoneredundant':
223            raise ArgumentUsageError("You need to enable high availability to set standby availability zone.")
224        if zone == standby_availability_zone:
225            raise ArgumentUsageError("Your server is in availability zone {}. "
226                                     "The zone of the server cannot be same as the standby zone.".format(zone))
227
228
229def pg_arguments_validator(db_context, location, tier, sku_name, storage_gb, server_name=None, zone=None,
230                           standby_availability_zone=None, high_availability=None, subnet=None, public_access=None,
231                           version=None, instance=None):
232    validate_server_name(db_context, server_name, 'Microsoft.DBforPostgreSQL/flexibleServers')
233    sku_info, single_az = get_postgres_list_skus_info(db_context.cmd, location)
234    _network_arg_validator(subnet, public_access)
235    _pg_tier_validator(tier, sku_info)  # need to be validated first
236    if tier is None and instance is not None:
237        tier = instance.sku.tier
238    _pg_storage_validator(storage_gb, sku_info, tier, instance)
239    _pg_sku_name_validator(sku_name, sku_info, tier, instance)
240    _pg_high_availability_validator(high_availability, standby_availability_zone, zone, tier, single_az, instance)
241    _pg_version_validator(version, sku_info, tier, instance)
242
243
244def _pg_storage_validator(storage_gb, sku_info, tier, instance):
245    if storage_gb is not None:
246        if instance is not None:
247            original_size = instance.storage.storage_size_gb
248            if original_size > storage_gb:
249                raise CLIError('Updating storage cannot be smaller than '
250                               'the original storage size {} GiB.'.format(original_size))
251        storage_sizes = get_postgres_storage_sizes(sku_info, tier)
252        if storage_gb not in storage_sizes:
253            storage_sizes = sorted([int(size) for size in storage_sizes])
254            raise CLIError('Incorrect value for --storage-size : Allowed values(in GiB) : {}'
255                           .format(storage_sizes))
256
257
258def _pg_tier_validator(tier, sku_info):
259    if tier:
260        tiers = get_postgres_tiers(sku_info)
261        if tier not in tiers:
262            raise CLIError('Incorrect value for --tier. Allowed values : {}'.format(tiers))
263
264
265def _pg_sku_name_validator(sku_name, sku_info, tier, instance):
266    if instance is not None:
267        tier = instance.sku.tier if tier is None else tier
268    if sku_name:
269        skus = get_postgres_skus(sku_info, tier)
270        if sku_name not in skus:
271            raise CLIError('Incorrect value for --sku-name. The SKU name does not match {} tier. '
272                           'Specify --tier if you did not. Or CLI will set GeneralPurpose as the default tier. '
273                           'Allowed values : {}'.format(tier, skus))
274
275
276def _pg_version_validator(version, sku_info, tier, instance):
277    if instance is not None:
278        tier = instance.sku.tier if tier is None else tier
279    if version:
280        versions = get_postgres_versions(sku_info, tier)
281        if version not in versions:
282            raise CLIError('Incorrect value for --version. Allowed values : {}'.format(versions))
283
284
285def _pg_high_availability_validator(high_availability, standby_availability_zone, zone, tier, single_az, instance):
286    if instance:
287        tier = instance.sku.tier if tier is None else tier
288    if high_availability is not None and high_availability.lower() == 'enabled':
289        if tier == 'Burstable':
290            raise ArgumentUsageError("High availability is not supported for Burstable tier")
291        if single_az:
292            raise ArgumentUsageError("This region is single availability zone."
293                                     "High availability is not supported in a single availability zone region.")
294
295    if standby_availability_zone:
296        if not high_availability:
297            raise ArgumentUsageError("You need to enable high availability to set standby availability zone.")
298        if zone == standby_availability_zone:
299            raise ArgumentUsageError("The zone of the server cannot be same as standby zone.")
300
301
302def _network_arg_validator(subnet, public_access):
303    if subnet is not None and public_access is not None:
304        raise CLIError("Incorrect usage : A combination of the parameters --subnet "
305                       "and --public-access is invalid. Use either one of them.")
306
307
308def maintenance_window_validator(ns):
309    options = ["Sun", "Mon", "Tue", "Wed", "Thu", "Fri", "Sat", "Disabled", "disabled"]
310    if ns.maintenance_window:
311        parsed_input = ns.maintenance_window.split(':')
312        if not parsed_input or len(parsed_input) > 3:
313            raise CLIError('Incorrect value for --maintenance-window. '
314                           'Enter <Day>:<Hour>:<Minute>. Example: "Mon:8:30" to schedule on Monday, 8:30 UTC')
315        if len(parsed_input) >= 1 and parsed_input[0] not in options:
316            raise CLIError('Incorrect value for --maintenance-window. '
317                           'The first value means the scheduled day in a week or '
318                           'can be "Disabled" to reset maintenance window.'
319                           'Allowed values: {"Sun","Mon","Tue","Wed","Thu","Fri","Sat"}')
320        if len(parsed_input) >= 2 and \
321           (not parsed_input[1].isdigit() or int(parsed_input[1]) < 0 or int(parsed_input[1]) > 23):
322            raise CLIError('Incorrect value for --maintenance-window. '
323                           'The second number means the scheduled hour in the scheduled day. '
324                           'Allowed values: {0, 1, ... 23}')
325        if len(parsed_input) >= 3 and \
326           (not parsed_input[2].isdigit() or int(parsed_input[2]) < 0 or int(parsed_input[2]) > 59):
327            raise CLIError('Incorrect value for --maintenance-window. '
328                           'The third number means the scheduled minute in the scheduled hour. '
329                           'Allowed values: {0, 1, ... 59}')
330
331
332def ip_address_validator(ns):
333    if (ns.end_ip_address and not _validate_ip(ns.end_ip_address)) or \
334       (ns.start_ip_address and not _validate_ip(ns.start_ip_address)):
335        raise CLIError('Incorrect value for ip address. '
336                       'Ip address should be IPv4 format. Example: 12.12.12.12. ')
337
338
339def public_access_validator(ns):
340    if ns.public_access:
341        val = ns.public_access.lower()
342        if not (val == 'all' or val == 'none' or (len(val.split('-')) == 1 and _validate_ip(val)) or
343                (len(val.split('-')) == 2 and _validate_ip(val))):
344            raise CLIError('incorrect usage: --public-access. '
345                           'Acceptable values are \'all\', \'none\',\'<startIP>\' and '
346                           '\'<startIP>-<destinationIP>\' where startIP and destinationIP ranges from '
347                           '0.0.0.0 to 255.255.255.255')
348
349
350def _validate_ip(ips):
351    """
352    # Regex not working for re.(regex, '255.255.255.255'). Hence commenting it out for now
353    regex = r'^(25[0-5]|2[0-4][0-9]|[0-1]?[0-9][0-9]?).(
354                25[0-5]|2[0-4][0-9]|[0-1]?[0-9][0-9]?).(
355                25[0-5]|2[0-4][0-9]|[0-1]?[0-9][0-9]?).(
356                25[0-5]|2[0-4][0-9]|[0-1]?[0-9][0-9]?)$'
357    """
358    parsed_input = ips.split('-')
359    if len(parsed_input) == 1:
360        return _validate_ranges_in_ip(parsed_input[0])
361    if len(parsed_input) == 2:
362        return _validate_ranges_in_ip(parsed_input[0]) and _validate_ranges_in_ip(parsed_input[1])
363    return False
364
365
366def _validate_ranges_in_ip(ip):
367    parsed_ip = ip.split('.')
368    if len(parsed_ip) == 4 and _valid_range(int(parsed_ip[0])) and _valid_range(int(parsed_ip[1])) \
369       and _valid_range(int(parsed_ip[2])) and _valid_range(int(parsed_ip[3])):
370        return True
371    return False
372
373
374def _valid_range(addr_range):
375    if 0 <= addr_range <= 255:
376        return True
377    return False
378
379
380def validate_server_name(db_context, server_name, type_):
381    client = db_context.cf_availability(db_context.cmd.cli_ctx, '_')
382
383    if not server_name:
384        return
385
386    if len(server_name) < 3 or len(server_name) > 63:
387        raise ValidationError("Server name must be at least 3 characters and at most 63 characters.")
388
389    if db_context.command_group == 'mysql':
390        # result = client.execute(db_context.location, name_availability_request={'name': server_name, 'type': type_})
391        return
392    result = client.execute(name_availability_request={'name': server_name, 'type': type_})
393
394    if not result.name_available:
395        raise ValidationError(result.message)
396
397
398def validate_private_dns_zone(db_context, server_name, private_dns_zone, private_dns_zone_suffix):
399    cmd = db_context.cmd
400    if db_context.command_group == 'postgres':
401        server_endpoint = cmd.cli_ctx.cloud.suffixes.postgresql_server_endpoint
402    else:
403        server_endpoint = cmd.cli_ctx.cloud.suffixes.mysql_server_endpoint
404    if private_dns_zone == server_name + server_endpoint:
405        raise ValidationError("private dns zone name cannot be same as the server's fully qualified domain name")
406
407    if private_dns_zone[-len(private_dns_zone_suffix):] != private_dns_zone_suffix:
408        raise ValidationError('The suffix of the private DNS zone should be "{}"'.format(private_dns_zone_suffix))
409
410    if _is_resource_name(private_dns_zone) and not is_valid_resource_name(private_dns_zone) \
411            or not _is_resource_name(private_dns_zone) and not is_valid_resource_id(private_dns_zone):
412        raise ValidationError("Check if the private dns zone name or Id is in correct format.")
413
414
415def validate_mysql_ha_enabled(server):
416    if server.storage_profile.storage_autogrow == "Disabled":
417        raise ValidationError("You need to enable auto grow first to enable high availability.")
418
419
420def validate_vnet_location(vnet, location):
421    if vnet.location != location:
422        raise ValidationError("The location of Vnet should be same as the location of the server")
423
424
425def validate_mysql_replica(cmd, server):
426    # Tier validation
427    if server.sku.tier == 'Burstable':
428        raise ValidationError("Replication for Burstable servers are not supported. "
429                              "Try using GeneralPurpose or MemoryOptimized tiers.")
430
431    # single az validation
432    _, single_az, _ = get_mysql_list_skus_info(cmd, server.location)
433    if single_az:
434        raise ValidationError("Replica can only be created for multi-availability zone regions. "
435                              "The location of the source server is in single availability zone region.")
436