1# --------------------------------------------------------------------------------------------
2# Copyright (c) Microsoft Corporation. All rights reserved.
3# Licensed under the MIT License. See License.txt in the project root for license information.
4# --------------------------------------------------------------------------------------------
5
6from knack.util import CLIError
7
8from ._utils import validate_premium_registry
9
10
11NETWORK_RULE_NOT_SUPPORTED = 'Network rules are only supported for managed registries in Premium SKU.'
12
13
14def acr_network_rule_list(cmd, registry_name, resource_group_name=None):
15    registry, _ = validate_premium_registry(
16        cmd, registry_name, resource_group_name, NETWORK_RULE_NOT_SUPPORTED)
17    rules = registry.network_rule_set
18    delattr(rules, 'default_action')
19    return rules
20
21
22def acr_network_rule_add(cmd,
23                         client,
24                         registry_name,
25                         subnet=None,
26                         vnet_name=None,
27                         ip_address=None,
28                         resource_group_name=None):
29    registry, resource_group_name = validate_premium_registry(
30        cmd, registry_name, resource_group_name, NETWORK_RULE_NOT_SUPPORTED)
31    rules = registry.network_rule_set
32
33    if subnet or vnet_name:
34        rules.virtual_network_rules = rules.virtual_network_rules if rules.virtual_network_rules else []
35        subnet_id = _validate_subnet(cmd.cli_ctx, subnet, vnet_name, resource_group_name)
36        VirtualNetworkRule = cmd.get_models('VirtualNetworkRule')
37        rules.virtual_network_rules.append(VirtualNetworkRule(virtual_network_resource_id=subnet_id))
38    if ip_address:
39        rules.ip_rules = rules.ip_rules if rules.ip_rules else []
40        IPRule = cmd.get_models('IPRule')
41        rules.ip_rules.append(IPRule(ip_address_or_range=ip_address))
42
43    RegistryUpdateParameters = cmd.get_models('RegistryUpdateParameters')
44    parameters = RegistryUpdateParameters(network_rule_set=rules)
45    return client.update(resource_group_name, registry_name, parameters)
46
47
48def acr_network_rule_remove(cmd,
49                            client,
50                            registry_name,
51                            subnet=None,
52                            vnet_name=None,
53                            ip_address=None,
54                            resource_group_name=None):
55    registry, resource_group_name = validate_premium_registry(
56        cmd, registry_name, resource_group_name, NETWORK_RULE_NOT_SUPPORTED)
57    rules = registry.network_rule_set
58
59    if subnet or vnet_name:
60        rules.virtual_network_rules = rules.virtual_network_rules if rules.virtual_network_rules else []
61        subnet_id = _validate_subnet(cmd.cli_ctx, subnet, vnet_name, resource_group_name).lower()
62        rules.virtual_network_rules = [
63            x for x in rules.virtual_network_rules if x.virtual_network_resource_id.lower() != subnet_id]
64    if ip_address:
65        rules.ip_rules = rules.ip_rules if rules.ip_rules else []
66        rules.ip_rules = [x for x in rules.ip_rules if x.ip_address_or_range != ip_address]
67
68    RegistryUpdateParameters = cmd.get_models('RegistryUpdateParameters')
69    parameters = RegistryUpdateParameters(network_rule_set=rules)
70    return client.update(resource_group_name, registry_name, parameters)
71
72
73def _validate_subnet(cli_ctx, subnet, vnet_name, resource_group_name):
74    from msrestazure.tools import is_valid_resource_id
75    subnet_is_id = is_valid_resource_id(subnet)
76
77    if subnet_is_id and not vnet_name:
78        return subnet
79    elif subnet and not subnet_is_id and vnet_name:
80        from msrestazure.tools import resource_id
81        from azure.cli.core.commands.client_factory import get_subscription_id
82        return resource_id(
83            subscription=get_subscription_id(cli_ctx),
84            resource_group=resource_group_name,
85            namespace='Microsoft.Network',
86            type='virtualNetworks',
87            name=vnet_name,
88            child_type_1='subnets',
89            child_name_1=subnet)
90    else:
91        raise CLIError('Usage error: [--subnet ID | --subnet NAME --vnet-name NAME]')
92