1# --------------------------------------------------------------------------------------------
2# Copyright (c) Microsoft Corporation. All rights reserved.
3# Licensed under the MIT License. See License.txt in the project root for license information.
4# --------------------------------------------------------------------------------------------
5
6import json
7
8from azure.cli.command_modules.network._client_factory import network_client_factory
9from azure.cli.core.azclierror import (ResourceNotFoundError, ArgumentUsageError, InvalidArgumentValueError,
10                                       MutuallyExclusiveArgumentError)
11from azure.cli.core.commands import LongRunningOperation
12from azure.cli.core.commands.client_factory import get_subscription_id
13from azure.mgmt.network.models import ServiceEndpointPropertiesFormat
14from azure.mgmt.web.models import IpSecurityRestriction
15from knack.log import get_logger
16from msrestazure.tools import is_valid_resource_id, resource_id, parse_resource_id
17
18from ._appservice_utils import _generic_site_operation
19from .custom import get_site_configs
20
21logger = get_logger(__name__)
22
23ALLOWED_HTTP_HEADER_NAMES = ['x-forwarded-host', 'x-forwarded-for', 'x-azure-fdid', 'x-fd-healthprobe']
24
25
26def show_webapp_access_restrictions(cmd, resource_group_name, name, slot=None):
27    configs = get_site_configs(cmd, resource_group_name, name, slot)
28    access_restrictions = json.dumps(configs.ip_security_restrictions, default=lambda x: x.__dict__)
29    scm_access_restrictions = json.dumps(configs.scm_ip_security_restrictions, default=lambda x: x.__dict__)
30    access_rules = {
31        "scmIpSecurityRestrictionsUseMain": configs.scm_ip_security_restrictions_use_main,
32        "ipSecurityRestrictions": json.loads(access_restrictions),
33        "scmIpSecurityRestrictions": json.loads(scm_access_restrictions)
34    }
35    return access_rules
36
37
38def add_webapp_access_restriction(
39        cmd, resource_group_name, name, priority, rule_name=None,
40        action='Allow', ip_address=None, subnet=None,
41        vnet_name=None, description=None, scm_site=False,
42        ignore_missing_vnet_service_endpoint=False, slot=None, vnet_resource_group=None,
43        service_tag=None, http_headers=None):
44    configs = get_site_configs(cmd, resource_group_name, name, slot)
45    if (int(service_tag is not None) + int(ip_address is not None) +
46            int(subnet is not None) != 1):
47        err_msg = 'Please specify either: --subnet or --ip-address or --service-tag'
48        raise MutuallyExclusiveArgumentError(err_msg)
49
50    # get rules list
51    access_rules = configs.scm_ip_security_restrictions if scm_site else configs.ip_security_restrictions
52    # check for null
53    access_rules = access_rules or []
54
55    rule_instance = None
56    if subnet:
57        vnet_rg = vnet_resource_group if vnet_resource_group else resource_group_name
58        subnet_id = _validate_subnet(cmd.cli_ctx, subnet, vnet_name, vnet_rg)
59        if not ignore_missing_vnet_service_endpoint:
60            _ensure_subnet_service_endpoint(cmd.cli_ctx, subnet_id)
61        # check for duplicates
62        for rule in list(access_rules):
63            if rule.vnet_subnet_resource_id and rule.vnet_subnet_resource_id.lower() == subnet_id.lower():
64                raise ArgumentUsageError('Service endpoint rule for: ' + subnet_id + ' already exists. '
65                                         'Cannot add duplicate service endpoint rules.')
66        rule_instance = IpSecurityRestriction(
67            name=rule_name, vnet_subnet_resource_id=subnet_id,
68            priority=priority, action=action, tag='Default', description=description)
69        access_rules.append(rule_instance)
70    elif ip_address:
71        rule_instance = IpSecurityRestriction(
72            name=rule_name, ip_address=ip_address,
73            priority=priority, action=action, tag='Default', description=description)
74        access_rules.append(rule_instance)
75    elif service_tag:
76        rule_instance = IpSecurityRestriction(
77            name=rule_name, ip_address=service_tag,
78            priority=priority, action=action, tag='ServiceTag', description=description)
79        access_rules.append(rule_instance)
80    if http_headers:
81        logger.info(http_headers)
82        rule_instance.headers = _parse_http_headers(http_headers=http_headers)
83
84    result = _generic_site_operation(
85        cmd.cli_ctx, resource_group_name, name, 'update_configuration', slot, configs)
86    return result.scm_ip_security_restrictions if scm_site else result.ip_security_restrictions
87
88
89def remove_webapp_access_restriction(cmd, resource_group_name, name, rule_name=None, action='Allow',
90                                     ip_address=None, subnet=None, vnet_name=None, scm_site=False, slot=None,
91                                     service_tag=None):
92    configs = get_site_configs(cmd, resource_group_name, name, slot)
93    input_rule_types = (int(service_tag is not None) + int(ip_address is not None) +
94                        int(subnet is not None))
95    if input_rule_types > 1:
96        err_msg = 'Please specify either: --subnet or --ip-address or --service-tag'
97        raise MutuallyExclusiveArgumentError(err_msg)
98    rule_instance = None
99    # get rules list
100    access_rules = configs.scm_ip_security_restrictions if scm_site else configs.ip_security_restrictions
101    for rule in list(access_rules):
102        if rule_name and input_rule_types == 0:
103            if rule.name and rule.name.lower() == rule_name.lower() and rule.action == action:
104                rule_instance = rule
105                break
106        elif ip_address:
107            if rule.ip_address == ip_address and rule.action == action:
108                if rule_name and (not rule.name or (rule.name and rule.name.lower() != rule_name.lower())):
109                    continue
110                rule_instance = rule
111                break
112        elif service_tag:
113            if rule.ip_address and rule.ip_address.lower() == service_tag.lower() and rule.action == action:
114                if rule_name and (not rule.name or (rule.name and rule.name.lower() != rule_name.lower())):
115                    continue
116                rule_instance = rule
117                break
118        elif subnet:
119            subnet_id = _validate_subnet(cmd.cli_ctx, subnet, vnet_name, resource_group_name)
120            if (rule.vnet_subnet_resource_id and
121                    rule.vnet_subnet_resource_id.lower() == subnet_id.lower() and rule.action == action):
122                if rule_name and (not rule.name or (rule.name and rule.name.lower() != rule_name.lower())):
123                    continue
124                rule_instance = rule
125                break
126
127    if rule_instance is None:
128        raise ResourceNotFoundError('No rule found with the specified criteria.\n'
129                                    '- If you specify rule name and source, both must match.\n'
130                                    '- If you are trying to remove a Deny rule, '
131                                    'you must explicitly specify --action Deny')
132
133    access_rules.remove(rule_instance)
134
135    result = _generic_site_operation(
136        cmd.cli_ctx, resource_group_name, name, 'update_configuration', slot, configs)
137    return result.scm_ip_security_restrictions if scm_site else result.ip_security_restrictions
138
139
140def set_webapp_access_restriction(cmd, resource_group_name, name, use_same_restrictions_for_scm_site, slot=None):
141    configs = get_site_configs(cmd, resource_group_name, name, slot)
142    setattr(configs, 'scm_ip_security_restrictions_use_main', bool(use_same_restrictions_for_scm_site))
143
144    use_main = _generic_site_operation(
145        cmd.cli_ctx, resource_group_name, name, 'update_configuration',
146        slot, configs).scm_ip_security_restrictions_use_main
147    use_main_json = {
148        "scmIpSecurityRestrictionsUseMain": use_main
149    }
150    return use_main_json
151
152
153def _validate_subnet(cli_ctx, subnet, vnet_name, resource_group_name):
154    subnet_is_id = is_valid_resource_id(subnet)
155    if subnet_is_id and not vnet_name:
156        return subnet
157    if subnet and not subnet_is_id and vnet_name:
158        return resource_id(
159            subscription=get_subscription_id(cli_ctx),
160            resource_group=resource_group_name,
161            namespace='Microsoft.Network',
162            type='virtualNetworks',
163            name=vnet_name,
164            child_type_1='subnets',
165            child_name_1=subnet)
166    err_msg = 'Please specify either: --subnet ID or (--subnet NAME and --vnet-name NAME)'
167    raise MutuallyExclusiveArgumentError(err_msg)
168
169
170def _ensure_subnet_service_endpoint(cli_ctx, subnet_id):
171    from azure.cli.core.profiles import AD_HOC_API_VERSIONS, ResourceType
172    subnet_id_parts = parse_resource_id(subnet_id)
173    subnet_subscription_id = subnet_id_parts['subscription']
174    subnet_resource_group = subnet_id_parts['resource_group']
175    subnet_vnet_name = subnet_id_parts['name']
176    subnet_name = subnet_id_parts['resource_name']
177
178    if get_subscription_id(cli_ctx).lower() != subnet_subscription_id.lower():
179        raise ArgumentUsageError('Cannot validate subnet in different subscription for missing service endpoint.'
180                                 ' Use --ignore-missing-endpoint or -i to'
181                                 ' skip validation and manually verify service endpoint.')
182
183    vnet_client = network_client_factory(cli_ctx, api_version=AD_HOC_API_VERSIONS[ResourceType.MGMT_NETWORK]
184                                         ['appservice_ensure_subnet'])
185    subnet_obj = vnet_client.subnets.get(subnet_resource_group, subnet_vnet_name, subnet_name)
186    subnet_obj.service_endpoints = subnet_obj.service_endpoints or []
187    service_endpoint_exists = False
188    for s in subnet_obj.service_endpoints:
189        if s.service == "Microsoft.Web":
190            service_endpoint_exists = True
191            break
192
193    if not service_endpoint_exists:
194        web_service_endpoint = ServiceEndpointPropertiesFormat(service="Microsoft.Web")
195        subnet_obj.service_endpoints.append(web_service_endpoint)
196        poller = vnet_client.subnets.begin_create_or_update(
197            subnet_resource_group, subnet_vnet_name,
198            subnet_name, subnet_parameters=subnet_obj)
199        # Ensure subnet is updated to avoid update conflict
200        LongRunningOperation(cli_ctx)(poller)
201
202
203def _parse_http_headers(http_headers):
204    logger.info(http_headers)
205    header_dict = {}
206    for header_str in http_headers:
207        header = header_str.split('=')
208        if len(header) != 2:
209            err_msg = 'Http headers must have a format of `<name>=<value>`: "{}"'.format(header_str)
210            raise InvalidArgumentValueError(err_msg)
211        header_name = header[0].strip().lower()
212        header_value = header[1].strip()
213
214        if header_name not in ALLOWED_HTTP_HEADER_NAMES:
215            raise InvalidArgumentValueError('Invalid http-header name: "{}"'.format(header_name))
216
217        if header_value:
218            if header_name in header_dict:
219                if len(header_dict[header_name]) > 7:
220                    err_msg = 'Only 8 values are allowed for each http-header: "{}"'.format(header_name)
221                    raise ArgumentUsageError(err_msg)
222                header_dict[header_name].append(header_value)
223            else:
224                header_dict[header_name] = [header_value]
225    return header_dict
226