1# --------------------------------------------------------------------------------------------
2# Copyright (c) Microsoft Corporation. All rights reserved.
3# Licensed under the MIT License. See License.txt in the project root for license information.
4# --------------------------------------------------------------------------------------------
5from collections import Counter, OrderedDict
6
7from msrestazure.tools import parse_resource_id, is_valid_resource_id, resource_id
8
9from knack.log import get_logger
10
11from azure.mgmt.trafficmanager.models import MonitorProtocol, ProfileStatus
12
13# pylint: disable=no-self-use,no-member,too-many-lines,unused-argument
14from azure.cli.core.commands import cached_get, cached_put, upsert_to_collection, get_property
15from azure.cli.core.commands.client_factory import get_subscription_id, get_mgmt_service_client
16
17from azure.cli.core.util import CLIError, sdk_no_wait, find_child_item, find_child_collection
18from azure.cli.core.azclierror import InvalidArgumentValueError, RequiredArgumentMissingError, \
19    UnrecognizedArgumentError, ResourceNotFoundError, CLIInternalError
20from azure.cli.core.profiles import ResourceType, supported_api_version
21
22from azure.cli.command_modules.network._client_factory import network_client_factory
23from azure.cli.command_modules.network.zone_file.parse_zone_file import parse_zone_file
24from azure.cli.command_modules.network.zone_file.make_zone_file import make_zone_file
25
26import threading
27import time
28import platform
29import subprocess
30
31logger = get_logger(__name__)
32
33
34# region Utility methods
35def _log_pprint_template(template):
36    import json
37    logger.info('==== BEGIN TEMPLATE ====')
38    logger.info(json.dumps(template, indent=2))
39    logger.info('==== END TEMPLATE ====')
40
41
42def _get_default_name(balancer, property_name, option_name):
43    return _get_default_value(balancer, property_name, option_name, True)
44
45
46def _get_default_id(balancer, property_name, option_name):
47    return _get_default_value(balancer, property_name, option_name, False)
48
49
50def _get_default_value(balancer, property_name, option_name, return_name):
51    values = [x.id for x in getattr(balancer, property_name)]
52    if len(values) > 1:
53        raise CLIError("Multiple possible values found for '{0}': {1}\nSpecify '{0}' "
54                       "explicitly.".format(option_name, ', '.join(values)))
55    if not values:
56        raise CLIError("No existing values found for '{0}'. Create one first and try "
57                       "again.".format(option_name))
58    return values[0].rsplit('/', 1)[1] if return_name else values[0]
59
60# endregion
61
62
63# region Generic list commands
64def _generic_list(cli_ctx, operation_name, resource_group_name):
65    ncf = network_client_factory(cli_ctx)
66    operation_group = getattr(ncf, operation_name)
67    if resource_group_name:
68        return operation_group.list(resource_group_name)
69
70    return operation_group.list_all()
71
72
73def list_vnet(cmd, resource_group_name=None):
74    return _generic_list(cmd.cli_ctx, 'virtual_networks', resource_group_name)
75
76
77def list_express_route_circuits(cmd, resource_group_name=None):
78    return _generic_list(cmd.cli_ctx, 'express_route_circuits', resource_group_name)
79
80
81def create_express_route_auth(cmd, resource_group_name, circuit_name, authorization_name):
82    ExpressRouteCircuitAuthorization = cmd.get_models('ExpressRouteCircuitAuthorization')
83
84    client = network_client_factory(cmd.cli_ctx).express_route_circuit_authorizations
85    return client.begin_create_or_update(resource_group_name,
86                                         circuit_name,
87                                         authorization_name,
88                                         ExpressRouteCircuitAuthorization())
89
90
91def list_lbs(cmd, resource_group_name=None):
92    return _generic_list(cmd.cli_ctx, 'load_balancers', resource_group_name)
93
94
95def list_nics(cmd, resource_group_name=None):
96    return _generic_list(cmd.cli_ctx, 'network_interfaces', resource_group_name)
97
98
99def list_nsgs(cmd, resource_group_name=None):
100    return _generic_list(cmd.cli_ctx, 'network_security_groups', resource_group_name)
101
102
103def list_nsg_rules(cmd, resource_group_name, network_security_group_name, include_default=False):
104    client = network_client_factory(cmd.cli_ctx).network_security_groups
105    nsg = client.get(resource_group_name, network_security_group_name)
106    rules = nsg.security_rules
107    if include_default:
108        rules = rules + nsg.default_security_rules
109    return rules
110
111
112def list_custom_ip_prefixes(cmd, resource_group_name=None):
113    return _generic_list(cmd.cli_ctx, 'custom_ip_prefixes', resource_group_name)
114
115
116def list_public_ips(cmd, resource_group_name=None):
117    return _generic_list(cmd.cli_ctx, 'public_ip_addresses', resource_group_name)
118
119
120def list_public_ip_prefixes(cmd, resource_group_name=None):
121    return _generic_list(cmd.cli_ctx, 'public_ip_prefixes', resource_group_name)
122
123
124def list_route_tables(cmd, resource_group_name=None):
125    return _generic_list(cmd.cli_ctx, 'route_tables', resource_group_name)
126
127
128def list_application_gateways(cmd, resource_group_name=None):
129    return _generic_list(cmd.cli_ctx, 'application_gateways', resource_group_name)
130
131
132def list_network_watchers(cmd, resource_group_name=None):
133    return _generic_list(cmd.cli_ctx, 'network_watchers', resource_group_name)
134
135# endregion
136
137
138# region ApplicationGateways
139# pylint: disable=too-many-locals
140def _is_v2_sku(sku):
141    return 'v2' in sku
142
143
144# pylint: disable=too-many-statements
145def create_application_gateway(cmd, application_gateway_name, resource_group_name, location=None,
146                               tags=None, no_wait=False, capacity=2,
147                               cert_data=None, cert_password=None, key_vault_secret_id=None,
148                               frontend_port=None, http_settings_cookie_based_affinity='disabled',
149                               http_settings_port=80, http_settings_protocol='Http',
150                               routing_rule_type='Basic', servers=None,
151                               sku=None,
152                               private_ip_address=None, public_ip_address=None,
153                               public_ip_address_allocation=None,
154                               subnet='default', subnet_address_prefix='10.0.0.0/24',
155                               virtual_network_name=None, vnet_address_prefix='10.0.0.0/16',
156                               public_ip_address_type=None, subnet_type=None, validate=False,
157                               connection_draining_timeout=0, enable_http2=None, min_capacity=None, zones=None,
158                               custom_error_pages=None, firewall_policy=None, max_capacity=None,
159                               user_assigned_identity=None,
160                               enable_private_link=False,
161                               private_link_ip_address=None,
162                               private_link_subnet='PrivateLinkDefaultSubnet',
163                               private_link_subnet_prefix='10.0.1.0/24',
164                               private_link_primary=None,
165                               trusted_client_cert=None,
166                               ssl_profile=None,
167                               ssl_profile_id=None,
168                               ssl_cert_name=None):
169    from azure.cli.core.util import random_string
170    from azure.cli.core.commands.arm import ArmTemplateBuilder
171    from azure.cli.command_modules.network._template_builder import (
172        build_application_gateway_resource, build_public_ip_resource, build_vnet_resource)
173
174    DeploymentProperties = cmd.get_models('DeploymentProperties', resource_type=ResourceType.MGMT_RESOURCE_RESOURCES)
175    IPAllocationMethod = cmd.get_models('IPAllocationMethod')
176
177    tags = tags or {}
178    sku_tier = sku.split('_', 1)[0] if not _is_v2_sku(sku) else sku
179    http_listener_protocol = 'https' if (cert_data or key_vault_secret_id) else 'http'
180    private_ip_allocation = 'Static' if private_ip_address else 'Dynamic'
181    virtual_network_name = virtual_network_name or '{}Vnet'.format(application_gateway_name)
182
183    # Build up the ARM template
184    master_template = ArmTemplateBuilder()
185    ag_dependencies = []
186
187    public_ip_id = public_ip_address if is_valid_resource_id(public_ip_address) else None
188    subnet_id = subnet if is_valid_resource_id(subnet) else None
189    private_ip_allocation = IPAllocationMethod.static.value if private_ip_address \
190        else IPAllocationMethod.dynamic.value
191
192    network_id_template = resource_id(
193        subscription=get_subscription_id(cmd.cli_ctx), resource_group=resource_group_name,
194        namespace='Microsoft.Network')
195
196    if subnet_type == 'new':
197        ag_dependencies.append('Microsoft.Network/virtualNetworks/{}'.format(virtual_network_name))
198        vnet = build_vnet_resource(
199            cmd, virtual_network_name, location, tags, vnet_address_prefix, subnet,
200            subnet_address_prefix,
201            enable_private_link=enable_private_link,
202            private_link_subnet=private_link_subnet,
203            private_link_subnet_prefix=private_link_subnet_prefix)
204        master_template.add_resource(vnet)
205        subnet_id = '{}/virtualNetworks/{}/subnets/{}'.format(network_id_template,
206                                                              virtual_network_name, subnet)
207
208    if public_ip_address_type == 'new':
209        ag_dependencies.append('Microsoft.Network/publicIpAddresses/{}'.format(public_ip_address))
210        public_ip_sku = None
211        if _is_v2_sku(sku):
212            public_ip_sku = 'Standard'
213            public_ip_address_allocation = 'Static'
214        master_template.add_resource(build_public_ip_resource(cmd, public_ip_address, location,
215                                                              tags,
216                                                              public_ip_address_allocation,
217                                                              None, public_ip_sku, None))
218        public_ip_id = '{}/publicIPAddresses/{}'.format(network_id_template,
219                                                        public_ip_address)
220
221    private_link_subnet_id = None
222    private_link_name = 'PrivateLinkDefaultConfiguration'
223    private_link_ip_allocation_method = 'Dynamic'
224    if enable_private_link:
225        private_link_subnet_id = '{}/virtualNetworks/{}/subnets/{}'.format(network_id_template,
226                                                                           virtual_network_name,
227                                                                           private_link_subnet)
228        private_link_ip_allocation_method = IPAllocationMethod.static.value if private_link_ip_address \
229            else IPAllocationMethod.dynamic.value
230
231    app_gateway_resource = build_application_gateway_resource(
232        cmd, application_gateway_name, location, tags, sku, sku_tier, capacity, servers, frontend_port,
233        private_ip_address, private_ip_allocation, cert_data, cert_password, key_vault_secret_id,
234        http_settings_cookie_based_affinity, http_settings_protocol, http_settings_port,
235        http_listener_protocol, routing_rule_type, public_ip_id, subnet_id,
236        connection_draining_timeout, enable_http2, min_capacity, zones, custom_error_pages,
237        firewall_policy, max_capacity, user_assigned_identity,
238        enable_private_link, private_link_name,
239        private_link_ip_address, private_link_ip_allocation_method, private_link_primary,
240        private_link_subnet_id, trusted_client_cert, ssl_profile, ssl_profile_id, ssl_cert_name)
241
242    app_gateway_resource['dependsOn'] = ag_dependencies
243    master_template.add_variable(
244        'appGwID',
245        "[resourceId('Microsoft.Network/applicationGateways', '{}')]".format(
246            application_gateway_name))
247    master_template.add_resource(app_gateway_resource)
248    master_template.add_output('applicationGateway', application_gateway_name, output_type='object')
249    if cert_password:
250        master_template.add_secure_parameter('certPassword', cert_password)
251
252    template = master_template.build()
253    parameters = master_template.build_parameters()
254
255    # deploy ARM template
256    deployment_name = 'ag_deploy_' + random_string(32)
257    client = get_mgmt_service_client(cmd.cli_ctx, ResourceType.MGMT_RESOURCE_RESOURCES).deployments
258    properties = DeploymentProperties(template=template, parameters=parameters, mode='incremental')
259    Deployment = cmd.get_models('Deployment', resource_type=ResourceType.MGMT_RESOURCE_RESOURCES)
260    deployment = Deployment(properties=properties)
261
262    if validate:
263        _log_pprint_template(template)
264        if cmd.supported_api_version(min_api='2019-10-01', resource_type=ResourceType.MGMT_RESOURCE_RESOURCES):
265            from azure.cli.core.commands import LongRunningOperation
266            validation_poller = client.begin_validate(resource_group_name, deployment_name, deployment)
267            return LongRunningOperation(cmd.cli_ctx)(validation_poller)
268
269        return client.validate(resource_group_name, deployment_name, deployment)
270
271    return sdk_no_wait(no_wait, client.begin_create_or_update, resource_group_name, deployment_name, deployment)
272
273
274def update_application_gateway(cmd, instance, sku=None, capacity=None, tags=None, enable_http2=None, min_capacity=None,
275                               custom_error_pages=None, max_capacity=None):
276    if sku is not None:
277        instance.sku.tier = sku.split('_', 1)[0] if not _is_v2_sku(sku) else sku
278
279    try:
280        if min_capacity is not None:
281            instance.autoscale_configuration.min_capacity = min_capacity
282        if max_capacity is not None:
283            instance.autoscale_configuration.max_capacity = max_capacity
284    except AttributeError:
285        instance.autoscale_configuration = {
286            'min_capacity': min_capacity,
287            'max_capacity': max_capacity
288        }
289
290    with cmd.update_context(instance) as c:
291        c.set_param('sku.name', sku)
292        c.set_param('sku.capacity', capacity)
293        c.set_param('tags', tags)
294        c.set_param('enable_http2', enable_http2)
295        c.set_param('custom_error_configurations', custom_error_pages)
296    return instance
297
298
299def create_ag_authentication_certificate(cmd, resource_group_name, application_gateway_name, item_name,
300                                         cert_data, no_wait=False):
301    AuthCert = cmd.get_models('ApplicationGatewayAuthenticationCertificate')
302    ncf = network_client_factory(cmd.cli_ctx).application_gateways
303    ag = ncf.get(resource_group_name, application_gateway_name)
304    new_cert = AuthCert(data=cert_data, name=item_name)
305    upsert_to_collection(ag, 'authentication_certificates', new_cert, 'name')
306    return sdk_no_wait(no_wait, ncf.begin_create_or_update, resource_group_name, application_gateway_name, ag)
307
308
309def update_ag_authentication_certificate(instance, parent, item_name, cert_data):
310    instance.data = cert_data
311    return parent
312
313
314def create_ag_backend_address_pool(cmd, resource_group_name, application_gateway_name, item_name,
315                                   servers=None, no_wait=False):
316    ApplicationGatewayBackendAddressPool = cmd.get_models('ApplicationGatewayBackendAddressPool')
317    ncf = network_client_factory(cmd.cli_ctx)
318    ag = ncf.application_gateways.get(resource_group_name, application_gateway_name)
319    new_pool = ApplicationGatewayBackendAddressPool(name=item_name, backend_addresses=servers)
320    upsert_to_collection(ag, 'backend_address_pools', new_pool, 'name')
321    return sdk_no_wait(no_wait, ncf.application_gateways.begin_create_or_update,
322                       resource_group_name, application_gateway_name, ag)
323
324
325def update_ag_backend_address_pool(instance, parent, item_name, servers=None):
326    if servers is not None:
327        instance.backend_addresses = servers
328    return parent
329
330
331def create_ag_frontend_ip_configuration(cmd, resource_group_name, application_gateway_name, item_name,
332                                        public_ip_address=None, subnet=None,
333                                        virtual_network_name=None, private_ip_address=None,
334                                        private_ip_address_allocation=None, no_wait=False):
335    ApplicationGatewayFrontendIPConfiguration, SubResource = cmd.get_models(
336        'ApplicationGatewayFrontendIPConfiguration', 'SubResource')
337    ncf = network_client_factory(cmd.cli_ctx)
338    ag = ncf.application_gateways.get(resource_group_name, application_gateway_name)
339    if public_ip_address:
340        new_config = ApplicationGatewayFrontendIPConfiguration(
341            name=item_name,
342            public_ip_address=SubResource(id=public_ip_address))
343    else:
344        new_config = ApplicationGatewayFrontendIPConfiguration(
345            name=item_name,
346            private_ip_address=private_ip_address if private_ip_address else None,
347            private_ip_allocation_method='Static' if private_ip_address else 'Dynamic',
348            subnet=SubResource(id=subnet))
349    upsert_to_collection(ag, 'frontend_ip_configurations', new_config, 'name')
350    return sdk_no_wait(no_wait, ncf.application_gateways.begin_create_or_update,
351                       resource_group_name, application_gateway_name, ag)
352
353
354def update_ag_frontend_ip_configuration(cmd, instance, parent, item_name, public_ip_address=None,
355                                        subnet=None, virtual_network_name=None,
356                                        private_ip_address=None):
357    SubResource = cmd.get_models('SubResource')
358    if public_ip_address is not None:
359        instance.public_ip_address = SubResource(id=public_ip_address)
360    if subnet is not None:
361        instance.subnet = SubResource(id=subnet)
362    if private_ip_address is not None:
363        instance.private_ip_address = private_ip_address
364        instance.private_ip_allocation_method = 'Static'
365    return parent
366
367
368def create_ag_frontend_port(cmd, resource_group_name, application_gateway_name, item_name, port,
369                            no_wait=False):
370    ApplicationGatewayFrontendPort = cmd.get_models('ApplicationGatewayFrontendPort')
371    ncf = network_client_factory(cmd.cli_ctx)
372    ag = ncf.application_gateways.get(resource_group_name, application_gateway_name)
373    new_port = ApplicationGatewayFrontendPort(name=item_name, port=port)
374    upsert_to_collection(ag, 'frontend_ports', new_port, 'name')
375    return sdk_no_wait(no_wait, ncf.application_gateways.begin_create_or_update,
376                       resource_group_name, application_gateway_name, ag)
377
378
379def update_ag_frontend_port(instance, parent, item_name, port=None):
380    if port is not None:
381        instance.port = port
382    return parent
383
384
385def create_ag_http_listener(cmd, resource_group_name, application_gateway_name, item_name,
386                            frontend_port, frontend_ip=None, host_name=None, ssl_cert=None,
387                            ssl_profile=None, firewall_policy=None, no_wait=False, host_names=None):
388    ApplicationGatewayHttpListener, SubResource = cmd.get_models('ApplicationGatewayHttpListener', 'SubResource')
389    ncf = network_client_factory(cmd.cli_ctx)
390    ag = ncf.application_gateways.get(resource_group_name, application_gateway_name)
391    if not frontend_ip:
392        frontend_ip = _get_default_id(ag, 'frontend_ip_configurations', '--frontend-ip')
393    new_listener = ApplicationGatewayHttpListener(
394        name=item_name,
395        frontend_ip_configuration=SubResource(id=frontend_ip),
396        frontend_port=SubResource(id=frontend_port),
397        host_name=host_name,
398        require_server_name_indication=True if ssl_cert and host_name else None,
399        protocol='https' if ssl_cert else 'http',
400        ssl_certificate=SubResource(id=ssl_cert) if ssl_cert else None,
401        host_names=host_names
402    )
403
404    if cmd.supported_api_version(min_api='2019-09-01'):
405        new_listener.firewall_policy = SubResource(id=firewall_policy) if firewall_policy else None
406
407    if cmd.supported_api_version(min_api='2020-06-01'):
408        new_listener.ssl_profile = SubResource(id=ssl_profile) if ssl_profile else None
409
410    upsert_to_collection(ag, 'http_listeners', new_listener, 'name')
411    return sdk_no_wait(no_wait, ncf.application_gateways.begin_create_or_update,
412                       resource_group_name, application_gateway_name, ag)
413
414
415def update_ag_http_listener(cmd, instance, parent, item_name, frontend_ip=None, frontend_port=None,
416                            host_name=None, ssl_cert=None, ssl_profile=None, firewall_policy=None, host_names=None):
417    SubResource = cmd.get_models('SubResource')
418    if frontend_ip is not None:
419        instance.frontend_ip_configuration = SubResource(id=frontend_ip)
420    if frontend_port is not None:
421        instance.frontend_port = SubResource(id=frontend_port)
422    if ssl_cert is not None:
423        if ssl_cert:
424            instance.ssl_certificate = SubResource(id=ssl_cert)
425            instance.protocol = 'Https'
426        else:
427            instance.ssl_certificate = None
428            instance.protocol = 'Http'
429    if host_name is not None:
430        instance.host_name = host_name or None
431
432    if cmd.supported_api_version(min_api='2019-09-01'):
433        if firewall_policy is not None:
434            instance.firewall_policy = SubResource(id=firewall_policy)
435
436    if cmd.supported_api_version(min_api='2020-06-01'):
437        if ssl_profile is not None:
438            instance.ssl_profile = SubResource(id=ssl_profile)
439
440    if host_names is not None:
441        instance.host_names = host_names or None
442
443    instance.require_server_name_indication = instance.host_name and instance.protocol.lower() == 'https'
444    return parent
445
446
447def assign_ag_identity(cmd, resource_group_name, application_gateway_name,
448                       user_assigned_identity, no_wait=False):
449    ncf = network_client_factory(cmd.cli_ctx).application_gateways
450    ag = ncf.get(resource_group_name, application_gateway_name)
451    ManagedServiceIdentity, ManagedServiceIdentityUserAssignedIdentitiesValue = \
452        cmd.get_models('ManagedServiceIdentity',
453                       'Components1Jq1T4ISchemasManagedserviceidentityPropertiesUserassignedidentitiesAdditionalproperties')  # pylint: disable=line-too-long
454    user_assigned_indentity_instance = ManagedServiceIdentityUserAssignedIdentitiesValue()
455
456    user_assigned_identities_instance = dict()
457
458    user_assigned_identities_instance[user_assigned_identity] = user_assigned_indentity_instance
459
460    identity_instance = ManagedServiceIdentity(
461        type="UserAssigned",
462        user_assigned_identities=user_assigned_identities_instance
463    )
464    ag.identity = identity_instance
465
466    return sdk_no_wait(no_wait, ncf.begin_create_or_update, resource_group_name, application_gateway_name, ag)
467
468
469def remove_ag_identity(cmd, resource_group_name, application_gateway_name, no_wait=False):
470    ncf = network_client_factory(cmd.cli_ctx).application_gateways
471    ag = ncf.get(resource_group_name, application_gateway_name)
472    if ag.identity is None:
473        logger.warning("This command will be ignored. The identity doesn't exist.")
474    ag.identity = None
475
476    return sdk_no_wait(no_wait, ncf.begin_create_or_update, resource_group_name, application_gateway_name, ag)
477
478
479def show_ag_identity(cmd, resource_group_name, application_gateway_name):
480    ncf = network_client_factory(cmd.cli_ctx).application_gateways
481    ag = ncf.get(resource_group_name, application_gateway_name)
482    if ag.identity is None:
483        raise CLIError("Please first use 'az network application-gateway identity assign` to init the identity.")
484    return ag.identity
485
486
487def add_ag_private_link(cmd,
488                        resource_group_name,
489                        application_gateway_name,
490                        frontend_ip,
491                        private_link_name,
492                        private_link_subnet_name_or_id,
493                        private_link_subnet_prefix=None,
494                        private_link_primary=None,
495                        private_link_ip_address=None,
496                        no_wait=False):
497    (SubResource, IPAllocationMethod, Subnet,
498     ApplicationGatewayPrivateLinkConfiguration,
499     ApplicationGatewayPrivateLinkIpConfiguration) = cmd.get_models(
500         'SubResource', 'IPAllocationMethod', 'Subnet',
501         'ApplicationGatewayPrivateLinkConfiguration', 'ApplicationGatewayPrivateLinkIpConfiguration')
502
503    ncf = network_client_factory(cmd.cli_ctx)
504
505    appgw = ncf.application_gateways.get(resource_group_name, application_gateway_name)
506    private_link_config_id = resource_id(
507        subscription=get_subscription_id(cmd.cli_ctx),
508        resource_group=resource_group_name,
509        namespace='Microsoft.Network',
510        type='applicationGateways',
511        name=appgw.name,
512        child_type_1='privateLinkConfigurations',
513        child_name_1=private_link_name
514    )
515
516    if not any(fic for fic in appgw.frontend_ip_configurations if fic.name == frontend_ip):
517        raise CLIError("Frontend IP doesn't exist")
518
519    for fic in appgw.frontend_ip_configurations:
520        if fic.private_link_configuration and fic.private_link_configuration.id == private_link_config_id:
521            raise CLIError('Frontend IP already reference an existing Private Link')
522        if fic.name == frontend_ip:
523            break
524    else:
525        raise CLIError("Frontend IP doesn't exist")
526
527    for pl in appgw.private_link_configurations:
528        if pl.name == private_link_name:
529            raise CLIError('Private Link name duplicates')
530
531    # get the virtual network of this application gateway
532    vnet_name = parse_resource_id(appgw.gateway_ip_configurations[0].subnet.id)['name']
533    vnet = ncf.virtual_networks.get(resource_group_name, vnet_name)
534
535    # prepare the subnet for new private link
536    for subnet in vnet.subnets:
537        if subnet.name == private_link_subnet_name_or_id:
538            raise CLIError('Subnet duplicates')
539        if subnet.address_prefix == private_link_subnet_prefix:
540            raise CLIError('Subnet prefix duplicates')
541        if subnet.address_prefixes and private_link_subnet_prefix in subnet.address_prefixes:
542            raise CLIError('Subnet prefix duplicates')
543
544    if is_valid_resource_id(private_link_subnet_name_or_id):
545        private_link_subnet_id = private_link_subnet_name_or_id
546    else:
547        private_link_subnet = Subnet(name=private_link_subnet_name_or_id,
548                                     address_prefix=private_link_subnet_prefix,
549                                     private_link_service_network_policies='Disabled')
550        private_link_subnet_id = resource_id(
551            subscription=get_subscription_id(cmd.cli_ctx),
552            resource_group=resource_group_name,
553            namespace='Microsoft.Network',
554            type='virtualNetworks',
555            name=vnet_name,
556            child_type_1='subnets',
557            child_name_1=private_link_subnet_name_or_id
558        )
559        vnet.subnets.append(private_link_subnet)
560        ncf.virtual_networks.begin_create_or_update(resource_group_name, vnet_name, vnet)
561
562    private_link_ip_allocation_method = IPAllocationMethod.static.value if private_link_ip_address \
563        else IPAllocationMethod.dynamic.value
564    private_link_ip_config = ApplicationGatewayPrivateLinkIpConfiguration(
565        name='PrivateLinkDefaultIPConfiguration',
566        private_ip_address=private_link_ip_address,
567        private_ip_allocation_method=private_link_ip_allocation_method,
568        subnet=SubResource(id=private_link_subnet_id),
569        primary=private_link_primary
570    )
571    private_link_config = ApplicationGatewayPrivateLinkConfiguration(
572        name=private_link_name,
573        ip_configurations=[private_link_ip_config]
574    )
575
576    # associate the private link with the frontend IP configuration
577    for fic in appgw.frontend_ip_configurations:
578        if fic.name == frontend_ip:
579            fic.private_link_configuration = SubResource(id=private_link_config_id)
580
581    appgw.private_link_configurations.append(private_link_config)
582
583    return sdk_no_wait(no_wait,
584                       ncf.application_gateways.begin_create_or_update,
585                       resource_group_name,
586                       application_gateway_name, appgw)
587
588
589def show_ag_private_link(cmd,
590                         resource_group_name,
591                         application_gateway_name,
592                         private_link_name):
593    ncf = network_client_factory(cmd.cli_ctx)
594
595    appgw = ncf.application_gateways.get(resource_group_name, application_gateway_name)
596
597    target_private_link = None
598    for pl in appgw.private_link_configurations:
599        if pl.name == private_link_name:
600            target_private_link = pl
601            break
602    else:
603        raise CLIError("Priavte Link doesn't exist")
604
605    return target_private_link
606
607
608def list_ag_private_link(cmd,
609                         resource_group_name,
610                         application_gateway_name):
611    ncf = network_client_factory(cmd.cli_ctx)
612
613    appgw = ncf.application_gateways.get(resource_group_name, application_gateway_name)
614    return appgw.private_link_configurations
615
616
617def remove_ag_private_link(cmd,
618                           resource_group_name,
619                           application_gateway_name,
620                           private_link_name,
621                           no_wait=False):
622    ncf = network_client_factory(cmd.cli_ctx)
623
624    appgw = ncf.application_gateways.get(resource_group_name, application_gateway_name)
625
626    removed_private_link = None
627
628    for pl in appgw.private_link_configurations:
629        if pl.name == private_link_name:
630            removed_private_link = pl
631            break
632    else:
633        raise CLIError("Priavte Link doesn't exist")
634
635    for fic in appgw.frontend_ip_configurations:
636        if fic.private_link_configuration and fic.private_link_configuration.id == removed_private_link.id:
637            fic.private_link_configuration = None
638
639    # the left vnet have to delete manually
640    # rs = parse_resource_id(removed_private_link.ip_configurations[0].subnet.id)
641    # vnet_resource_group, vnet_name, subnet = rs['resource_group'], rs['name'], rs['child_name_1']
642    # ncf.subnets.delete(vnet_resource_group, vnet_name, subnet)
643
644    appgw.private_link_configurations.remove(removed_private_link)
645    return sdk_no_wait(no_wait,
646                       ncf.application_gateways.begin_create_or_update,
647                       resource_group_name,
648                       application_gateway_name,
649                       appgw)
650
651
652# region application-gateway trusted-client-certificates
653def add_trusted_client_certificate(cmd, resource_group_name, application_gateway_name, client_cert_name,
654                                   client_cert_data, no_wait=False):
655    ncf = network_client_factory(cmd.cli_ctx)
656    appgw = ncf.application_gateways.get(resource_group_name, application_gateway_name)
657    ApplicationGatewayTrustedClientCertificate = cmd.get_models('ApplicationGatewayTrustedClientCertificate')
658    cert = ApplicationGatewayTrustedClientCertificate(name=client_cert_name, data=client_cert_data)
659    appgw.trusted_client_certificates.append(cert)
660
661    return sdk_no_wait(no_wait, ncf.application_gateways.begin_create_or_update, resource_group_name,
662                       application_gateway_name, appgw)
663
664
665def update_trusted_client_certificate(cmd, resource_group_name, application_gateway_name, client_cert_name,
666                                      client_cert_data, no_wait=False):
667    ncf = network_client_factory(cmd.cli_ctx)
668    appgw = ncf.application_gateways.get(resource_group_name, application_gateway_name)
669
670    for cert in appgw.trusted_client_certificates:
671        if cert.name == client_cert_name:
672            cert.data = client_cert_data
673            break
674    else:
675        raise ResourceNotFoundError(f"Trusted client certificate {client_cert_name} doesn't exist")
676
677    return sdk_no_wait(no_wait, ncf.application_gateways.begin_create_or_update, resource_group_name,
678                       application_gateway_name, appgw)
679
680
681def list_trusted_client_certificate(cmd, resource_group_name, application_gateway_name):
682    ncf = network_client_factory(cmd.cli_ctx)
683    appgw = ncf.application_gateways.get(resource_group_name, application_gateway_name)
684    return appgw.trusted_client_certificates
685
686
687def remove_trusted_client_certificate(cmd, resource_group_name, application_gateway_name, client_cert_name,
688                                      no_wait=False):
689    ncf = network_client_factory(cmd.cli_ctx)
690    appgw = ncf.application_gateways.get(resource_group_name, application_gateway_name)
691
692    for cert in appgw.trusted_client_certificates:
693        if cert.name == client_cert_name:
694            appgw.trusted_client_certificates.remove(cert)
695            break
696    else:
697        raise ResourceNotFoundError(f"Trusted client certificate {client_cert_name} doesn't exist")
698
699    return sdk_no_wait(no_wait, ncf.application_gateways.begin_create_or_update, resource_group_name,
700                       application_gateway_name, appgw)
701
702
703def show_trusted_client_certificate(cmd, resource_group_name, application_gateway_name, client_cert_name):
704    ncf = network_client_factory(cmd.cli_ctx)
705    appgw = ncf.application_gateways.get(resource_group_name, application_gateway_name)
706
707    instance = None
708    for cert in appgw.trusted_client_certificates:
709        if cert.name == client_cert_name:
710            instance = cert
711            break
712    else:
713        raise ResourceNotFoundError(f"Trusted client certificate {client_cert_name} doesn't exist")
714
715    return instance
716
717
718def show_ag_backend_health(cmd, client, resource_group_name, application_gateway_name, expand=None,
719                           protocol=None, host=None, path=None, timeout=None, host_name_from_http_settings=None,
720                           match_body=None, match_status_codes=None, address_pool=None, http_settings=None):
721    from azure.cli.core.commands import LongRunningOperation
722    on_demand_arguments = {protocol, host, path, timeout, host_name_from_http_settings, match_body, match_status_codes,
723                           address_pool, http_settings}
724    if on_demand_arguments.difference({None}) and cmd.supported_api_version(min_api='2019-04-01'):
725        SubResource, ApplicationGatewayOnDemandProbe, ApplicationGatewayProbeHealthResponseMatch = cmd.get_models(
726            "SubResource", "ApplicationGatewayOnDemandProbe", "ApplicationGatewayProbeHealthResponseMatch")
727        probe_request = ApplicationGatewayOnDemandProbe(
728            protocol=protocol,
729            host=host,
730            path=path,
731            timeout=timeout,
732            pick_host_name_from_backend_http_settings=host_name_from_http_settings
733        )
734        if match_body is not None or match_status_codes is not None:
735            probe_request.match = ApplicationGatewayProbeHealthResponseMatch(
736                body=match_body,
737                status_codes=match_status_codes,
738            )
739        if address_pool is not None:
740            if not is_valid_resource_id(address_pool):
741                address_pool = resource_id(
742                    subscription=get_subscription_id(cmd.cli_ctx),
743                    resource_group=resource_group_name,
744                    namespace='Microsoft.Network',
745                    type='applicationGateways',
746                    name=application_gateway_name,
747                    child_type_1='backendAddressPools',
748                    child_name_1=address_pool
749                )
750            probe_request.backend_address_pool = SubResource(id=address_pool)
751        if http_settings is not None:
752            if not is_valid_resource_id(http_settings):
753                http_settings = resource_id(
754                    subscription=get_subscription_id(cmd.cli_ctx),
755                    resource_group=resource_group_name,
756                    namespace='Microsoft.Network',
757                    type='applicationGateways',
758                    name=application_gateway_name,
759                    child_type_1='backendHttpSettingsCollection',
760                    child_name_1=http_settings
761                )
762            probe_request.backend_http_settings = SubResource(id=http_settings)
763        return LongRunningOperation(cmd.cli_ctx)(client.begin_backend_health_on_demand(
764            resource_group_name, application_gateway_name, probe_request, expand))
765
766    return LongRunningOperation(cmd.cli_ctx)(client.begin_backend_health(
767        resource_group_name, application_gateway_name, expand))
768
769# endregion
770
771
772# region application-gateway ssl-profile
773def add_ssl_profile(cmd, resource_group_name, application_gateway_name, ssl_profile_name, policy_name=None,
774                    policy_type=None, min_protocol_version=None, cipher_suites=None, disabled_ssl_protocols=None,
775                    trusted_client_certificates=None, client_auth_configuration=None, no_wait=False):
776    ncf = network_client_factory(cmd.cli_ctx)
777    appgw = ncf.application_gateways.get(resource_group_name, application_gateway_name)
778    (SubResource,
779     ApplicationGatewaySslPolicy,
780     ApplicationGatewayClientAuthConfiguration,
781     ApplicationGatewaySslProfile) = cmd.get_models('SubResource',
782                                                    'ApplicationGatewaySslPolicy',
783                                                    'ApplicationGatewayClientAuthConfiguration',
784                                                    'ApplicationGatewaySslProfile')
785    sr_trusted_client_certificates = [SubResource(id=item) for item in
786                                      trusted_client_certificates] if trusted_client_certificates else None
787    ssl_policy = ApplicationGatewaySslPolicy(policy_name=policy_name, policy_type=policy_type,
788                                             min_protocol_version=min_protocol_version,
789                                             cipher_suites=cipher_suites, disabled_ssl_protocols=disabled_ssl_protocols)
790    client_auth = ApplicationGatewayClientAuthConfiguration(
791        verify_client_cert_issuer_dn=client_auth_configuration) if client_auth_configuration else None
792    ssl_profile = ApplicationGatewaySslProfile(trusted_client_certificates=sr_trusted_client_certificates,
793                                               ssl_policy=ssl_policy, client_auth_configuration=client_auth,
794                                               name=ssl_profile_name)
795    appgw.ssl_profiles.append(ssl_profile)
796    return sdk_no_wait(no_wait, ncf.application_gateways.begin_create_or_update, resource_group_name,
797                       application_gateway_name, appgw)
798
799
800def update_ssl_profile(cmd, resource_group_name, application_gateway_name, ssl_profile_name, policy_name=None,
801                       policy_type=None, min_protocol_version=None, cipher_suites=None, disabled_ssl_protocols=None,
802                       trusted_client_certificates=None, client_auth_configuration=None, no_wait=False):
803    ncf = network_client_factory(cmd.cli_ctx)
804    appgw = ncf.application_gateways.get(resource_group_name, application_gateway_name)
805
806    instance = None
807    for profile in appgw.ssl_profiles:
808        if profile.name == ssl_profile_name:
809            instance = profile
810            break
811    else:
812        raise ResourceNotFoundError(f"Ssl profiles {ssl_profile_name} doesn't exist")
813
814    if policy_name is not None:
815        instance.ssl_policy.policy_name = policy_name
816    if policy_type is not None:
817        instance.ssl_policy.policy_type = policy_type
818    if min_protocol_version is not None:
819        instance.ssl_policy.min_protocol_version = min_protocol_version
820    if cipher_suites is not None:
821        instance.ssl_policy.cipher_suites = cipher_suites
822    if disabled_ssl_protocols is not None:
823        instance.ssl_policy.disabled_ssl_protocols = disabled_ssl_protocols
824    if trusted_client_certificates is not None:
825        SubResource = cmd.get_models('SubResource')
826        instance.trusted_client_certificates = [SubResource(id=item) for item in trusted_client_certificates]
827    if client_auth_configuration is not None:
828        ApplicationGatewayClientAuthConfiguration = cmd.get_models('ApplicationGatewayClientAuthConfiguration')
829        instance.client_auth_configuration = ApplicationGatewayClientAuthConfiguration(
830            verify_client_cert_issuer_dn=(client_auth_configuration == 'True')
831        )
832
833    return sdk_no_wait(no_wait, ncf.application_gateways.begin_create_or_update, resource_group_name,
834                       application_gateway_name, appgw)
835
836
837def list_ssl_profile(cmd, resource_group_name, application_gateway_name):
838    ncf = network_client_factory(cmd.cli_ctx)
839    appgw = ncf.application_gateways.get(resource_group_name, application_gateway_name)
840    return appgw.ssl_profiles
841
842
843def remove_ssl_profile(cmd, resource_group_name, application_gateway_name, ssl_profile_name, no_wait=False):
844    ncf = network_client_factory(cmd.cli_ctx)
845    appgw = ncf.application_gateways.get(resource_group_name, application_gateway_name)
846
847    for profile in appgw.ssl_profiles:
848        if profile.name == ssl_profile_name:
849            appgw.ssl_profiles.remove(profile)
850            break
851    else:
852        raise ResourceNotFoundError(f"Ssl profiles {ssl_profile_name} doesn't exist")
853
854    return sdk_no_wait(no_wait, ncf.application_gateways.begin_create_or_update, resource_group_name,
855                       application_gateway_name, appgw)
856
857
858def show_ssl_profile(cmd, resource_group_name, application_gateway_name, ssl_profile_name):
859    ncf = network_client_factory(cmd.cli_ctx)
860    appgw = ncf.application_gateways.get(resource_group_name, application_gateway_name)
861
862    instance = None
863    for profile in appgw.ssl_profiles:
864        if profile.name == ssl_profile_name:
865            instance = profile
866            break
867    else:
868        raise ResourceNotFoundError(f"Ssl profiles {ssl_profile_name} doesn't exist")
869    return instance
870
871# endregion
872
873
874def add_ag_private_link_ip(cmd,
875                           resource_group_name,
876                           application_gateway_name,
877                           private_link_name,
878                           private_link_ip_name,
879                           private_link_primary=False,
880                           private_link_ip_address=None,
881                           no_wait=False):
882    ncf = network_client_factory(cmd.cli_ctx)
883
884    appgw = ncf.application_gateways.get(resource_group_name, application_gateway_name)
885
886    target_private_link = None
887    for pl in appgw.private_link_configurations:
888        if pl.name == private_link_name:
889            target_private_link = pl
890            break
891    else:
892        raise CLIError("Priavte Link doesn't exist")
893
894    (SubResource, IPAllocationMethod,
895     ApplicationGatewayPrivateLinkIpConfiguration) = \
896        cmd.get_models('SubResource', 'IPAllocationMethod',
897                       'ApplicationGatewayPrivateLinkIpConfiguration')
898
899    private_link_subnet_id = target_private_link.ip_configurations[0].subnet.id
900
901    private_link_ip_allocation_method = IPAllocationMethod.static.value if private_link_ip_address \
902        else IPAllocationMethod.dynamic.value
903    private_link_ip_config = ApplicationGatewayPrivateLinkIpConfiguration(
904        name=private_link_ip_name,
905        private_ip_address=private_link_ip_address,
906        private_ip_allocation_method=private_link_ip_allocation_method,
907        subnet=SubResource(id=private_link_subnet_id),
908        primary=private_link_primary
909    )
910
911    target_private_link.ip_configurations.append(private_link_ip_config)
912
913    return sdk_no_wait(no_wait,
914                       ncf.application_gateways.begin_create_or_update,
915                       resource_group_name,
916                       application_gateway_name,
917                       appgw)
918
919
920def show_ag_private_link_ip(cmd,
921                            resource_group_name,
922                            application_gateway_name,
923                            private_link_name,
924                            private_link_ip_name):
925    ncf = network_client_factory(cmd.cli_ctx)
926
927    appgw = ncf.application_gateways.get(resource_group_name, application_gateway_name)
928
929    target_private_link = None
930    for pl in appgw.private_link_configurations:
931        if pl.name == private_link_name:
932            target_private_link = pl
933            break
934    else:
935        raise CLIError("Priavte Link doesn't exist")
936
937    target_private_link_ip_config = None
938    for pic in target_private_link.ip_configurations:
939        if pic.name == private_link_ip_name:
940            target_private_link_ip_config = pic
941            break
942    else:
943        raise CLIError("IP Configuration doesn't exist")
944
945    return target_private_link_ip_config
946
947
948def list_ag_private_link_ip(cmd,
949                            resource_group_name,
950                            application_gateway_name,
951                            private_link_name):
952    ncf = network_client_factory(cmd.cli_ctx)
953
954    appgw = ncf.application_gateways.get(resource_group_name, application_gateway_name)
955
956    target_private_link = None
957    for pl in appgw.private_link_configurations:
958        if pl.name == private_link_name:
959            target_private_link = pl
960            break
961    else:
962        raise CLIError("Priavte Link doesn't exist")
963
964    return target_private_link.ip_configurations
965
966
967def remove_ag_private_link_ip(cmd,
968                              resource_group_name,
969                              application_gateway_name,
970                              private_link_name,
971                              private_link_ip_name,
972                              no_wait=False):
973    ncf = network_client_factory(cmd.cli_ctx)
974
975    appgw = ncf.application_gateways.get(resource_group_name, application_gateway_name)
976
977    target_private_link = None
978    for pl in appgw.private_link_configurations:
979        if pl.name == private_link_name:
980            target_private_link = pl
981            break
982    else:
983        raise CLIError("Priavte Link doesn't exist")
984
985    updated_ip_configurations = target_private_link.ip_configurations
986    for pic in target_private_link.ip_configurations:
987        if pic.name == private_link_ip_name:
988            updated_ip_configurations.remove(pic)
989            break
990    else:
991        raise CLIError("IP Configuration doesn't exist")
992
993    return sdk_no_wait(no_wait,
994                       ncf.application_gateways.begin_create_or_update,
995                       resource_group_name,
996                       application_gateway_name,
997                       appgw)
998
999
1000def create_ag_backend_http_settings_collection(cmd, resource_group_name, application_gateway_name, item_name, port,
1001                                               probe=None, protocol='http', cookie_based_affinity=None, timeout=None,
1002                                               no_wait=False, connection_draining_timeout=0,
1003                                               host_name=None, host_name_from_backend_pool=None,
1004                                               affinity_cookie_name=None, enable_probe=None, path=None,
1005                                               auth_certs=None, root_certs=None):
1006    ApplicationGatewayBackendHttpSettings, ApplicationGatewayConnectionDraining, SubResource = cmd.get_models(
1007        'ApplicationGatewayBackendHttpSettings', 'ApplicationGatewayConnectionDraining', 'SubResource')
1008    ncf = network_client_factory(cmd.cli_ctx)
1009    ag = ncf.application_gateways.get(resource_group_name, application_gateway_name)
1010    new_settings = ApplicationGatewayBackendHttpSettings(
1011        port=port,
1012        protocol=protocol,
1013        cookie_based_affinity=cookie_based_affinity or 'Disabled',
1014        request_timeout=timeout,
1015        probe=SubResource(id=probe) if probe else None,
1016        name=item_name)
1017    if cmd.supported_api_version(min_api='2016-09-01'):
1018        new_settings.authentication_certificates = [SubResource(id=x) for x in auth_certs or []]
1019    if cmd.supported_api_version(min_api='2016-12-01'):
1020        new_settings.connection_draining = \
1021            ApplicationGatewayConnectionDraining(
1022                enabled=bool(connection_draining_timeout), drain_timeout_in_sec=connection_draining_timeout or 1)
1023    if cmd.supported_api_version(min_api='2017-06-01'):
1024        new_settings.host_name = host_name
1025        new_settings.pick_host_name_from_backend_address = host_name_from_backend_pool
1026        new_settings.affinity_cookie_name = affinity_cookie_name
1027        new_settings.probe_enabled = enable_probe
1028        new_settings.path = path
1029    if cmd.supported_api_version(min_api='2019-04-01'):
1030        new_settings.trusted_root_certificates = [SubResource(id=x) for x in root_certs or []]
1031    upsert_to_collection(ag, 'backend_http_settings_collection', new_settings, 'name')
1032    return sdk_no_wait(no_wait, ncf.application_gateways.begin_create_or_update,
1033                       resource_group_name, application_gateway_name, ag)
1034
1035
1036def update_ag_backend_http_settings_collection(cmd, instance, parent, item_name, port=None, probe=None, protocol=None,
1037                                               cookie_based_affinity=None, timeout=None,
1038                                               connection_draining_timeout=None,
1039                                               host_name=None, host_name_from_backend_pool=None,
1040                                               affinity_cookie_name=None, enable_probe=None, path=None,
1041                                               auth_certs=None, root_certs=None):
1042    SubResource = cmd.get_models('SubResource')
1043    if auth_certs == "":
1044        instance.authentication_certificates = None
1045    elif auth_certs is not None:
1046        instance.authentication_certificates = [SubResource(id=x) for x in auth_certs]
1047    if root_certs == "":
1048        instance.trusted_root_certificates = None
1049    elif root_certs is not None:
1050        instance.trusted_root_certificates = [SubResource(id=x) for x in root_certs]
1051    if port is not None:
1052        instance.port = port
1053    if probe is not None:
1054        instance.probe = SubResource(id=probe)
1055    if protocol is not None:
1056        instance.protocol = protocol
1057    if cookie_based_affinity is not None:
1058        instance.cookie_based_affinity = cookie_based_affinity
1059    if timeout is not None:
1060        instance.request_timeout = timeout
1061    if connection_draining_timeout is not None:
1062        instance.connection_draining = {
1063            'enabled': bool(connection_draining_timeout),
1064            'drain_timeout_in_sec': connection_draining_timeout or 1
1065        }
1066    if host_name is not None:
1067        instance.host_name = host_name
1068    if host_name_from_backend_pool is not None:
1069        instance.pick_host_name_from_backend_address = host_name_from_backend_pool
1070    if affinity_cookie_name is not None:
1071        instance.affinity_cookie_name = affinity_cookie_name
1072    if enable_probe is not None:
1073        instance.probe_enabled = enable_probe
1074    if path is not None:
1075        instance.path = path
1076    return parent
1077
1078
1079def create_ag_redirect_configuration(cmd, resource_group_name, application_gateway_name, item_name, redirect_type,
1080                                     target_listener=None, target_url=None, include_path=None,
1081                                     include_query_string=None, no_wait=False):
1082    ApplicationGatewayRedirectConfiguration, SubResource = cmd.get_models(
1083        'ApplicationGatewayRedirectConfiguration', 'SubResource')
1084    ncf = network_client_factory(cmd.cli_ctx).application_gateways
1085    ag = ncf.get(resource_group_name, application_gateway_name)
1086    new_config = ApplicationGatewayRedirectConfiguration(
1087        name=item_name,
1088        redirect_type=redirect_type,
1089        target_listener=SubResource(id=target_listener) if target_listener else None,
1090        target_url=target_url,
1091        include_path=include_path,
1092        include_query_string=include_query_string)
1093    upsert_to_collection(ag, 'redirect_configurations', new_config, 'name')
1094    return sdk_no_wait(no_wait, ncf.begin_create_or_update, resource_group_name, application_gateway_name, ag)
1095
1096
1097def update_ag_redirect_configuration(cmd, instance, parent, item_name, redirect_type=None,
1098                                     target_listener=None, target_url=None, include_path=None,
1099                                     include_query_string=None, raw=False):
1100    SubResource = cmd.get_models('SubResource')
1101    if redirect_type:
1102        instance.redirect_type = redirect_type
1103    if target_listener:
1104        instance.target_listener = SubResource(id=target_listener)
1105        instance.target_url = None
1106    if target_url:
1107        instance.target_listener = None
1108        instance.target_url = target_url
1109    if include_path is not None:
1110        instance.include_path = include_path
1111    if include_query_string is not None:
1112        instance.include_query_string = include_query_string
1113    return parent
1114
1115
1116def create_ag_rewrite_rule_set(cmd, resource_group_name, application_gateway_name, item_name, no_wait=False):
1117    ApplicationGatewayRewriteRuleSet = cmd.get_models(
1118        'ApplicationGatewayRewriteRuleSet')
1119    ncf = network_client_factory(cmd.cli_ctx).application_gateways
1120    ag = ncf.get(resource_group_name, application_gateway_name)
1121    new_set = ApplicationGatewayRewriteRuleSet(name=item_name)
1122    upsert_to_collection(ag, 'rewrite_rule_sets', new_set, 'name')
1123    if no_wait:
1124        return sdk_no_wait(no_wait, ncf.begin_create_or_update, resource_group_name, application_gateway_name, ag)
1125    parent = sdk_no_wait(no_wait, ncf.begin_create_or_update,
1126                         resource_group_name, application_gateway_name, ag).result()
1127    return find_child_item(parent, item_name,
1128                           path='rewrite_rule_sets', key_path='name')
1129
1130
1131def update_ag_rewrite_rule_set(instance, parent, item_name):
1132    return parent
1133
1134
1135def create_ag_rewrite_rule(cmd, resource_group_name, application_gateway_name, rule_set_name, rule_name,
1136                           sequence=None, request_headers=None, response_headers=None, no_wait=False,
1137                           modified_path=None, modified_query_string=None, enable_reroute=None):
1138    (ApplicationGatewayRewriteRule,
1139     ApplicationGatewayRewriteRuleActionSet,
1140     ApplicationGatewayUrlConfiguration) = cmd.get_models('ApplicationGatewayRewriteRule',
1141                                                          'ApplicationGatewayRewriteRuleActionSet',
1142                                                          'ApplicationGatewayUrlConfiguration')
1143    ncf = network_client_factory(cmd.cli_ctx).application_gateways
1144    ag = ncf.get(resource_group_name, application_gateway_name)
1145    rule_set = find_child_item(ag, rule_set_name,
1146                               path='rewrite_rule_sets', key_path='name')
1147    url_configuration = None
1148    if any([modified_path, modified_query_string, enable_reroute]):
1149        url_configuration = ApplicationGatewayUrlConfiguration(modified_path=modified_path,
1150                                                               modified_query_string=modified_query_string,
1151                                                               reroute=enable_reroute)
1152    new_rule = ApplicationGatewayRewriteRule(
1153        name=rule_name,
1154        rule_sequence=sequence,
1155        action_set=ApplicationGatewayRewriteRuleActionSet(
1156            request_header_configurations=request_headers,
1157            response_header_configurations=response_headers,
1158            url_configuration=url_configuration
1159        )
1160    )
1161    upsert_to_collection(rule_set, 'rewrite_rules', new_rule, 'name')
1162    if no_wait:
1163        return sdk_no_wait(no_wait, ncf.begin_create_or_update, resource_group_name, application_gateway_name, ag)
1164    parent = sdk_no_wait(no_wait, ncf.begin_create_or_update,
1165                         resource_group_name, application_gateway_name, ag).result()
1166    return find_child_item(parent, rule_set_name, rule_name,
1167                           path='rewrite_rule_sets.rewrite_rules', key_path='name.name')
1168
1169
1170def update_ag_rewrite_rule(instance, parent, cmd, rule_set_name, rule_name, sequence=None,
1171                           request_headers=None, response_headers=None,
1172                           modified_path=None, modified_query_string=None, enable_reroute=None):
1173    with cmd.update_context(instance) as c:
1174        c.set_param('rule_sequence', sequence)
1175        c.set_param('action_set.request_header_configurations', request_headers)
1176        c.set_param('action_set.response_header_configurations', response_headers)
1177        ApplicationGatewayUrlConfiguration = cmd.get_models('ApplicationGatewayUrlConfiguration')
1178        url_configuration = None
1179        if any([modified_path, modified_query_string, enable_reroute]):
1180            url_configuration = ApplicationGatewayUrlConfiguration(modified_path=modified_path,
1181                                                                   modified_query_string=modified_query_string,
1182                                                                   reroute=enable_reroute)
1183        c.set_param('action_set.url_configuration', url_configuration)
1184    return parent
1185
1186
1187def show_ag_rewrite_rule(cmd, resource_group_name, application_gateway_name, rule_set_name, rule_name):
1188    client = network_client_factory(cmd.cli_ctx).application_gateways
1189    gateway = client.get(resource_group_name, application_gateway_name)
1190    return find_child_item(gateway, rule_set_name, rule_name,
1191                           path='rewrite_rule_sets.rewrite_rules', key_path='name.name')
1192
1193
1194def list_ag_rewrite_rules(cmd, resource_group_name, application_gateway_name, rule_set_name):
1195    client = network_client_factory(cmd.cli_ctx).application_gateways
1196    gateway = client.get(resource_group_name, application_gateway_name)
1197    return find_child_collection(gateway, rule_set_name, path='rewrite_rule_sets.rewrite_rules', key_path='name')
1198
1199
1200def delete_ag_rewrite_rule(cmd, resource_group_name, application_gateway_name, rule_set_name, rule_name, no_wait=None):
1201    client = network_client_factory(cmd.cli_ctx).application_gateways
1202    gateway = client.get(resource_group_name, application_gateway_name)
1203    rule_set = find_child_item(gateway, rule_set_name, path='rewrite_rule_sets', key_path='name')
1204    rule = find_child_item(rule_set, rule_name, path='rewrite_rules', key_path='name')
1205    rule_set.rewrite_rules.remove(rule)
1206    sdk_no_wait(no_wait, client.begin_create_or_update, resource_group_name, application_gateway_name, gateway)
1207
1208
1209def create_ag_rewrite_rule_condition(cmd, resource_group_name, application_gateway_name, rule_set_name, rule_name,
1210                                     variable, no_wait=False, pattern=None, ignore_case=None, negate=None):
1211    ApplicationGatewayRewriteRuleCondition = cmd.get_models(
1212        'ApplicationGatewayRewriteRuleCondition')
1213    ncf = network_client_factory(cmd.cli_ctx).application_gateways
1214    ag = ncf.get(resource_group_name, application_gateway_name)
1215    rule = find_child_item(ag, rule_set_name, rule_name,
1216                           path='rewrite_rule_sets.rewrite_rules', key_path='name.name')
1217    new_condition = ApplicationGatewayRewriteRuleCondition(
1218        variable=variable,
1219        pattern=pattern,
1220        ignore_case=ignore_case,
1221        negate=negate
1222    )
1223    upsert_to_collection(rule, 'conditions', new_condition, 'variable')
1224    if no_wait:
1225        return sdk_no_wait(no_wait, ncf.begin_create_or_update, resource_group_name, application_gateway_name, ag)
1226    parent = sdk_no_wait(no_wait, ncf.begin_create_or_update,
1227                         resource_group_name, application_gateway_name, ag).result()
1228    return find_child_item(parent, rule_set_name, rule_name, variable,
1229                           path='rewrite_rule_sets.rewrite_rules.conditions', key_path='name.name.variable')
1230
1231
1232def update_ag_rewrite_rule_condition(instance, parent, cmd, rule_set_name, rule_name, variable, pattern=None,
1233                                     ignore_case=None, negate=None):
1234    with cmd.update_context(instance) as c:
1235        c.set_param('pattern', pattern)
1236        c.set_param('ignore_case', ignore_case)
1237        c.set_param('negate', negate)
1238    return parent
1239
1240
1241def show_ag_rewrite_rule_condition(cmd, resource_group_name, application_gateway_name, rule_set_name,
1242                                   rule_name, variable):
1243    client = network_client_factory(cmd.cli_ctx).application_gateways
1244    gateway = client.get(resource_group_name, application_gateway_name)
1245    return find_child_item(gateway, rule_set_name, rule_name, variable,
1246                           path='rewrite_rule_sets.rewrite_rules.conditions', key_path='name.name.variable')
1247
1248
1249def list_ag_rewrite_rule_conditions(cmd, resource_group_name, application_gateway_name, rule_set_name, rule_name):
1250    client = network_client_factory(cmd.cli_ctx).application_gateways
1251    gateway = client.get(resource_group_name, application_gateway_name)
1252    return find_child_collection(gateway, rule_set_name, rule_name,
1253                                 path='rewrite_rule_sets.rewrite_rules.conditions', key_path='name.name')
1254
1255
1256def delete_ag_rewrite_rule_condition(cmd, resource_group_name, application_gateway_name, rule_set_name,
1257                                     rule_name, variable, no_wait=None):
1258    client = network_client_factory(cmd.cli_ctx).application_gateways
1259    gateway = client.get(resource_group_name, application_gateway_name)
1260    rule = find_child_item(gateway, rule_set_name, rule_name,
1261                           path='rewrite_rule_sets.rewrite_rules', key_path='name.name')
1262    condition = find_child_item(rule, variable, path='conditions', key_path='variable')
1263    rule.conditions.remove(condition)
1264    sdk_no_wait(no_wait, client.begin_create_or_update, resource_group_name, application_gateway_name, gateway)
1265
1266
1267def create_ag_probe(cmd, resource_group_name, application_gateway_name, item_name, protocol, host,
1268                    path, interval=30, timeout=120, threshold=8, no_wait=False, host_name_from_http_settings=None,
1269                    min_servers=None, match_body=None, match_status_codes=None, port=None):
1270    ApplicationGatewayProbe, ProbeMatchCriteria = cmd.get_models(
1271        'ApplicationGatewayProbe', 'ApplicationGatewayProbeHealthResponseMatch')
1272    ncf = network_client_factory(cmd.cli_ctx)
1273    ag = ncf.application_gateways.get(resource_group_name, application_gateway_name)
1274    new_probe = ApplicationGatewayProbe(
1275        name=item_name,
1276        protocol=protocol,
1277        host=host,
1278        path=path,
1279        interval=interval,
1280        timeout=timeout,
1281        unhealthy_threshold=threshold)
1282    if cmd.supported_api_version(min_api='2017-06-01'):
1283        new_probe.pick_host_name_from_backend_http_settings = host_name_from_http_settings
1284        new_probe.min_servers = min_servers
1285        new_probe.match = ProbeMatchCriteria(body=match_body, status_codes=match_status_codes)
1286    if cmd.supported_api_version(min_api='2019-04-01'):
1287        new_probe.port = port
1288
1289    upsert_to_collection(ag, 'probes', new_probe, 'name')
1290    return sdk_no_wait(no_wait, ncf.application_gateways.begin_create_or_update,
1291                       resource_group_name, application_gateway_name, ag)
1292
1293
1294def update_ag_probe(cmd, instance, parent, item_name, protocol=None, host=None, path=None,
1295                    interval=None, timeout=None, threshold=None, host_name_from_http_settings=None,
1296                    min_servers=None, match_body=None, match_status_codes=None, port=None):
1297    if protocol is not None:
1298        instance.protocol = protocol
1299    if host is not None:
1300        instance.host = host
1301    if path is not None:
1302        instance.path = path
1303    if interval is not None:
1304        instance.interval = interval
1305    if timeout is not None:
1306        instance.timeout = timeout
1307    if threshold is not None:
1308        instance.unhealthy_threshold = threshold
1309    if host_name_from_http_settings is not None:
1310        instance.pick_host_name_from_backend_http_settings = host_name_from_http_settings
1311    if min_servers is not None:
1312        instance.min_servers = min_servers
1313    if match_body is not None or match_status_codes is not None:
1314        ProbeMatchCriteria = \
1315            cmd.get_models('ApplicationGatewayProbeHealthResponseMatch')
1316        instance.match = instance.match or ProbeMatchCriteria()
1317        if match_body is not None:
1318            instance.match.body = match_body
1319        if match_status_codes is not None:
1320            instance.match.status_codes = match_status_codes
1321    if port is not None:
1322        instance.port = port
1323    return parent
1324
1325
1326def create_ag_request_routing_rule(cmd, resource_group_name, application_gateway_name, item_name,
1327                                   address_pool=None, http_settings=None, http_listener=None, redirect_config=None,
1328                                   url_path_map=None, rule_type='Basic', no_wait=False, rewrite_rule_set=None,
1329                                   priority=None):
1330    ApplicationGatewayRequestRoutingRule, SubResource = cmd.get_models(
1331        'ApplicationGatewayRequestRoutingRule', 'SubResource')
1332    ncf = network_client_factory(cmd.cli_ctx)
1333    ag = ncf.application_gateways.get(resource_group_name, application_gateway_name)
1334    if not address_pool and not redirect_config:
1335        address_pool = _get_default_id(ag, 'backend_address_pools', '--address-pool')
1336    if not http_settings and not redirect_config:
1337        http_settings = _get_default_id(ag, 'backend_http_settings_collection', '--http-settings')
1338    if not http_listener:
1339        http_listener = _get_default_id(ag, 'http_listeners', '--http-listener')
1340    new_rule = ApplicationGatewayRequestRoutingRule(
1341        name=item_name,
1342        rule_type=rule_type,
1343        priority=priority,
1344        backend_address_pool=SubResource(id=address_pool) if address_pool else None,
1345        backend_http_settings=SubResource(id=http_settings) if http_settings else None,
1346        http_listener=SubResource(id=http_listener),
1347        url_path_map=SubResource(id=url_path_map) if url_path_map else None)
1348    if cmd.supported_api_version(min_api='2017-06-01'):
1349        new_rule.redirect_configuration = SubResource(id=redirect_config) if redirect_config else None
1350
1351    rewrite_rule_set_name = next(key for key, value in locals().items() if id(value) == id(rewrite_rule_set))
1352    if cmd.supported_api_version(parameter_name=rewrite_rule_set_name):
1353        new_rule.rewrite_rule_set = SubResource(id=rewrite_rule_set) if rewrite_rule_set else None
1354    upsert_to_collection(ag, 'request_routing_rules', new_rule, 'name')
1355    return sdk_no_wait(no_wait, ncf.application_gateways.begin_create_or_update,
1356                       resource_group_name, application_gateway_name, ag)
1357
1358
1359def update_ag_request_routing_rule(cmd, instance, parent, item_name, address_pool=None,
1360                                   http_settings=None, http_listener=None, redirect_config=None, url_path_map=None,
1361                                   rule_type=None, rewrite_rule_set=None, priority=None):
1362    SubResource = cmd.get_models('SubResource')
1363    if address_pool is not None:
1364        instance.backend_address_pool = SubResource(id=address_pool)
1365    if http_settings is not None:
1366        instance.backend_http_settings = SubResource(id=http_settings)
1367    if redirect_config is not None:
1368        instance.redirect_configuration = SubResource(id=redirect_config)
1369    if http_listener is not None:
1370        instance.http_listener = SubResource(id=http_listener)
1371    if url_path_map is not None:
1372        instance.url_path_map = SubResource(id=url_path_map)
1373    if rule_type is not None:
1374        instance.rule_type = rule_type
1375    if rewrite_rule_set is not None:
1376        instance.rewrite_rule_set = SubResource(id=rewrite_rule_set)
1377    with cmd.update_context(instance) as c:
1378        c.set_param('priority', priority)
1379    return parent
1380
1381
1382def create_ag_ssl_certificate(cmd, resource_group_name, application_gateway_name, item_name, cert_data=None,
1383                              cert_password=None, key_vault_secret_id=None, no_wait=False):
1384    ApplicationGatewaySslCertificate = cmd.get_models('ApplicationGatewaySslCertificate')
1385    ncf = network_client_factory(cmd.cli_ctx)
1386    ag = ncf.application_gateways.get(resource_group_name, application_gateway_name)
1387    new_cert = ApplicationGatewaySslCertificate(
1388        name=item_name, data=cert_data, password=cert_password, key_vault_secret_id=key_vault_secret_id)
1389    upsert_to_collection(ag, 'ssl_certificates', new_cert, 'name')
1390    return sdk_no_wait(no_wait, ncf.application_gateways.begin_create_or_update,
1391                       resource_group_name, application_gateway_name, ag)
1392
1393
1394def update_ag_ssl_certificate(instance, parent, item_name,
1395                              cert_data=None, cert_password=None, key_vault_secret_id=None):
1396    if cert_data is not None:
1397        instance.data = cert_data
1398    if cert_password is not None:
1399        instance.password = cert_password
1400    if key_vault_secret_id is not None:
1401        instance.key_vault_secret_id = key_vault_secret_id
1402    return parent
1403
1404
1405def set_ag_ssl_policy_2017_03_01(cmd, resource_group_name, application_gateway_name, disabled_ssl_protocols=None,
1406                                 clear=False, no_wait=False):
1407    ApplicationGatewaySslPolicy = cmd.get_models('ApplicationGatewaySslPolicy')
1408    ncf = network_client_factory(cmd.cli_ctx).application_gateways
1409    ag = ncf.get(resource_group_name, application_gateway_name)
1410    ag.ssl_policy = None if clear else ApplicationGatewaySslPolicy(
1411        disabled_ssl_protocols=disabled_ssl_protocols)
1412    return sdk_no_wait(no_wait, ncf.begin_create_or_update, resource_group_name, application_gateway_name, ag)
1413
1414
1415def set_ag_ssl_policy_2017_06_01(cmd, resource_group_name, application_gateway_name, policy_name=None, policy_type=None,
1416                                 disabled_ssl_protocols=None, cipher_suites=None, min_protocol_version=None,
1417                                 no_wait=False):
1418    ApplicationGatewaySslPolicy, ApplicationGatewaySslPolicyType = cmd.get_models(
1419        'ApplicationGatewaySslPolicy', 'ApplicationGatewaySslPolicyType')
1420    ncf = network_client_factory(cmd.cli_ctx).application_gateways
1421    ag = ncf.get(resource_group_name, application_gateway_name)
1422    policy_type = None
1423    if policy_name:
1424        policy_type = ApplicationGatewaySslPolicyType.predefined.value
1425    elif cipher_suites or min_protocol_version:
1426        policy_type = ApplicationGatewaySslPolicyType.custom.value
1427    ag.ssl_policy = ApplicationGatewaySslPolicy(
1428        policy_name=policy_name,
1429        policy_type=policy_type,
1430        disabled_ssl_protocols=disabled_ssl_protocols,
1431        cipher_suites=cipher_suites,
1432        min_protocol_version=min_protocol_version)
1433    return sdk_no_wait(no_wait, ncf.begin_create_or_update, resource_group_name, application_gateway_name, ag)
1434
1435
1436def show_ag_ssl_policy(cmd, resource_group_name, application_gateway_name):
1437    return network_client_factory(cmd.cli_ctx).application_gateways.get(
1438        resource_group_name, application_gateway_name).ssl_policy
1439
1440
1441def create_ag_trusted_root_certificate(cmd, resource_group_name, application_gateway_name, item_name, no_wait=False,
1442                                       cert_data=None, keyvault_secret=None):
1443    ApplicationGatewayTrustedRootCertificate = cmd.get_models('ApplicationGatewayTrustedRootCertificate')
1444    ncf = network_client_factory(cmd.cli_ctx).application_gateways
1445    ag = ncf.get(resource_group_name, application_gateway_name)
1446    root_cert = ApplicationGatewayTrustedRootCertificate(name=item_name, data=cert_data,
1447                                                         key_vault_secret_id=keyvault_secret)
1448    upsert_to_collection(ag, 'trusted_root_certificates', root_cert, 'name')
1449    return sdk_no_wait(no_wait, ncf.begin_create_or_update,
1450                       resource_group_name, application_gateway_name, ag)
1451
1452
1453def update_ag_trusted_root_certificate(instance, parent, item_name, cert_data=None, keyvault_secret=None):
1454    if cert_data is not None:
1455        instance.data = cert_data
1456    if keyvault_secret is not None:
1457        instance.key_vault_secret_id = keyvault_secret
1458    return parent
1459
1460
1461def create_ag_url_path_map(cmd, resource_group_name, application_gateway_name, item_name, paths,
1462                           address_pool=None, http_settings=None, redirect_config=None, rewrite_rule_set=None,
1463                           default_address_pool=None, default_http_settings=None, default_redirect_config=None,
1464                           no_wait=False, rule_name='default', default_rewrite_rule_set=None, firewall_policy=None):
1465    ApplicationGatewayUrlPathMap, ApplicationGatewayPathRule, SubResource = cmd.get_models(
1466        'ApplicationGatewayUrlPathMap', 'ApplicationGatewayPathRule', 'SubResource')
1467    ncf = network_client_factory(cmd.cli_ctx)
1468    ag = ncf.application_gateways.get(resource_group_name, application_gateway_name)
1469
1470    new_rule = ApplicationGatewayPathRule(
1471        name=rule_name,
1472        backend_address_pool=SubResource(id=address_pool) if address_pool else None,
1473        backend_http_settings=SubResource(id=http_settings) if http_settings else None,
1474        paths=paths
1475    )
1476    new_map = ApplicationGatewayUrlPathMap(
1477        name=item_name,
1478        default_backend_address_pool=SubResource(id=default_address_pool) if default_address_pool else None,
1479        default_backend_http_settings=SubResource(id=default_http_settings) if default_http_settings else None,
1480        path_rules=[])
1481    if cmd.supported_api_version(min_api='2017-06-01'):
1482        new_rule.redirect_configuration = SubResource(id=redirect_config) if redirect_config else None
1483        new_map.default_redirect_configuration = \
1484            SubResource(id=default_redirect_config) if default_redirect_config else None
1485
1486    rewrite_rule_set_name = next(key for key, value in locals().items() if id(value) == id(rewrite_rule_set))
1487    if cmd.supported_api_version(parameter_name=rewrite_rule_set_name):
1488        new_rule.rewrite_rule_set = SubResource(id=rewrite_rule_set) if rewrite_rule_set else None
1489        new_map.default_rewrite_rule_set = \
1490            SubResource(id=default_rewrite_rule_set) if default_rewrite_rule_set else None
1491
1492    if cmd.supported_api_version(min_api='2019-09-01'):
1493        new_rule.firewall_policy = SubResource(id=firewall_policy) if firewall_policy else None
1494
1495    # pull defaults from the rule specific properties if the default-* option isn't specified
1496    if new_rule.backend_address_pool and not new_map.default_backend_address_pool:
1497        new_map.default_backend_address_pool = new_rule.backend_address_pool
1498
1499    if new_rule.backend_http_settings and not new_map.default_backend_http_settings:
1500        new_map.default_backend_http_settings = new_rule.backend_http_settings
1501
1502    if new_rule.redirect_configuration and not new_map.default_redirect_configuration:
1503        new_map.default_redirect_configuration = new_rule.redirect_configuration
1504
1505    new_map.path_rules.append(new_rule)
1506    upsert_to_collection(ag, 'url_path_maps', new_map, 'name')
1507    return sdk_no_wait(no_wait, ncf.application_gateways.begin_create_or_update,
1508                       resource_group_name, application_gateway_name, ag)
1509
1510
1511def update_ag_url_path_map(cmd, instance, parent, item_name, default_address_pool=None,
1512                           default_http_settings=None, default_redirect_config=None, raw=False,
1513                           default_rewrite_rule_set=None):
1514    SubResource = cmd.get_models('SubResource')
1515    if default_address_pool == '':
1516        instance.default_backend_address_pool = None
1517    elif default_address_pool:
1518        instance.default_backend_address_pool = SubResource(id=default_address_pool)
1519
1520    if default_http_settings == '':
1521        instance.default_backend_http_settings = None
1522    elif default_http_settings:
1523        instance.default_backend_http_settings = SubResource(id=default_http_settings)
1524
1525    if default_redirect_config == '':
1526        instance.default_redirect_configuration = None
1527    elif default_redirect_config:
1528        instance.default_redirect_configuration = SubResource(id=default_redirect_config)
1529
1530    if default_rewrite_rule_set == '':
1531        instance.default_rewrite_rule_set = None
1532    elif default_rewrite_rule_set:
1533        instance.default_rewrite_rule_set = SubResource(id=default_rewrite_rule_set)
1534    return parent
1535
1536
1537def create_ag_url_path_map_rule(cmd, resource_group_name, application_gateway_name, url_path_map_name,
1538                                item_name, paths, address_pool=None, http_settings=None, redirect_config=None,
1539                                firewall_policy=None, no_wait=False, rewrite_rule_set=None):
1540    ApplicationGatewayPathRule, SubResource = cmd.get_models('ApplicationGatewayPathRule', 'SubResource')
1541    if address_pool and redirect_config:
1542        raise CLIError("Cannot reference a BackendAddressPool when Redirect Configuration is specified.")
1543    ncf = network_client_factory(cmd.cli_ctx)
1544    ag = ncf.application_gateways.get(resource_group_name, application_gateway_name)
1545    url_map = next((x for x in ag.url_path_maps if x.name == url_path_map_name), None)
1546    if not url_map:
1547        raise CLIError('URL path map "{}" not found.'.format(url_path_map_name))
1548    default_backend_pool = SubResource(id=url_map.default_backend_address_pool.id) \
1549        if (url_map.default_backend_address_pool and not redirect_config) else None
1550    default_http_settings = SubResource(id=url_map.default_backend_http_settings.id) \
1551        if url_map.default_backend_http_settings else None
1552    new_rule = ApplicationGatewayPathRule(
1553        name=item_name,
1554        paths=paths,
1555        backend_address_pool=SubResource(id=address_pool) if address_pool else default_backend_pool,
1556        backend_http_settings=SubResource(id=http_settings) if http_settings else default_http_settings)
1557    if cmd.supported_api_version(min_api='2017-06-01'):
1558        default_redirect = SubResource(id=url_map.default_redirect_configuration.id) \
1559            if (url_map.default_redirect_configuration and not address_pool) else None
1560        new_rule.redirect_configuration = SubResource(id=redirect_config) if redirect_config else default_redirect
1561
1562    rewrite_rule_set_name = next(key for key, value in locals().items() if id(value) == id(rewrite_rule_set))
1563    if cmd.supported_api_version(parameter_name=rewrite_rule_set_name):
1564        new_rule.rewrite_rule_set = SubResource(id=rewrite_rule_set) if rewrite_rule_set else None
1565
1566    if cmd.supported_api_version(min_api='2019-09-01'):
1567        new_rule.firewall_policy = SubResource(id=firewall_policy) if firewall_policy else None
1568
1569    upsert_to_collection(url_map, 'path_rules', new_rule, 'name')
1570    return sdk_no_wait(no_wait, ncf.application_gateways.begin_create_or_update,
1571                       resource_group_name, application_gateway_name, ag)
1572
1573
1574def delete_ag_url_path_map_rule(cmd, resource_group_name, application_gateway_name, url_path_map_name,
1575                                item_name, no_wait=False):
1576    ncf = network_client_factory(cmd.cli_ctx)
1577    ag = ncf.application_gateways.get(resource_group_name, application_gateway_name)
1578    url_map = next((x for x in ag.url_path_maps if x.name == url_path_map_name), None)
1579    if not url_map:
1580        raise CLIError('URL path map "{}" not found.'.format(url_path_map_name))
1581    url_map.path_rules = \
1582        [x for x in url_map.path_rules if x.name.lower() != item_name.lower()]
1583    return sdk_no_wait(no_wait, ncf.application_gateways.begin_create_or_update,
1584                       resource_group_name, application_gateway_name, ag)
1585
1586
1587def set_ag_waf_config_2016_09_01(cmd, resource_group_name, application_gateway_name, enabled,
1588                                 firewall_mode=None,
1589                                 no_wait=False):
1590    ApplicationGatewayWebApplicationFirewallConfiguration = cmd.get_models(
1591        'ApplicationGatewayWebApplicationFirewallConfiguration')
1592    ncf = network_client_factory(cmd.cli_ctx).application_gateways
1593    ag = ncf.get(resource_group_name, application_gateway_name)
1594    ag.web_application_firewall_configuration = \
1595        ApplicationGatewayWebApplicationFirewallConfiguration(
1596            enabled=(enabled == 'true'), firewall_mode=firewall_mode)
1597
1598    return sdk_no_wait(no_wait, ncf.begin_create_or_update, resource_group_name, application_gateway_name, ag)
1599
1600
1601def set_ag_waf_config_2017_03_01(cmd, resource_group_name, application_gateway_name, enabled,
1602                                 firewall_mode=None,
1603                                 rule_set_type='OWASP', rule_set_version=None,
1604                                 disabled_rule_groups=None,
1605                                 disabled_rules=None, no_wait=False,
1606                                 request_body_check=None, max_request_body_size=None, file_upload_limit=None,
1607                                 exclusions=None):
1608    ApplicationGatewayWebApplicationFirewallConfiguration = cmd.get_models(
1609        'ApplicationGatewayWebApplicationFirewallConfiguration')
1610    ncf = network_client_factory(cmd.cli_ctx).application_gateways
1611    ag = ncf.get(resource_group_name, application_gateway_name)
1612    ag.web_application_firewall_configuration = \
1613        ApplicationGatewayWebApplicationFirewallConfiguration(
1614            enabled=(enabled == 'true'), firewall_mode=firewall_mode, rule_set_type=rule_set_type,
1615            rule_set_version=rule_set_version)
1616    if disabled_rule_groups or disabled_rules:
1617        ApplicationGatewayFirewallDisabledRuleGroup = cmd.get_models('ApplicationGatewayFirewallDisabledRuleGroup')
1618
1619        disabled_groups = []
1620
1621        # disabled groups can be added directly
1622        for group in disabled_rule_groups or []:
1623            disabled_groups.append(ApplicationGatewayFirewallDisabledRuleGroup(rule_group_name=group))
1624
1625        def _flatten(collection, expand_property_fn):
1626            for each in collection:
1627                for value in expand_property_fn(each):
1628                    yield value
1629
1630        # for disabled rules, we have to look up the IDs
1631        if disabled_rules:
1632            results = list_ag_waf_rule_sets(ncf, _type=rule_set_type, version=rule_set_version, group='*')
1633            for group in _flatten(results, lambda r: r.rule_groups):
1634                disabled_group = ApplicationGatewayFirewallDisabledRuleGroup(
1635                    rule_group_name=group.rule_group_name, rules=[])
1636
1637                for rule in group.rules:
1638                    if str(rule.rule_id) in disabled_rules:
1639                        disabled_group.rules.append(rule.rule_id)
1640                if disabled_group.rules:
1641                    disabled_groups.append(disabled_group)
1642        ag.web_application_firewall_configuration.disabled_rule_groups = disabled_groups
1643
1644    if cmd.supported_api_version(min_api='2018-08-01'):
1645        ag.web_application_firewall_configuration.request_body_check = request_body_check
1646        ag.web_application_firewall_configuration.max_request_body_size_in_kb = max_request_body_size
1647        ag.web_application_firewall_configuration.file_upload_limit_in_mb = file_upload_limit
1648        ag.web_application_firewall_configuration.exclusions = exclusions
1649
1650    return sdk_no_wait(no_wait, ncf.begin_create_or_update, resource_group_name, application_gateway_name, ag)
1651
1652
1653def show_ag_waf_config(cmd, resource_group_name, application_gateway_name):
1654    return network_client_factory(cmd.cli_ctx).application_gateways.get(
1655        resource_group_name, application_gateway_name).web_application_firewall_configuration
1656
1657
1658def list_ag_waf_rule_sets(client, _type=None, version=None, group=None):
1659    results = client.list_available_waf_rule_sets().value
1660    filtered_results = []
1661    # filter by rule set name or version
1662    for rule_set in results:
1663        if _type and _type.lower() != rule_set.rule_set_type.lower():
1664            continue
1665        if version and version.lower() != rule_set.rule_set_version.lower():
1666            continue
1667
1668        filtered_groups = []
1669        for rule_group in rule_set.rule_groups:
1670            if not group:
1671                rule_group.rules = None
1672                filtered_groups.append(rule_group)
1673            elif group.lower() == rule_group.rule_group_name.lower() or group == '*':
1674                filtered_groups.append(rule_group)
1675
1676        if filtered_groups:
1677            rule_set.rule_groups = filtered_groups
1678            filtered_results.append(rule_set)
1679
1680    return filtered_results
1681
1682
1683# endregion
1684
1685
1686# region ApplicationGatewayWAFPolicy
1687def create_ag_waf_policy(cmd, client, resource_group_name, policy_name,
1688                         location=None, tags=None, rule_set_type='OWASP',
1689                         rule_set_version='3.0'):
1690    WebApplicationFirewallPolicy, ManagedRulesDefinition, \
1691        ManagedRuleSet = cmd.get_models('WebApplicationFirewallPolicy',
1692                                        'ManagedRulesDefinition',
1693                                        'ManagedRuleSet')
1694    #  https://docs.microsoft.com/en-us/azure/application-gateway/waf-overview
1695
1696    # mandatory default rule with empty rule sets
1697    managed_rule_set = ManagedRuleSet(rule_set_type=rule_set_type, rule_set_version=rule_set_version)
1698    managed_rule_definition = ManagedRulesDefinition(managed_rule_sets=[managed_rule_set])
1699    waf_policy = WebApplicationFirewallPolicy(location=location, tags=tags, managed_rules=managed_rule_definition)
1700    return client.create_or_update(resource_group_name, policy_name, waf_policy)
1701
1702
1703def update_ag_waf_policy(cmd, instance, tags=None):
1704    with cmd.update_context(instance) as c:
1705        c.set_param('tags', tags)
1706    return instance
1707
1708
1709def list_ag_waf_policies(cmd, resource_group_name=None):
1710    return _generic_list(cmd.cli_ctx, 'web_application_firewall_policies', resource_group_name)
1711# endregion
1712
1713
1714# region ApplicationGatewayWAFPolicyRules PolicySettings
1715def update_waf_policy_setting(cmd, instance,
1716                              state=None, mode=None,
1717                              max_request_body_size_in_kb=None, file_upload_limit_in_mb=None,
1718                              request_body_check=False):
1719    if state is not None:
1720        instance.policy_settings.state = state
1721
1722    if mode is not None:
1723        instance.policy_settings.mode = mode
1724
1725    if max_request_body_size_in_kb is not None:
1726        instance.policy_settings.max_request_body_size_in_kb = max_request_body_size_in_kb
1727
1728    if file_upload_limit_in_mb is not None:
1729        instance.policy_settings.file_upload_limit_in_mb = file_upload_limit_in_mb
1730
1731    if request_body_check is not None:
1732        instance.policy_settings.request_body_check = request_body_check
1733
1734    return instance
1735
1736
1737def list_waf_policy_setting(cmd, client, resource_group_name, policy_name):
1738    return client.get(resource_group_name, policy_name).policy_settings
1739# endregion
1740
1741
1742# region ApplicationGatewayWAFPolicyRules
1743def create_waf_custom_rule(cmd, client, resource_group_name, policy_name, rule_name, priority, rule_type, action):
1744    """
1745    Initialize custom rule for WAF policy
1746    """
1747    WebApplicationFirewallCustomRule = cmd.get_models('WebApplicationFirewallCustomRule')
1748    waf_policy = client.get(resource_group_name, policy_name)
1749    new_custom_rule = WebApplicationFirewallCustomRule(
1750        name=rule_name,
1751        action=action,
1752        match_conditions=[],
1753        priority=priority,
1754        rule_type=rule_type
1755    )
1756    upsert_to_collection(waf_policy, 'custom_rules', new_custom_rule, 'name')
1757    parent = client.create_or_update(resource_group_name, policy_name, waf_policy)
1758    return find_child_item(parent, rule_name, path='custom_rules', key_path='name')
1759
1760
1761# pylint: disable=unused-argument
1762def update_waf_custom_rule(instance, parent, cmd, rule_name, priority=None, rule_type=None, action=None):
1763    with cmd.update_context(instance) as c:
1764        c.set_param('priority', priority)
1765        c.set_param('rule_type', rule_type)
1766        c.set_param('action', action)
1767    return parent
1768
1769
1770def show_waf_custom_rule(cmd, client, resource_group_name, policy_name, rule_name):
1771    waf_policy = client.get(resource_group_name, policy_name)
1772    return find_child_item(waf_policy, rule_name, path='custom_rules', key_path='name')
1773
1774
1775def list_waf_custom_rules(cmd, client, resource_group_name, policy_name):
1776    return client.get(resource_group_name, policy_name).custom_rules
1777
1778
1779def delete_waf_custom_rule(cmd, client, resource_group_name, policy_name, rule_name, no_wait=None):
1780    waf_policy = client.get(resource_group_name, policy_name)
1781    rule = find_child_item(waf_policy, rule_name, path='custom_rules', key_path='name')
1782    waf_policy.custom_rules.remove(rule)
1783    sdk_no_wait(no_wait, client.create_or_update, resource_group_name, policy_name, waf_policy)
1784# endregion
1785
1786
1787# region ApplicationGatewayWAFPolicyRuleMatchConditions
1788def add_waf_custom_rule_match_cond(cmd, client, resource_group_name, policy_name, rule_name,
1789                                   match_variables, operator, match_values, negation_condition=None, transforms=None):
1790    MatchCondition = cmd.get_models('MatchCondition')
1791    waf_policy = client.get(resource_group_name, policy_name)
1792    custom_rule = find_child_item(waf_policy, rule_name, path='custom_rules', key_path='name')
1793    new_cond = MatchCondition(
1794        match_variables=match_variables,
1795        operator=operator,
1796        match_values=match_values,
1797        negation_conditon=negation_condition,
1798        transforms=transforms
1799    )
1800    custom_rule.match_conditions.append(new_cond)
1801    upsert_to_collection(waf_policy, 'custom_rules', custom_rule, 'name', warn=False)
1802    client.create_or_update(resource_group_name, policy_name, waf_policy)
1803    return new_cond
1804
1805
1806def list_waf_custom_rule_match_cond(cmd, client, resource_group_name, policy_name, rule_name):
1807    waf_policy = client.get(resource_group_name, policy_name)
1808    return find_child_item(waf_policy, rule_name, path='custom_rules', key_path='name').match_conditions
1809
1810
1811def remove_waf_custom_rule_match_cond(cmd, client, resource_group_name, policy_name, rule_name, index):
1812    waf_policy = client.get(resource_group_name, policy_name)
1813    rule = find_child_item(waf_policy, rule_name, path='custom_rules', key_path='name')
1814    rule.match_conditions.pop(index)
1815    client.create_or_update(resource_group_name, policy_name, waf_policy)
1816# endregion
1817
1818
1819# region ApplicationGatewayWAFPolicy ManagedRule ManagedRuleSet
1820def add_waf_managed_rule_set(cmd, client, resource_group_name, policy_name,
1821                             rule_set_type, rule_set_version,
1822                             rule_group_name=None, rules=None):
1823    """
1824    Add managed rule set to the WAF policy managed rules.
1825    Visit: https://docs.microsoft.com/en-us/azure/web-application-firewall/ag/application-gateway-crs-rulegroups-rules
1826    """
1827    ManagedRuleSet, ManagedRuleGroupOverride, ManagedRuleOverride = \
1828        cmd.get_models('ManagedRuleSet', 'ManagedRuleGroupOverride', 'ManagedRuleOverride')
1829
1830    waf_policy = client.get(resource_group_name, policy_name)
1831
1832    managed_rule_overrides = [ManagedRuleOverride(rule_id=r) for r in rules] if rules is not None else []
1833
1834    rule_group_override = None
1835    if rule_group_name is not None:
1836        rule_group_override = ManagedRuleGroupOverride(rule_group_name=rule_group_name,
1837                                                       rules=managed_rule_overrides)
1838    new_managed_rule_set = ManagedRuleSet(rule_set_type=rule_set_type,
1839                                          rule_set_version=rule_set_version,
1840                                          rule_group_overrides=[rule_group_override] if rule_group_override is not None else [])  # pylint: disable=line-too-long
1841
1842    for rule_set in waf_policy.managed_rules.managed_rule_sets:
1843        if rule_set.rule_set_type == rule_set_type and rule_set.rule_set_version == rule_set_version:
1844            for rule_override in rule_set.rule_group_overrides:
1845                if rule_override.rule_group_name == rule_group_name:
1846                    # Add one rule
1847                    rule_override.rules.extend(managed_rule_overrides)
1848                    break
1849            else:
1850                # Add one rule group
1851                if rule_group_override is not None:
1852                    rule_set.rule_group_overrides.append(rule_group_override)
1853            break
1854    else:
1855        # Add new rule set
1856        waf_policy.managed_rules.managed_rule_sets.append(new_managed_rule_set)
1857
1858    return client.create_or_update(resource_group_name, policy_name, waf_policy)
1859
1860
1861def update_waf_managed_rule_set(cmd, instance, rule_set_type, rule_set_version, rule_group_name=None, rules=None):
1862    """
1863    Update(Override) existing rule set of a WAF policy managed rules.
1864    """
1865    ManagedRuleSet, ManagedRuleGroupOverride, ManagedRuleOverride = \
1866        cmd.get_models('ManagedRuleSet', 'ManagedRuleGroupOverride', 'ManagedRuleOverride')
1867
1868    managed_rule_overrides = [ManagedRuleOverride(rule_id=r) for r in rules] if rules else None
1869
1870    rule_group_override = ManagedRuleGroupOverride(rule_group_name=rule_group_name,
1871                                                   rules=managed_rule_overrides) if managed_rule_overrides else None
1872
1873    new_managed_rule_set = ManagedRuleSet(rule_set_type=rule_set_type,
1874                                          rule_set_version=rule_set_version,
1875                                          rule_group_overrides=[rule_group_override] if rule_group_override is not None else [])  # pylint: disable=line-too-long
1876
1877    updated_rule_set = None
1878
1879    for rule_set in instance.managed_rules.managed_rule_sets:
1880        if rule_set.rule_set_type == rule_set_type and rule_set.rule_set_version != rule_set_version:
1881            updated_rule_set = rule_set
1882            break
1883
1884        if rule_set.rule_set_type == rule_set_type and rule_set.rule_set_version == rule_set_version:
1885            if rule_group_name is None:
1886                updated_rule_set = rule_set
1887                break
1888
1889            rg = next((rg for rg in rule_set.rule_group_overrides if rg.rule_group_name == rule_group_name), None)
1890            if rg:
1891                rg.rules = managed_rule_overrides   # differentiate with add_waf_managed_rule_set()
1892            else:
1893                rule_set.rule_group_overrides.append(rule_group_override)
1894
1895    if updated_rule_set:
1896        instance.managed_rules.managed_rule_sets.remove(updated_rule_set)
1897        instance.managed_rules.managed_rule_sets.append(new_managed_rule_set)
1898
1899    return instance
1900
1901
1902def remove_waf_managed_rule_set(cmd, client, resource_group_name, policy_name,
1903                                rule_set_type, rule_set_version, rule_group_name=None):
1904    """
1905    Remove a managed rule set by rule set group name if rule_group_name is specified. Otherwise, remove all rule set.
1906    """
1907    waf_policy = client.get(resource_group_name, policy_name)
1908
1909    delete_rule_set = None
1910
1911    for rule_set in waf_policy.managed_rules.managed_rule_sets:
1912        if rule_set.rule_set_type == rule_set_type or rule_set.rule_set_version == rule_set_version:
1913            if rule_group_name is None:
1914                delete_rule_set = rule_set
1915                break
1916
1917            # Remove one rule from rule group
1918            rg = next((rg for rg in rule_set.rule_group_overrides if rg.rule_group_name == rule_group_name), None)
1919            if rg is None:
1920                raise CLIError('Rule set group [ {} ] not found.'.format(rule_group_name))
1921            rule_set.rule_group_overrides.remove(rg)
1922
1923    if delete_rule_set:
1924        waf_policy.managed_rules.managed_rule_sets.remove(delete_rule_set)
1925
1926    return client.create_or_update(resource_group_name, policy_name, waf_policy)
1927
1928
1929def list_waf_managed_rule_set(cmd, client, resource_group_name, policy_name):
1930    waf_policy = client.get(resource_group_name, policy_name)
1931    return waf_policy.managed_rules
1932# endregion
1933
1934
1935# region ApplicationGatewayWAFPolicy ManagedRule OwaspCrsExclusionEntry
1936def add_waf_managed_rule_exclusion(cmd, client, resource_group_name, policy_name,
1937                                   match_variable, selector_match_operator, selector):
1938    OwaspCrsExclusionEntry = cmd.get_models('OwaspCrsExclusionEntry')
1939
1940    exclusion_entry = OwaspCrsExclusionEntry(match_variable=match_variable,
1941                                             selector_match_operator=selector_match_operator,
1942                                             selector=selector)
1943
1944    waf_policy = client.get(resource_group_name, policy_name)
1945
1946    waf_policy.managed_rules.exclusions.append(exclusion_entry)
1947
1948    return client.create_or_update(resource_group_name, policy_name, waf_policy)
1949
1950
1951def remove_waf_managed_rule_exclusion(cmd, client, resource_group_name, policy_name):
1952    waf_policy = client.get(resource_group_name, policy_name)
1953    waf_policy.managed_rules.exclusions = []
1954    return client.create_or_update(resource_group_name, policy_name, waf_policy)
1955
1956
1957def list_waf_managed_rule_exclusion(cmd, client, resource_group_name, policy_name):
1958    waf_policy = client.get(resource_group_name, policy_name)
1959    return waf_policy.managed_rules
1960# endregion
1961
1962
1963# region ApplicationSecurityGroups
1964def create_asg(cmd, client, resource_group_name, application_security_group_name, location=None, tags=None):
1965    ApplicationSecurityGroup = cmd.get_models('ApplicationSecurityGroup')
1966    asg = ApplicationSecurityGroup(location=location, tags=tags)
1967    return client.begin_create_or_update(resource_group_name, application_security_group_name, asg)
1968
1969
1970def update_asg(instance, tags=None):
1971    if tags is not None:
1972        instance.tags = tags
1973    return instance
1974# endregion
1975
1976
1977# region DdosProtectionPlans
1978def create_ddos_plan(cmd, resource_group_name, ddos_plan_name, location=None, tags=None, vnets=None):
1979    from azure.cli.core.commands import LongRunningOperation
1980
1981    ddos_client = network_client_factory(cmd.cli_ctx).ddos_protection_plans
1982    ddos_protection_plan = cmd.get_models('DdosProtectionPlan')()
1983    if location:
1984        ddos_protection_plan.location = location
1985    if tags:
1986        ddos_protection_plan.tags = tags
1987    if not vnets:
1988        # if no VNETs can do a simple PUT
1989        return ddos_client.begin_create_or_update(resource_group_name, ddos_plan_name, parameters=ddos_protection_plan)
1990
1991    # if VNETs specified, have to create the protection plan and then add the VNETs
1992    plan_id = LongRunningOperation(cmd.cli_ctx)(
1993        ddos_client.begin_create_or_update(resource_group_name, ddos_plan_name, parameters=ddos_protection_plan)).id
1994
1995    SubResource = cmd.get_models('SubResource')
1996    logger.info('Attempting to attach VNets to newly created DDoS protection plan.')
1997    for vnet_subresource in vnets:
1998        vnet_client = network_client_factory(cmd.cli_ctx).virtual_networks
1999        id_parts = parse_resource_id(vnet_subresource.id)
2000        vnet = vnet_client.get(id_parts['resource_group'], id_parts['name'])
2001        vnet.ddos_protection_plan = SubResource(id=plan_id)
2002        vnet_client.begin_create_or_update(id_parts['resource_group'], id_parts['name'], vnet)
2003    return ddos_client.get(resource_group_name, ddos_plan_name)
2004
2005
2006def update_ddos_plan(cmd, instance, tags=None, vnets=None):
2007    SubResource = cmd.get_models('SubResource')
2008
2009    if tags is not None:
2010        instance.tags = tags
2011    if vnets is not None:
2012        logger.info('Attempting to update the VNets attached to the DDoS protection plan.')
2013        vnet_ids = set([])
2014        if len(vnets) == 1 and not vnets[0]:
2015            pass
2016        else:
2017            vnet_ids = {x.id for x in vnets}
2018        existing_vnet_ids = {x.id for x in instance.virtual_networks} if instance.virtual_networks else set([])
2019        client = network_client_factory(cmd.cli_ctx).virtual_networks
2020        for vnet_id in vnet_ids.difference(existing_vnet_ids):
2021            logger.info("Adding VNet '%s' to plan.", vnet_id)
2022            id_parts = parse_resource_id(vnet_id)
2023            vnet = client.get(id_parts['resource_group'], id_parts['name'])
2024            vnet.ddos_protection_plan = SubResource(id=instance.id)
2025            client.begin_create_or_update(id_parts['resource_group'], id_parts['name'], vnet)
2026        for vnet_id in existing_vnet_ids.difference(vnet_ids):
2027            logger.info("Removing VNet '%s' from plan.", vnet_id)
2028            id_parts = parse_resource_id(vnet_id)
2029            vnet = client.get(id_parts['resource_group'], id_parts['name'])
2030            vnet.ddos_protection_plan = None
2031            client.begin_create_or_update(id_parts['resource_group'], id_parts['name'], vnet)
2032    return instance
2033
2034
2035def list_ddos_plans(cmd, resource_group_name=None):
2036    client = network_client_factory(cmd.cli_ctx).ddos_protection_plans
2037    if resource_group_name:
2038        return client.list_by_resource_group(resource_group_name)
2039    return client.list()
2040# endregion
2041
2042
2043# region DNS Commands
2044# add delegation name server record for the created child zone in it's parent zone.
2045def add_dns_delegation(cmd, child_zone, parent_zone, child_rg, child_zone_name):
2046    """
2047     :param child_zone: the zone object corresponding to the child that is created.
2048     :param parent_zone: the parent zone name / FQDN of the parent zone.
2049                         if parent zone name is mentioned, assume current subscription and resource group.
2050     :param child_rg: resource group of the child zone
2051     :param child_zone_name: name of the child zone
2052    """
2053    import sys
2054    from azure.core.exceptions import HttpResponseError
2055    parent_rg = child_rg
2056    parent_subscription_id = None
2057    parent_zone_name = parent_zone
2058
2059    if is_valid_resource_id(parent_zone):
2060        id_parts = parse_resource_id(parent_zone)
2061        parent_rg = id_parts['resource_group']
2062        parent_subscription_id = id_parts['subscription']
2063        parent_zone_name = id_parts['name']
2064
2065    if all([parent_zone_name, parent_rg, child_zone_name, child_zone]) and child_zone_name.endswith(parent_zone_name):
2066        record_set_name = child_zone_name.replace('.' + parent_zone_name, '')
2067        try:
2068            for dname in child_zone.name_servers:
2069                add_dns_ns_record(cmd, parent_rg, parent_zone_name, record_set_name, dname, parent_subscription_id)
2070            print('Delegation added succesfully in \'{}\'\n'.format(parent_zone_name), file=sys.stderr)
2071        except HttpResponseError as ex:
2072            logger.error(ex)
2073            print('Could not add delegation in \'{}\'\n'.format(parent_zone_name), file=sys.stderr)
2074
2075
2076def create_dns_zone(cmd, client, resource_group_name, zone_name, parent_zone_name=None, tags=None,
2077                    if_none_match=False, zone_type='Public', resolution_vnets=None, registration_vnets=None):
2078    Zone = cmd.get_models('Zone', resource_type=ResourceType.MGMT_NETWORK_DNS)
2079    zone = Zone(location='global', tags=tags)
2080
2081    if hasattr(zone, 'zone_type'):
2082        zone.zone_type = zone_type
2083        zone.registration_virtual_networks = registration_vnets
2084        zone.resolution_virtual_networks = resolution_vnets
2085
2086    created_zone = client.create_or_update(resource_group_name, zone_name, zone,
2087                                           if_none_match='*' if if_none_match else None)
2088
2089    if cmd.supported_api_version(min_api='2016-04-01') and parent_zone_name is not None:
2090        logger.info('Attempting to add delegation in the parent zone')
2091        add_dns_delegation(cmd, created_zone, parent_zone_name, resource_group_name, zone_name)
2092    return created_zone
2093
2094
2095def update_dns_zone(instance, tags=None, zone_type=None, resolution_vnets=None, registration_vnets=None):
2096
2097    if tags is not None:
2098        instance.tags = tags
2099
2100    if zone_type:
2101        instance.zone_type = zone_type
2102
2103    if resolution_vnets == ['']:
2104        instance.resolution_virtual_networks = None
2105    elif resolution_vnets:
2106        instance.resolution_virtual_networks = resolution_vnets
2107
2108    if registration_vnets == ['']:
2109        instance.registration_virtual_networks = None
2110    elif registration_vnets:
2111        instance.registration_virtual_networks = registration_vnets
2112    return instance
2113
2114
2115def list_dns_zones(cmd, resource_group_name=None):
2116    ncf = get_mgmt_service_client(cmd.cli_ctx, ResourceType.MGMT_NETWORK_DNS).zones
2117    if resource_group_name:
2118        return ncf.list_by_resource_group(resource_group_name)
2119    return ncf.list()
2120
2121
2122def create_dns_record_set(cmd, resource_group_name, zone_name, record_set_name, record_set_type,
2123                          metadata=None, if_match=None, if_none_match=None, ttl=3600, target_resource=None):
2124
2125    RecordSet = cmd.get_models('RecordSet', resource_type=ResourceType.MGMT_NETWORK_DNS)
2126    SubResource = cmd.get_models('SubResource', resource_type=ResourceType.MGMT_NETWORK)
2127    client = get_mgmt_service_client(cmd.cli_ctx, ResourceType.MGMT_NETWORK_DNS).record_sets
2128    record_set = RecordSet(
2129        ttl=ttl,
2130        metadata=metadata,
2131        target_resource=SubResource(id=target_resource) if target_resource else None
2132    )
2133    return client.create_or_update(resource_group_name, zone_name, record_set_name,
2134                                   record_set_type, record_set, if_match=if_match,
2135                                   if_none_match='*' if if_none_match else None)
2136
2137
2138def list_dns_record_set(client, resource_group_name, zone_name, record_type=None):
2139    if record_type:
2140        return client.list_by_type(resource_group_name, zone_name, record_type)
2141
2142    return client.list_by_dns_zone(resource_group_name, zone_name)
2143
2144
2145def update_dns_record_set(instance, cmd, metadata=None, target_resource=None):
2146    if metadata is not None:
2147        instance.metadata = metadata
2148    if target_resource == '':
2149        instance.target_resource = None
2150    elif target_resource is not None:
2151        SubResource = cmd.get_models('SubResource')
2152        instance.target_resource = SubResource(id=target_resource)
2153    return instance
2154
2155
2156def _type_to_property_name(key):
2157    type_dict = {
2158        'a': 'a_records',
2159        'aaaa': 'aaaa_records',
2160        'caa': 'caa_records',
2161        'cname': 'cname_record',
2162        'mx': 'mx_records',
2163        'ns': 'ns_records',
2164        'ptr': 'ptr_records',
2165        'soa': 'soa_record',
2166        'spf': 'txt_records',
2167        'srv': 'srv_records',
2168        'txt': 'txt_records',
2169    }
2170    return type_dict[key.lower()]
2171
2172
2173def export_zone(cmd, resource_group_name, zone_name, file_name=None):
2174    from time import localtime, strftime
2175
2176    client = get_mgmt_service_client(cmd.cli_ctx, ResourceType.MGMT_NETWORK_DNS)
2177    record_sets = client.record_sets.list_by_dns_zone(resource_group_name, zone_name)
2178
2179    zone_obj = OrderedDict({
2180        '$origin': zone_name.rstrip('.') + '.',
2181        'resource-group': resource_group_name,
2182        'zone-name': zone_name.rstrip('.'),
2183        'datetime': strftime('%a, %d %b %Y %X %z', localtime())
2184    })
2185
2186    for record_set in record_sets:
2187        record_type = record_set.type.rsplit('/', 1)[1].lower()
2188        record_set_name = record_set.name
2189        record_data = getattr(record_set, _type_to_property_name(record_type), None)
2190
2191        # ignore empty record sets
2192        if not record_data:
2193            continue
2194
2195        if not isinstance(record_data, list):
2196            record_data = [record_data]
2197
2198        if record_set_name not in zone_obj:
2199            zone_obj[record_set_name] = OrderedDict()
2200
2201        for record in record_data:
2202
2203            record_obj = {'ttl': record_set.ttl}
2204
2205            if record_type not in zone_obj[record_set_name]:
2206                zone_obj[record_set_name][record_type] = []
2207
2208            if record_type == 'aaaa':
2209                record_obj.update({'ip': record.ipv6_address})
2210            elif record_type == 'a':
2211                record_obj.update({'ip': record.ipv4_address})
2212            elif record_type == 'caa':
2213                record_obj.update({'val': record.value, 'tag': record.tag, 'flags': record.flags})
2214            elif record_type == 'cname':
2215                record_obj.update({'alias': record.cname.rstrip('.') + '.'})
2216            elif record_type == 'mx':
2217                record_obj.update({'preference': record.preference, 'host': record.exchange.rstrip('.') + '.'})
2218            elif record_type == 'ns':
2219                record_obj.update({'host': record.nsdname.rstrip('.') + '.'})
2220            elif record_type == 'ptr':
2221                record_obj.update({'host': record.ptrdname.rstrip('.') + '.'})
2222            elif record_type == 'soa':
2223                record_obj.update({
2224                    'mname': record.host.rstrip('.') + '.',
2225                    'rname': record.email.rstrip('.') + '.',
2226                    'serial': int(record.serial_number), 'refresh': record.refresh_time,
2227                    'retry': record.retry_time, 'expire': record.expire_time,
2228                    'minimum': record.minimum_ttl
2229                })
2230                zone_obj['$ttl'] = record.minimum_ttl
2231            elif record_type == 'srv':
2232                record_obj.update({'priority': record.priority, 'weight': record.weight,
2233                                   'port': record.port, 'target': record.target.rstrip('.') + '.'})
2234            elif record_type == 'txt':
2235                record_obj.update({'txt': ''.join(record.value)})
2236
2237            zone_obj[record_set_name][record_type].append(record_obj)
2238
2239    zone_file_content = make_zone_file(zone_obj)
2240    print(zone_file_content)
2241    if file_name:
2242        try:
2243            with open(file_name, 'w') as f:
2244                f.write(zone_file_content)
2245        except IOError:
2246            raise CLIError('Unable to export to file: {}'.format(file_name))
2247
2248
2249# pylint: disable=too-many-return-statements, inconsistent-return-statements
2250def _build_record(cmd, data):
2251    AaaaRecord, ARecord, CaaRecord, CnameRecord, MxRecord, NsRecord, PtrRecord, SoaRecord, SrvRecord, TxtRecord = \
2252        cmd.get_models('AaaaRecord', 'ARecord', 'CaaRecord', 'CnameRecord', 'MxRecord', 'NsRecord',
2253                       'PtrRecord', 'SoaRecord', 'SrvRecord', 'TxtRecord', resource_type=ResourceType.MGMT_NETWORK_DNS)
2254    record_type = data['delim'].lower()
2255    try:
2256        if record_type == 'aaaa':
2257            return AaaaRecord(ipv6_address=data['ip'])
2258        if record_type == 'a':
2259            return ARecord(ipv4_address=data['ip'])
2260        if (record_type == 'caa' and
2261                supported_api_version(cmd.cli_ctx, ResourceType.MGMT_NETWORK_DNS, min_api='2018-03-01-preview')):
2262            return CaaRecord(value=data['val'], flags=int(data['flags']), tag=data['tag'])
2263        if record_type == 'cname':
2264            return CnameRecord(cname=data['alias'])
2265        if record_type == 'mx':
2266            return MxRecord(preference=data['preference'], exchange=data['host'])
2267        if record_type == 'ns':
2268            return NsRecord(nsdname=data['host'])
2269        if record_type == 'ptr':
2270            return PtrRecord(ptrdname=data['host'])
2271        if record_type == 'soa':
2272            return SoaRecord(host=data['host'], email=data['email'], serial_number=data['serial'],
2273                             refresh_time=data['refresh'], retry_time=data['retry'], expire_time=data['expire'],
2274                             minimum_ttl=data['minimum'])
2275        if record_type == 'srv':
2276            return SrvRecord(
2277                priority=int(data['priority']), weight=int(data['weight']), port=int(data['port']),
2278                target=data['target'])
2279        if record_type in ['txt', 'spf']:
2280            text_data = data['txt']
2281            return TxtRecord(value=text_data) if isinstance(text_data, list) else TxtRecord(value=[text_data])
2282    except KeyError as ke:
2283        raise CLIError("The {} record '{}' is missing a property.  {}"
2284                       .format(record_type, data['name'], ke))
2285
2286
2287# pylint: disable=too-many-statements
2288def import_zone(cmd, resource_group_name, zone_name, file_name):
2289    from azure.cli.core.util import read_file_content
2290    from azure.core.exceptions import HttpResponseError
2291    import sys
2292    logger.warning("In the future, zone name will be case insensitive.")
2293    RecordSet = cmd.get_models('RecordSet', resource_type=ResourceType.MGMT_NETWORK_DNS)
2294
2295    from azure.cli.core.azclierror import FileOperationError, UnclassifiedUserFault
2296    try:
2297        file_text = read_file_content(file_name)
2298    except FileNotFoundError:
2299        raise FileOperationError("No such file: " + str(file_name))
2300    except IsADirectoryError:
2301        raise FileOperationError("Is a directory: " + str(file_name))
2302    except PermissionError:
2303        raise FileOperationError("Permission denied: " + str(file_name))
2304    except OSError as e:
2305        raise UnclassifiedUserFault(e)
2306
2307    zone_obj = parse_zone_file(file_text, zone_name)
2308
2309    origin = zone_name
2310    record_sets = {}
2311    for record_set_name in zone_obj:
2312        for record_set_type in zone_obj[record_set_name]:
2313            record_set_obj = zone_obj[record_set_name][record_set_type]
2314
2315            if record_set_type == 'soa':
2316                origin = record_set_name.rstrip('.')
2317
2318            if not isinstance(record_set_obj, list):
2319                record_set_obj = [record_set_obj]
2320
2321            for entry in record_set_obj:
2322
2323                record_set_ttl = entry['ttl']
2324                record_set_key = '{}{}'.format(record_set_name.lower(), record_set_type)
2325
2326                record = _build_record(cmd, entry)
2327                if not record:
2328                    logger.warning('Cannot import %s. RecordType is not found. Skipping...', entry['delim'].lower())
2329                    continue
2330
2331                record_set = record_sets.get(record_set_key, None)
2332                if not record_set:
2333
2334                    # Workaround for issue #2824
2335                    relative_record_set_name = record_set_name.rstrip('.')
2336                    if not relative_record_set_name.endswith(origin):
2337                        logger.warning(
2338                            'Cannot import %s. Only records relative to origin may be '
2339                            'imported at this time. Skipping...', relative_record_set_name)
2340                        continue
2341
2342                    record_set = RecordSet(ttl=record_set_ttl)
2343                    record_sets[record_set_key] = record_set
2344                _add_record(record_set, record, record_set_type,
2345                            is_list=record_set_type.lower() not in ['soa', 'cname'])
2346
2347    total_records = 0
2348    for key, rs in record_sets.items():
2349        rs_name, rs_type = key.lower().rsplit('.', 1)
2350        rs_name = rs_name[:-(len(origin) + 1)] if rs_name != origin else '@'
2351        try:
2352            record_count = len(getattr(rs, _type_to_property_name(rs_type)))
2353        except TypeError:
2354            record_count = 1
2355        total_records += record_count
2356    cum_records = 0
2357
2358    client = get_mgmt_service_client(cmd.cli_ctx, ResourceType.MGMT_NETWORK_DNS)
2359    print('== BEGINNING ZONE IMPORT: {} ==\n'.format(zone_name), file=sys.stderr)
2360
2361    Zone = cmd.get_models('Zone', resource_type=ResourceType.MGMT_NETWORK_DNS)
2362    client.zones.create_or_update(resource_group_name, zone_name, Zone(location='global'))
2363    for key, rs in record_sets.items():
2364
2365        rs_name, rs_type = key.lower().rsplit('.', 1)
2366        rs_name = '@' if rs_name == origin else rs_name
2367        if rs_name.endswith(origin):
2368            rs_name = rs_name[:-(len(origin) + 1)]
2369
2370        try:
2371            record_count = len(getattr(rs, _type_to_property_name(rs_type)))
2372        except TypeError:
2373            record_count = 1
2374        if rs_name == '@' and rs_type == 'soa':
2375            root_soa = client.record_sets.get(resource_group_name, zone_name, '@', 'SOA')
2376            rs.soa_record.host = root_soa.soa_record.host
2377            rs_name = '@'
2378        elif rs_name == '@' and rs_type == 'ns':
2379            root_ns = client.record_sets.get(resource_group_name, zone_name, '@', 'NS')
2380            root_ns.ttl = rs.ttl
2381            rs = root_ns
2382            rs_type = rs.type.rsplit('/', 1)[1]
2383        try:
2384            client.record_sets.create_or_update(
2385                resource_group_name, zone_name, rs_name, rs_type, rs)
2386            cum_records += record_count
2387            print("({}/{}) Imported {} records of type '{}' and name '{}'"
2388                  .format(cum_records, total_records, record_count, rs_type, rs_name), file=sys.stderr)
2389        except HttpResponseError as ex:
2390            logger.error(ex)
2391    print("\n== {}/{} RECORDS IMPORTED SUCCESSFULLY: '{}' =="
2392          .format(cum_records, total_records, zone_name), file=sys.stderr)
2393
2394
2395def add_dns_aaaa_record(cmd, resource_group_name, zone_name, record_set_name, ipv6_address,
2396                        ttl=3600, if_none_match=None):
2397    AaaaRecord = cmd.get_models('AaaaRecord', resource_type=ResourceType.MGMT_NETWORK_DNS)
2398    record = AaaaRecord(ipv6_address=ipv6_address)
2399    record_type = 'aaaa'
2400    return _add_save_record(cmd, record, record_type, record_set_name, resource_group_name, zone_name,
2401                            ttl=ttl, if_none_match=if_none_match)
2402
2403
2404def add_dns_a_record(cmd, resource_group_name, zone_name, record_set_name, ipv4_address,
2405                     ttl=3600, if_none_match=None):
2406    ARecord = cmd.get_models('ARecord', resource_type=ResourceType.MGMT_NETWORK_DNS)
2407    record = ARecord(ipv4_address=ipv4_address)
2408    record_type = 'a'
2409    return _add_save_record(cmd, record, record_type, record_set_name, resource_group_name, zone_name, 'arecords',
2410                            ttl=ttl, if_none_match=if_none_match)
2411
2412
2413def add_dns_caa_record(cmd, resource_group_name, zone_name, record_set_name, value, flags, tag,
2414                       ttl=3600, if_none_match=None):
2415    CaaRecord = cmd.get_models('CaaRecord', resource_type=ResourceType.MGMT_NETWORK_DNS)
2416    record = CaaRecord(flags=flags, tag=tag, value=value)
2417    record_type = 'caa'
2418    return _add_save_record(cmd, record, record_type, record_set_name, resource_group_name, zone_name,
2419                            ttl=ttl, if_none_match=if_none_match)
2420
2421
2422def add_dns_cname_record(cmd, resource_group_name, zone_name, record_set_name, cname, ttl=3600, if_none_match=None):
2423    CnameRecord = cmd.get_models('CnameRecord', resource_type=ResourceType.MGMT_NETWORK_DNS)
2424    record = CnameRecord(cname=cname)
2425    record_type = 'cname'
2426    return _add_save_record(cmd, record, record_type, record_set_name, resource_group_name, zone_name,
2427                            is_list=False, ttl=ttl, if_none_match=if_none_match)
2428
2429
2430def add_dns_mx_record(cmd, resource_group_name, zone_name, record_set_name, preference, exchange,
2431                      ttl=3600, if_none_match=None):
2432    MxRecord = cmd.get_models('MxRecord', resource_type=ResourceType.MGMT_NETWORK_DNS)
2433    record = MxRecord(preference=int(preference), exchange=exchange)
2434    record_type = 'mx'
2435    return _add_save_record(cmd, record, record_type, record_set_name, resource_group_name, zone_name,
2436                            ttl=ttl, if_none_match=if_none_match)
2437
2438
2439def add_dns_ns_record(cmd, resource_group_name, zone_name, record_set_name, dname,
2440                      subscription_id=None, ttl=3600, if_none_match=None):
2441    NsRecord = cmd.get_models('NsRecord', resource_type=ResourceType.MGMT_NETWORK_DNS)
2442    record = NsRecord(nsdname=dname)
2443    record_type = 'ns'
2444    return _add_save_record(cmd, record, record_type, record_set_name, resource_group_name, zone_name,
2445                            subscription_id=subscription_id, ttl=ttl, if_none_match=if_none_match)
2446
2447
2448def add_dns_ptr_record(cmd, resource_group_name, zone_name, record_set_name, dname, ttl=3600, if_none_match=None):
2449    PtrRecord = cmd.get_models('PtrRecord', resource_type=ResourceType.MGMT_NETWORK_DNS)
2450    record = PtrRecord(ptrdname=dname)
2451    record_type = 'ptr'
2452    return _add_save_record(cmd, record, record_type, record_set_name, resource_group_name, zone_name,
2453                            ttl=ttl, if_none_match=if_none_match)
2454
2455
2456def update_dns_soa_record(cmd, resource_group_name, zone_name, host=None, email=None,
2457                          serial_number=None, refresh_time=None, retry_time=None, expire_time=None,
2458                          minimum_ttl=3600, if_none_match=None):
2459    record_set_name = '@'
2460    record_type = 'soa'
2461
2462    ncf = get_mgmt_service_client(cmd.cli_ctx, ResourceType.MGMT_NETWORK_DNS).record_sets
2463    record_set = ncf.get(resource_group_name, zone_name, record_set_name, record_type)
2464    record = record_set.soa_record
2465
2466    record.host = host or record.host
2467    record.email = email or record.email
2468    record.serial_number = serial_number or record.serial_number
2469    record.refresh_time = refresh_time or record.refresh_time
2470    record.retry_time = retry_time or record.retry_time
2471    record.expire_time = expire_time or record.expire_time
2472    record.minimum_ttl = minimum_ttl or record.minimum_ttl
2473
2474    return _add_save_record(cmd, record, record_type, record_set_name, resource_group_name, zone_name,
2475                            is_list=False, if_none_match=if_none_match)
2476
2477
2478def add_dns_srv_record(cmd, resource_group_name, zone_name, record_set_name, priority, weight,
2479                       port, target, if_none_match=None):
2480    SrvRecord = cmd.get_models('SrvRecord', resource_type=ResourceType.MGMT_NETWORK_DNS)
2481    record = SrvRecord(priority=priority, weight=weight, port=port, target=target)
2482    record_type = 'srv'
2483    return _add_save_record(cmd, record, record_type, record_set_name, resource_group_name, zone_name,
2484                            if_none_match=if_none_match)
2485
2486
2487def add_dns_txt_record(cmd, resource_group_name, zone_name, record_set_name, value, if_none_match=None):
2488    TxtRecord = cmd.get_models('TxtRecord', resource_type=ResourceType.MGMT_NETWORK_DNS)
2489    record = TxtRecord(value=value)
2490    record_type = 'txt'
2491    long_text = ''.join(x for x in record.value)
2492    original_len = len(long_text)
2493    record.value = []
2494    while len(long_text) > 255:
2495        record.value.append(long_text[:255])
2496        long_text = long_text[255:]
2497    record.value.append(long_text)
2498    final_str = ''.join(record.value)
2499    final_len = len(final_str)
2500    assert original_len == final_len
2501    return _add_save_record(cmd, record, record_type, record_set_name, resource_group_name, zone_name,
2502                            if_none_match=if_none_match)
2503
2504
2505def remove_dns_aaaa_record(cmd, resource_group_name, zone_name, record_set_name, ipv6_address,
2506                           keep_empty_record_set=False):
2507    AaaaRecord = cmd.get_models('AaaaRecord', resource_type=ResourceType.MGMT_NETWORK_DNS)
2508    record = AaaaRecord(ipv6_address=ipv6_address)
2509    record_type = 'aaaa'
2510    return _remove_record(cmd.cli_ctx, record, record_type, record_set_name, resource_group_name, zone_name,
2511                          keep_empty_record_set=keep_empty_record_set)
2512
2513
2514def remove_dns_a_record(cmd, resource_group_name, zone_name, record_set_name, ipv4_address,
2515                        keep_empty_record_set=False):
2516    ARecord = cmd.get_models('ARecord', resource_type=ResourceType.MGMT_NETWORK_DNS)
2517    record = ARecord(ipv4_address=ipv4_address)
2518    record_type = 'a'
2519    return _remove_record(cmd.cli_ctx, record, record_type, record_set_name, resource_group_name, zone_name,
2520                          keep_empty_record_set=keep_empty_record_set)
2521
2522
2523def remove_dns_caa_record(cmd, resource_group_name, zone_name, record_set_name, value,
2524                          flags, tag, keep_empty_record_set=False):
2525    CaaRecord = cmd.get_models('CaaRecord', resource_type=ResourceType.MGMT_NETWORK_DNS)
2526    record = CaaRecord(flags=flags, tag=tag, value=value)
2527    record_type = 'caa'
2528    return _remove_record(cmd.cli_ctx, record, record_type, record_set_name, resource_group_name, zone_name,
2529                          keep_empty_record_set=keep_empty_record_set)
2530
2531
2532def remove_dns_cname_record(cmd, resource_group_name, zone_name, record_set_name, cname,
2533                            keep_empty_record_set=False):
2534    CnameRecord = cmd.get_models('CnameRecord', resource_type=ResourceType.MGMT_NETWORK_DNS)
2535    record = CnameRecord(cname=cname)
2536    record_type = 'cname'
2537    return _remove_record(cmd.cli_ctx, record, record_type, record_set_name, resource_group_name, zone_name,
2538                          is_list=False, keep_empty_record_set=keep_empty_record_set)
2539
2540
2541def remove_dns_mx_record(cmd, resource_group_name, zone_name, record_set_name, preference, exchange,
2542                         keep_empty_record_set=False):
2543    MxRecord = cmd.get_models('MxRecord', resource_type=ResourceType.MGMT_NETWORK_DNS)
2544    record = MxRecord(preference=int(preference), exchange=exchange)
2545    record_type = 'mx'
2546    return _remove_record(cmd.cli_ctx, record, record_type, record_set_name, resource_group_name, zone_name,
2547                          keep_empty_record_set=keep_empty_record_set)
2548
2549
2550def remove_dns_ns_record(cmd, resource_group_name, zone_name, record_set_name, dname,
2551                         keep_empty_record_set=False):
2552    NsRecord = cmd.get_models('NsRecord', resource_type=ResourceType.MGMT_NETWORK_DNS)
2553    record = NsRecord(nsdname=dname)
2554    record_type = 'ns'
2555    return _remove_record(cmd.cli_ctx, record, record_type, record_set_name, resource_group_name, zone_name,
2556                          keep_empty_record_set=keep_empty_record_set)
2557
2558
2559def remove_dns_ptr_record(cmd, resource_group_name, zone_name, record_set_name, dname,
2560                          keep_empty_record_set=False):
2561    PtrRecord = cmd.get_models('PtrRecord', resource_type=ResourceType.MGMT_NETWORK_DNS)
2562    record = PtrRecord(ptrdname=dname)
2563    record_type = 'ptr'
2564    return _remove_record(cmd.cli_ctx, record, record_type, record_set_name, resource_group_name, zone_name,
2565                          keep_empty_record_set=keep_empty_record_set)
2566
2567
2568def remove_dns_srv_record(cmd, resource_group_name, zone_name, record_set_name, priority, weight,
2569                          port, target, keep_empty_record_set=False):
2570    SrvRecord = cmd.get_models('SrvRecord', resource_type=ResourceType.MGMT_NETWORK_DNS)
2571    record = SrvRecord(priority=priority, weight=weight, port=port, target=target)
2572    record_type = 'srv'
2573    return _remove_record(cmd.cli_ctx, record, record_type, record_set_name, resource_group_name, zone_name,
2574                          keep_empty_record_set=keep_empty_record_set)
2575
2576
2577def remove_dns_txt_record(cmd, resource_group_name, zone_name, record_set_name, value,
2578                          keep_empty_record_set=False):
2579    TxtRecord = cmd.get_models('TxtRecord', resource_type=ResourceType.MGMT_NETWORK_DNS)
2580    record = TxtRecord(value=value)
2581    record_type = 'txt'
2582    return _remove_record(cmd.cli_ctx, record, record_type, record_set_name, resource_group_name, zone_name,
2583                          keep_empty_record_set=keep_empty_record_set)
2584
2585
2586def _check_a_record_exist(record, exist_list):
2587    for r in exist_list:
2588        if r.ipv4_address == record.ipv4_address:
2589            return True
2590    return False
2591
2592
2593def _check_aaaa_record_exist(record, exist_list):
2594    for r in exist_list:
2595        if r.ipv6_address == record.ipv6_address:
2596            return True
2597    return False
2598
2599
2600def _check_caa_record_exist(record, exist_list):
2601    for r in exist_list:
2602        if (r.flags == record.flags and
2603                r.tag == record.tag and
2604                r.value == record.value):
2605            return True
2606    return False
2607
2608
2609def _check_cname_record_exist(record, exist_list):
2610    for r in exist_list:
2611        if r.cname == record.cname:
2612            return True
2613    return False
2614
2615
2616def _check_mx_record_exist(record, exist_list):
2617    for r in exist_list:
2618        if (r.preference == record.preference and
2619                r.exchange == record.exchange):
2620            return True
2621    return False
2622
2623
2624def _check_ns_record_exist(record, exist_list):
2625    for r in exist_list:
2626        if r.nsdname == record.nsdname:
2627            return True
2628    return False
2629
2630
2631def _check_ptr_record_exist(record, exist_list):
2632    for r in exist_list:
2633        if r.ptrdname == record.ptrdname:
2634            return True
2635    return False
2636
2637
2638def _check_srv_record_exist(record, exist_list):
2639    for r in exist_list:
2640        if (r.priority == record.priority and
2641                r.weight == record.weight and
2642                r.port == record.port and
2643                r.target == record.target):
2644            return True
2645    return False
2646
2647
2648def _check_txt_record_exist(record, exist_list):
2649    for r in exist_list:
2650        if r.value == record.value:
2651            return True
2652    return False
2653
2654
2655def _record_exist_func(record_type):
2656    return globals()["_check_{}_record_exist".format(record_type)]
2657
2658
2659def _add_record(record_set, record, record_type, is_list=False):
2660    record_property = _type_to_property_name(record_type)
2661
2662    if is_list:
2663        record_list = getattr(record_set, record_property)
2664        if record_list is None:
2665            setattr(record_set, record_property, [])
2666            record_list = getattr(record_set, record_property)
2667
2668        _record_exist = _record_exist_func(record_type)
2669        if not _record_exist(record, record_list):
2670            record_list.append(record)
2671    else:
2672        setattr(record_set, record_property, record)
2673
2674
2675def _add_save_record(cmd, record, record_type, record_set_name, resource_group_name, zone_name,
2676                     is_list=True, subscription_id=None, ttl=None, if_none_match=None):
2677    from azure.core.exceptions import HttpResponseError
2678    ncf = get_mgmt_service_client(cmd.cli_ctx, ResourceType.MGMT_NETWORK_DNS,
2679                                  subscription_id=subscription_id).record_sets
2680
2681    try:
2682        record_set = ncf.get(resource_group_name, zone_name, record_set_name, record_type)
2683    except HttpResponseError:
2684        RecordSet = cmd.get_models('RecordSet', resource_type=ResourceType.MGMT_NETWORK_DNS)
2685        record_set = RecordSet(ttl=3600)
2686
2687    if ttl is not None:
2688        record_set.ttl = ttl
2689
2690    _add_record(record_set, record, record_type, is_list)
2691
2692    return ncf.create_or_update(resource_group_name, zone_name, record_set_name,
2693                                record_type, record_set,
2694                                if_none_match='*' if if_none_match else None)
2695
2696
2697def _remove_record(cli_ctx, record, record_type, record_set_name, resource_group_name, zone_name,
2698                   keep_empty_record_set, is_list=True):
2699    ncf = get_mgmt_service_client(cli_ctx, ResourceType.MGMT_NETWORK_DNS).record_sets
2700    record_set = ncf.get(resource_group_name, zone_name, record_set_name, record_type)
2701    record_property = _type_to_property_name(record_type)
2702
2703    if is_list:
2704        record_list = getattr(record_set, record_property)
2705        if record_list is not None:
2706            keep_list = [r for r in record_list
2707                         if not dict_matches_filter(r.__dict__, record.__dict__)]
2708            if len(keep_list) == len(record_list):
2709                raise CLIError('Record {} not found.'.format(str(record)))
2710            setattr(record_set, record_property, keep_list)
2711    else:
2712        setattr(record_set, record_property, None)
2713
2714    if is_list:
2715        records_remaining = len(getattr(record_set, record_property))
2716    else:
2717        records_remaining = 1 if getattr(record_set, record_property) is not None else 0
2718
2719    if not records_remaining and not keep_empty_record_set:
2720        logger.info('Removing empty %s record set: %s', record_type, record_set_name)
2721        return ncf.delete(resource_group_name, zone_name, record_set_name, record_type)
2722
2723    return ncf.create_or_update(resource_group_name, zone_name, record_set_name, record_type, record_set)
2724
2725
2726def dict_matches_filter(d, filter_dict):
2727    sentinel = object()
2728    return all(not filter_dict.get(key, None) or
2729               str(filter_dict[key]) == str(d.get(key, sentinel)) or
2730               lists_match(filter_dict[key], d.get(key, []))
2731               for key in filter_dict)
2732
2733
2734def lists_match(l1, l2):
2735    try:
2736        return Counter(l1) == Counter(l2)  # pylint: disable=too-many-function-args
2737    except TypeError:
2738        return False
2739# endregion
2740
2741
2742# region ExpressRoutes
2743def create_express_route(cmd, circuit_name, resource_group_name, bandwidth_in_mbps, peering_location,
2744                         service_provider_name, location=None, tags=None, no_wait=False,
2745                         sku_family=None, sku_tier=None, allow_global_reach=None, express_route_port=None,
2746                         allow_classic_operations=None):
2747    ExpressRouteCircuit, ExpressRouteCircuitSku, ExpressRouteCircuitServiceProviderProperties, SubResource = \
2748        cmd.get_models(
2749            'ExpressRouteCircuit', 'ExpressRouteCircuitSku', 'ExpressRouteCircuitServiceProviderProperties',
2750            'SubResource')
2751    client = network_client_factory(cmd.cli_ctx).express_route_circuits
2752    sku_name = '{}_{}'.format(sku_tier, sku_family)
2753    circuit = ExpressRouteCircuit(
2754        location=location, tags=tags,
2755        service_provider_properties=ExpressRouteCircuitServiceProviderProperties(
2756            service_provider_name=service_provider_name,
2757            peering_location=peering_location,
2758            bandwidth_in_mbps=bandwidth_in_mbps if not express_route_port else None),
2759        sku=ExpressRouteCircuitSku(name=sku_name, tier=sku_tier, family=sku_family),
2760        allow_global_reach=allow_global_reach,
2761        bandwidth_in_gbps=(int(bandwidth_in_mbps) / 1000) if express_route_port else None
2762    )
2763    if cmd.supported_api_version(min_api='2010-07-01') and allow_classic_operations is not None:
2764        circuit.allow_classic_operations = allow_classic_operations
2765    if cmd.supported_api_version(min_api='2018-08-01') and express_route_port:
2766        circuit.express_route_port = SubResource(id=express_route_port)
2767        circuit.service_provider_properties = None
2768    return sdk_no_wait(no_wait, client.begin_create_or_update, resource_group_name, circuit_name, circuit)
2769
2770
2771def update_express_route(instance, cmd, bandwidth_in_mbps=None, peering_location=None,
2772                         service_provider_name=None, sku_family=None, sku_tier=None, tags=None,
2773                         allow_global_reach=None, express_route_port=None,
2774                         allow_classic_operations=None):
2775
2776    with cmd.update_context(instance) as c:
2777        c.set_param('allow_classic_operations', allow_classic_operations)
2778        c.set_param('tags', tags)
2779        c.set_param('allow_global_reach', allow_global_reach)
2780
2781    with cmd.update_context(instance.sku) as c:
2782        c.set_param('family', sku_family)
2783        c.set_param('tier', sku_tier)
2784
2785    with cmd.update_context(instance.service_provider_properties) as c:
2786        c.set_param('peering_location', peering_location)
2787        c.set_param('service_provider_name', service_provider_name)
2788
2789    if express_route_port is not None:
2790        SubResource = cmd.get_models('SubResource')
2791        instance.express_route_port = SubResource(id=express_route_port)
2792        instance.service_provider_properties = None
2793
2794    if bandwidth_in_mbps is not None:
2795        if not instance.express_route_port:
2796            instance.service_provider_properties.bandwith_in_mbps = float(bandwidth_in_mbps)
2797        else:
2798            instance.bandwidth_in_gbps = (float(bandwidth_in_mbps) / 1000)
2799
2800    return instance
2801
2802
2803def create_express_route_peering_connection(cmd, resource_group_name, circuit_name, peering_name, connection_name,
2804                                            peer_circuit, address_prefix, authorization_key=None):
2805    client = network_client_factory(cmd.cli_ctx).express_route_circuit_connections
2806    ExpressRouteCircuitConnection, SubResource = cmd.get_models('ExpressRouteCircuitConnection', 'SubResource')
2807    source_circuit = resource_id(
2808        subscription=get_subscription_id(cmd.cli_ctx),
2809        resource_group=resource_group_name,
2810        namespace='Microsoft.Network',
2811        type='expressRouteCircuits',
2812        name=circuit_name,
2813        child_type_1='peerings',
2814        child_name_1=peering_name
2815    )
2816    conn = ExpressRouteCircuitConnection(
2817        express_route_circuit_peering=SubResource(id=source_circuit),
2818        peer_express_route_circuit_peering=SubResource(id=peer_circuit),
2819        address_prefix=address_prefix,
2820        authorization_key=authorization_key
2821    )
2822    return client.begin_create_or_update(resource_group_name, circuit_name, peering_name, connection_name, conn)
2823
2824
2825def _validate_ipv6_address_prefixes(prefixes):
2826    from ipaddress import ip_network, IPv6Network
2827    prefixes = prefixes if isinstance(prefixes, list) else [prefixes]
2828    version = None
2829    for prefix in prefixes:
2830        try:
2831            network = ip_network(prefix)
2832            if version is None:
2833                version = type(network)
2834            else:
2835                if not isinstance(network, version):  # pylint: disable=isinstance-second-argument-not-valid-type
2836                    raise CLIError("usage error: '{}' incompatible mix of IPv4 and IPv6 address prefixes."
2837                                   .format(prefixes))
2838        except ValueError:
2839            raise CLIError("usage error: prefix '{}' is not recognized as an IPv4 or IPv6 address prefix."
2840                           .format(prefix))
2841    return version == IPv6Network
2842
2843
2844def create_express_route_peering(
2845        cmd, client, resource_group_name, circuit_name, peering_type, peer_asn, vlan_id,
2846        primary_peer_address_prefix, secondary_peer_address_prefix, shared_key=None,
2847        advertised_public_prefixes=None, customer_asn=None, routing_registry_name=None,
2848        route_filter=None, legacy_mode=None, ip_version='IPv4'):
2849    (ExpressRouteCircuitPeering, ExpressRouteCircuitPeeringConfig, RouteFilter) = \
2850        cmd.get_models('ExpressRouteCircuitPeering', 'ExpressRouteCircuitPeeringConfig', 'RouteFilter')
2851
2852    if cmd.supported_api_version(min_api='2018-02-01'):
2853        ExpressRoutePeeringType = cmd.get_models('ExpressRoutePeeringType')
2854    else:
2855        ExpressRoutePeeringType = cmd.get_models('ExpressRouteCircuitPeeringType')
2856
2857    if ip_version == 'IPv6' and cmd.supported_api_version(min_api='2020-08-01'):
2858        Ipv6ExpressRouteCircuitPeeringConfig = cmd.get_models('Ipv6ExpressRouteCircuitPeeringConfig')
2859        if peering_type == ExpressRoutePeeringType.microsoft_peering.value:
2860            microsoft_config = ExpressRouteCircuitPeeringConfig(advertised_public_prefixes=advertised_public_prefixes,
2861                                                                customer_asn=customer_asn,
2862                                                                routing_registry_name=routing_registry_name)
2863        else:
2864            microsoft_config = None
2865        ipv6 = Ipv6ExpressRouteCircuitPeeringConfig(primary_peer_address_prefix=primary_peer_address_prefix,
2866                                                    secondary_peer_address_prefix=secondary_peer_address_prefix,
2867                                                    microsoft_peering_config=microsoft_config,
2868                                                    route_filter=route_filter)
2869        peering = ExpressRouteCircuitPeering(peering_type=peering_type, ipv6_peering_config=ipv6, peer_asn=peer_asn,
2870                                             vlan_id=vlan_id)
2871
2872    else:
2873        peering = ExpressRouteCircuitPeering(
2874            peering_type=peering_type, peer_asn=peer_asn, vlan_id=vlan_id,
2875            primary_peer_address_prefix=primary_peer_address_prefix,
2876            secondary_peer_address_prefix=secondary_peer_address_prefix,
2877            shared_key=shared_key)
2878
2879        if peering_type == ExpressRoutePeeringType.microsoft_peering.value:
2880            peering.microsoft_peering_config = ExpressRouteCircuitPeeringConfig(
2881                advertised_public_prefixes=advertised_public_prefixes,
2882                customer_asn=customer_asn,
2883                routing_registry_name=routing_registry_name)
2884        if cmd.supported_api_version(min_api='2016-12-01') and route_filter:
2885            peering.route_filter = RouteFilter(id=route_filter)
2886        if cmd.supported_api_version(min_api='2017-10-01') and legacy_mode is not None:
2887            peering.microsoft_peering_config.legacy_mode = legacy_mode
2888
2889    return client.begin_create_or_update(resource_group_name, circuit_name, peering_type, peering)
2890
2891
2892def _create_or_update_ipv6_peering(cmd, config, primary_peer_address_prefix, secondary_peer_address_prefix,
2893                                   route_filter, advertised_public_prefixes, customer_asn, routing_registry_name):
2894    if config:
2895        # update scenario
2896        with cmd.update_context(config) as c:
2897            c.set_param('primary_peer_address_prefix', primary_peer_address_prefix)
2898            c.set_param('secondary_peer_address_prefix', secondary_peer_address_prefix)
2899            c.set_param('advertised_public_prefixes', advertised_public_prefixes)
2900            c.set_param('customer_asn', customer_asn)
2901            c.set_param('routing_registry_name', routing_registry_name)
2902
2903        if route_filter:
2904            RouteFilter = cmd.get_models('RouteFilter')
2905            config.route_filter = RouteFilter(id=route_filter)
2906    else:
2907        # create scenario
2908
2909        IPv6Config, MicrosoftPeeringConfig = cmd.get_models(
2910            'Ipv6ExpressRouteCircuitPeeringConfig', 'ExpressRouteCircuitPeeringConfig')
2911        microsoft_config = MicrosoftPeeringConfig(advertised_public_prefixes=advertised_public_prefixes,
2912                                                  customer_asn=customer_asn,
2913                                                  routing_registry_name=routing_registry_name)
2914        config = IPv6Config(primary_peer_address_prefix=primary_peer_address_prefix,
2915                            secondary_peer_address_prefix=secondary_peer_address_prefix,
2916                            microsoft_peering_config=microsoft_config,
2917                            route_filter=route_filter)
2918
2919    return config
2920
2921
2922def update_express_route_peering(cmd, instance, peer_asn=None, primary_peer_address_prefix=None,
2923                                 secondary_peer_address_prefix=None, vlan_id=None, shared_key=None,
2924                                 advertised_public_prefixes=None, customer_asn=None,
2925                                 routing_registry_name=None, route_filter=None, ip_version='IPv4',
2926                                 legacy_mode=None):
2927
2928    # update settings common to all peering types
2929    with cmd.update_context(instance) as c:
2930        c.set_param('peer_asn', peer_asn)
2931        c.set_param('vlan_id', vlan_id)
2932        c.set_param('shared_key', shared_key)
2933
2934    if ip_version == 'IPv6':
2935        # update is the only way to add IPv6 peering options
2936        instance.ipv6_peering_config = _create_or_update_ipv6_peering(cmd, instance.ipv6_peering_config,
2937                                                                      primary_peer_address_prefix,
2938                                                                      secondary_peer_address_prefix, route_filter,
2939                                                                      advertised_public_prefixes, customer_asn,
2940                                                                      routing_registry_name)
2941    else:
2942        # IPv4 Microsoft Peering (or non-Microsoft Peering)
2943        with cmd.update_context(instance) as c:
2944            c.set_param('primary_peer_address_prefix', primary_peer_address_prefix)
2945            c.set_param('secondary_peer_address_prefix', secondary_peer_address_prefix)
2946
2947        if route_filter is not None:
2948            RouteFilter = cmd.get_models('RouteFilter')
2949            instance.route_filter = RouteFilter(id=route_filter)
2950
2951        try:
2952            with cmd.update_context(instance.microsoft_peering_config) as c:
2953                c.set_param('advertised_public_prefixes', advertised_public_prefixes)
2954                c.set_param('customer_asn', customer_asn)
2955                c.set_param('routing_registry_name', routing_registry_name)
2956                c.set_param('legacy_mode', legacy_mode)
2957        except AttributeError:
2958            raise CLIError('--advertised-public-prefixes, --customer-asn, --routing-registry-name and '
2959                           '--legacy-mode are only applicable for Microsoft Peering.')
2960    return instance
2961# endregion
2962
2963
2964# region ExpressRoute Connection
2965# pylint: disable=unused-argument
2966def create_express_route_connection(cmd, resource_group_name, express_route_gateway_name, connection_name,
2967                                    peering, circuit_name=None, authorization_key=None, routing_weight=None,
2968                                    enable_internet_security=None, associated_route_table=None,
2969                                    propagated_route_tables=None, labels=None):
2970    ExpressRouteConnection, SubResource, RoutingConfiguration, PropagatedRouteTable\
2971        = cmd.get_models('ExpressRouteConnection', 'SubResource', 'RoutingConfiguration', 'PropagatedRouteTable')
2972    client = network_client_factory(cmd.cli_ctx).express_route_connections
2973
2974    propagated_route_tables = PropagatedRouteTable(
2975        labels=labels,
2976        ids=[SubResource(id=propagated_route_table) for propagated_route_table in
2977             propagated_route_tables] if propagated_route_tables else None
2978    )
2979    routing_configuration = RoutingConfiguration(
2980        associated_route_table=SubResource(id=associated_route_table),
2981        propagated_route_tables=propagated_route_tables
2982    )
2983    connection = ExpressRouteConnection(
2984        name=connection_name,
2985        express_route_circuit_peering=SubResource(id=peering) if peering else None,
2986        authorization_key=authorization_key,
2987        routing_weight=routing_weight,
2988        routing_configuration=routing_configuration
2989    )
2990
2991    if enable_internet_security and cmd.supported_api_version(min_api='2019-09-01'):
2992        connection.enable_internet_security = enable_internet_security
2993
2994    return client.begin_create_or_update(resource_group_name, express_route_gateway_name, connection_name, connection)
2995
2996
2997# pylint: disable=unused-argument
2998def update_express_route_connection(instance, cmd, circuit_name=None, peering=None, authorization_key=None,
2999                                    routing_weight=None, enable_internet_security=None, associated_route_table=None,
3000                                    propagated_route_tables=None, labels=None):
3001    SubResource = cmd.get_models('SubResource')
3002    if peering is not None:
3003        instance.express_route_connection_id = SubResource(id=peering)
3004    if authorization_key is not None:
3005        instance.authorization_key = authorization_key
3006    if routing_weight is not None:
3007        instance.routing_weight = routing_weight
3008    if enable_internet_security is not None and cmd.supported_api_version(min_api='2019-09-01'):
3009        instance.enable_internet_security = enable_internet_security
3010    if associated_route_table is not None or propagated_route_tables is not None or labels is not None:
3011        if instance.routing_configuration is None:
3012            RoutingConfiguration = cmd.get_models('RoutingConfiguration')
3013            instance.routing_configuration = RoutingConfiguration()
3014        if associated_route_table is not None:
3015            instance.routing_configuration.associated_route_table = SubResource(id=associated_route_table)
3016        if propagated_route_tables is not None or labels is not None:
3017            if instance.routing_configuration.propagated_route_tables is None:
3018                PropagatedRouteTable = cmd.get_models('PropagatedRouteTable')
3019                instance.routing_configuration.propagated_route_tables = PropagatedRouteTable()
3020            if propagated_route_tables is not None:
3021                instance.routing_configuration.propagated_route_tables.ids = [SubResource(id=propagated_route_table) for propagated_route_table in propagated_route_tables]  # pylint: disable=line-too-long
3022            if labels is not None:
3023                instance.routing_configuration.propagated_route_tables.labels = labels
3024
3025    return instance
3026# endregion
3027
3028
3029# region ExpressRoute Gateways
3030def create_express_route_gateway(cmd, resource_group_name, express_route_gateway_name, location=None, tags=None,
3031                                 min_val=2, max_val=None, virtual_hub=None):
3032    ExpressRouteGateway, SubResource = cmd.get_models('ExpressRouteGateway', 'SubResource')
3033    client = network_client_factory(cmd.cli_ctx).express_route_gateways
3034    gateway = ExpressRouteGateway(
3035        location=location,
3036        tags=tags,
3037        virtual_hub=SubResource(id=virtual_hub) if virtual_hub else None
3038    )
3039    if min or max:
3040        gateway.auto_scale_configuration = {'bounds': {'min': min_val, 'max': max_val}}
3041    return client.begin_create_or_update(resource_group_name, express_route_gateway_name, gateway)
3042
3043
3044def update_express_route_gateway(instance, cmd, tags=None, min_val=None, max_val=None):
3045
3046    def _ensure_autoscale():
3047        if not instance.auto_scale_configuration:
3048            ExpressRouteGatewayPropertiesAutoScaleConfiguration, \
3049                ExpressRouteGatewayPropertiesAutoScaleConfigurationBounds = cmd.get_models(
3050                    'ExpressRouteGatewayPropertiesAutoScaleConfiguration',
3051                    'ExpressRouteGatewayPropertiesAutoScaleConfigurationBounds')
3052            instance.auto_scale_configuration = ExpressRouteGatewayPropertiesAutoScaleConfiguration(
3053                bounds=ExpressRouteGatewayPropertiesAutoScaleConfigurationBounds(min=min, max=max))
3054
3055    if tags is not None:
3056        instance.tags = tags
3057    if min is not None:
3058        _ensure_autoscale()
3059        instance.auto_scale_configuration.bounds.min = min_val
3060    if max is not None:
3061        _ensure_autoscale()
3062        instance.auto_scale_configuration.bounds.max = max_val
3063    return instance
3064
3065
3066def list_express_route_gateways(cmd, resource_group_name=None):
3067    client = network_client_factory(cmd.cli_ctx).express_route_gateways
3068    if resource_group_name:
3069        return client.list_by_resource_group(resource_group_name)
3070    return client.list_by_subscription()
3071# endregion
3072
3073
3074# region ExpressRoute ports
3075def create_express_route_port(cmd, resource_group_name, express_route_port_name, location=None, tags=None,
3076                              peering_location=None, bandwidth_in_gbps=None, encapsulation=None):
3077    client = network_client_factory(cmd.cli_ctx).express_route_ports
3078    ExpressRoutePort = cmd.get_models('ExpressRoutePort')
3079    if bandwidth_in_gbps is not None:
3080        bandwidth_in_gbps = int(bandwidth_in_gbps)
3081    port = ExpressRoutePort(
3082        location=location,
3083        tags=tags,
3084        peering_location=peering_location,
3085        bandwidth_in_gbps=bandwidth_in_gbps,
3086        encapsulation=encapsulation
3087    )
3088    return client.begin_create_or_update(resource_group_name, express_route_port_name, port)
3089
3090
3091def update_express_route_port(cmd, instance, tags=None):
3092    with cmd.update_context(instance) as c:
3093        c.set_param('tags', tags, True)
3094    return instance
3095
3096
3097def download_generated_loa_as_pdf(cmd,
3098                                  resource_group_name,
3099                                  express_route_port_name,
3100                                  customer_name,
3101                                  file_path='loa.pdf'):
3102    import os
3103    import base64
3104
3105    dirname, basename = os.path.dirname(file_path), os.path.basename(file_path)
3106
3107    if basename == '':
3108        basename = 'loa.pdf'
3109    elif basename.endswith('.pdf') is False:
3110        basename = basename + '.pdf'
3111
3112    file_path = os.path.join(dirname, basename)
3113    generate_express_route_ports_loa_request =\
3114        cmd.get_models('GenerateExpressRoutePortsLOARequest')(customer_name=customer_name)
3115    client = network_client_factory(cmd.cli_ctx).express_route_ports
3116    response = client.generate_loa(resource_group_name, express_route_port_name,
3117                                   generate_express_route_ports_loa_request)
3118
3119    encoded_content = base64.b64decode(response.encoded_content)
3120
3121    from azure.cli.core.azclierror import FileOperationError
3122    try:
3123        with open(file_path, 'wb') as f:
3124            f.write(encoded_content)
3125    except OSError as ex:
3126        raise FileOperationError(ex)
3127
3128    logger.warning("The generated letter of authorization is saved at %s", file_path)
3129
3130
3131def list_express_route_ports(cmd, resource_group_name=None):
3132    client = network_client_factory(cmd.cli_ctx).express_route_ports
3133    if resource_group_name:
3134        return client.list_by_resource_group(resource_group_name)
3135    return client.list()
3136
3137
3138def assign_express_route_port_identity(cmd, resource_group_name, express_route_port_name,
3139                                       user_assigned_identity, no_wait=False):
3140    client = network_client_factory(cmd.cli_ctx).express_route_ports
3141    ports = client.get(resource_group_name, express_route_port_name)
3142
3143    ManagedServiceIdentity, ManagedServiceIdentityUserAssignedIdentitiesValue = \
3144        cmd.get_models('ManagedServiceIdentity', 'Components1Jq1T4ISchemasManagedserviceidentityPropertiesUserassignedidentitiesAdditionalproperties')  # pylint: disable=line-too-long
3145
3146    user_assigned_identity_instance = ManagedServiceIdentityUserAssignedIdentitiesValue()
3147    user_assigned_identities_instance = dict()
3148    user_assigned_identities_instance[user_assigned_identity] = user_assigned_identity_instance
3149
3150    identity_instance = ManagedServiceIdentity(type="UserAssigned",
3151                                               user_assigned_identities=user_assigned_identities_instance)
3152    ports.identity = identity_instance
3153
3154    return sdk_no_wait(no_wait, client.begin_create_or_update, resource_group_name, express_route_port_name, ports)
3155
3156
3157def remove_express_route_port_identity(cmd, resource_group_name, express_route_port_name, no_wait=False):
3158    client = network_client_factory(cmd.cli_ctx).express_route_ports
3159    ports = client.get(resource_group_name, express_route_port_name)
3160
3161    if ports.identity is None:
3162        logger.warning("The identity of the ExpressRoute Port doesn't exist.")
3163        return ports
3164
3165    ports.identity = None
3166
3167    return sdk_no_wait(no_wait, client.begin_create_or_update, resource_group_name, express_route_port_name, ports)
3168
3169
3170def show_express_route_port_identity(cmd, resource_group_name, express_route_port_name):
3171    client = network_client_factory(cmd.cli_ctx).express_route_ports
3172    ports = client.get(resource_group_name, express_route_port_name)
3173    return ports.identity
3174
3175
3176def update_express_route_port_link(cmd, instance, parent, express_route_port_name, link_name,
3177                                   macsec_cak_secret_identifier=None, macsec_ckn_secret_identifier=None,
3178                                   macsec_sci_state=None, macsec_cipher=None, admin_state=None):
3179    """
3180    :param cmd:
3181    :param instance: an instance of ExpressRoutePort
3182    :param express_route_port_name:
3183    :param link_name:
3184    :param macsec_cak_secret_identifier:
3185    :param macsec_ckn_secret_identifier:
3186    :param macsec_cipher:
3187    :param admin_state:
3188    :return:
3189    """
3190    if any([macsec_cak_secret_identifier, macsec_ckn_secret_identifier, macsec_cipher, macsec_sci_state]):
3191        instance.mac_sec_config.cak_secret_identifier = macsec_cak_secret_identifier
3192        instance.mac_sec_config.ckn_secret_identifier = macsec_ckn_secret_identifier
3193
3194        # TODO https://github.com/Azure/azure-rest-api-specs/issues/7569
3195        # need to remove this conversion when the issue is fixed.
3196        if macsec_cipher is not None:
3197            macsec_ciphers_tmp = {'gcm-aes-128': 'GcmAes128', 'gcm-aes-256': 'GcmAes256'}
3198            macsec_cipher = macsec_ciphers_tmp.get(macsec_cipher, macsec_cipher)
3199        instance.mac_sec_config.cipher = macsec_cipher
3200        instance.mac_sec_config.sci_state = macsec_sci_state
3201
3202    if admin_state is not None:
3203        instance.admin_state = admin_state
3204
3205    return parent
3206# endregion
3207
3208
3209# region PrivateEndpoint
3210def create_private_endpoint(cmd, resource_group_name, private_endpoint_name, subnet,
3211                            private_connection_resource_id, connection_name, group_ids=None,
3212                            virtual_network_name=None, tags=None, location=None,
3213                            request_message=None, manual_request=None, edge_zone=None):
3214    client = network_client_factory(cmd.cli_ctx).private_endpoints
3215    PrivateEndpoint, Subnet, PrivateLinkServiceConnection = cmd.get_models('PrivateEndpoint',
3216                                                                           'Subnet',
3217                                                                           'PrivateLinkServiceConnection')
3218    pls_connection = PrivateLinkServiceConnection(private_link_service_id=private_connection_resource_id,
3219                                                  group_ids=group_ids,
3220                                                  request_message=request_message,
3221                                                  name=connection_name)
3222    private_endpoint = PrivateEndpoint(
3223        location=location,
3224        tags=tags,
3225        subnet=Subnet(id=subnet)
3226    )
3227
3228    if manual_request:
3229        private_endpoint.manual_private_link_service_connections = [pls_connection]
3230    else:
3231        private_endpoint.private_link_service_connections = [pls_connection]
3232
3233    if edge_zone:
3234        private_endpoint.extended_location = _edge_zone_model(cmd, edge_zone)
3235    return client.begin_create_or_update(resource_group_name, private_endpoint_name, private_endpoint)
3236
3237
3238def update_private_endpoint(instance, cmd, tags=None, request_message=None):
3239    with cmd.update_context(instance) as c:
3240        c.set_param('tags', tags)
3241
3242    if request_message is not None:
3243        if instance.private_link_service_connections:
3244            instance.private_link_service_connections[0].request_message = request_message
3245        else:
3246            instance.manual_private_link_service_connections[0].request_message = request_message
3247
3248    return instance
3249
3250
3251def list_private_endpoints(cmd, resource_group_name=None):
3252    client = network_client_factory(cmd.cli_ctx).private_endpoints
3253    if resource_group_name:
3254        return client.list(resource_group_name)
3255    return client.list_by_subscription()
3256
3257
3258def create_private_endpoint_private_dns_zone_group(cmd, resource_group_name, private_endpoint_name,
3259                                                   private_dns_zone_group_name,
3260                                                   private_dns_zone_name, private_dns_zone):
3261    client = network_client_factory(cmd.cli_ctx).private_dns_zone_groups
3262    PrivateDnsZoneGroup, PrivateDnsZoneConfig = cmd.get_models('PrivateDnsZoneGroup', 'PrivateDnsZoneConfig')
3263    private_dns_zone_group = PrivateDnsZoneGroup(name=private_dns_zone_group_name,
3264                                                 private_dns_zone_configs=[PrivateDnsZoneConfig(private_dns_zone_id=private_dns_zone,  # pylint: disable=line-too-long
3265                                                                                                name=private_dns_zone_name)])  # pylint: disable=line-too-long
3266    return client.begin_create_or_update(resource_group_name=resource_group_name,
3267                                         private_endpoint_name=private_endpoint_name,
3268                                         private_dns_zone_group_name=private_dns_zone_group_name,
3269                                         parameters=private_dns_zone_group)
3270
3271
3272def add_private_endpoint_private_dns_zone(cmd, resource_group_name, private_endpoint_name,
3273                                          private_dns_zone_group_name,
3274                                          private_dns_zone_name, private_dns_zone):
3275    client = network_client_factory(cmd.cli_ctx).private_dns_zone_groups
3276    PrivateDnsZoneConfig = cmd.get_models('PrivateDnsZoneConfig')
3277    private_dns_zone_group = client.get(resource_group_name=resource_group_name,
3278                                        private_endpoint_name=private_endpoint_name,
3279                                        private_dns_zone_group_name=private_dns_zone_group_name)
3280    private_dns_zone = PrivateDnsZoneConfig(private_dns_zone_id=private_dns_zone, name=private_dns_zone_name)
3281    private_dns_zone_group.private_dns_zone_configs.append(private_dns_zone)
3282    return client.begin_create_or_update(resource_group_name=resource_group_name,
3283                                         private_endpoint_name=private_endpoint_name,
3284                                         private_dns_zone_group_name=private_dns_zone_group_name,
3285                                         parameters=private_dns_zone_group)
3286
3287
3288def remove_private_endpoint_private_dns_zone(cmd, resource_group_name, private_endpoint_name,
3289                                             private_dns_zone_group_name,
3290                                             private_dns_zone_name):
3291    client = network_client_factory(cmd.cli_ctx).private_dns_zone_groups
3292    private_dns_zone_group = client.get(resource_group_name=resource_group_name,
3293                                        private_endpoint_name=private_endpoint_name,
3294                                        private_dns_zone_group_name=private_dns_zone_group_name)
3295    private_dns_zone_configs = [item for item in private_dns_zone_group.private_dns_zone_configs if item.name != private_dns_zone_name]  # pylint: disable=line-too-long
3296    private_dns_zone_group.private_dns_zone_configs = private_dns_zone_configs
3297    return client.begin_create_or_update(resource_group_name=resource_group_name,
3298                                         private_endpoint_name=private_endpoint_name,
3299                                         private_dns_zone_group_name=private_dns_zone_group_name,
3300                                         parameters=private_dns_zone_group)
3301# endregion
3302
3303
3304# region PrivateLinkService
3305def create_private_link_service(cmd, resource_group_name, service_name, subnet, frontend_ip_configurations,
3306                                private_ip_address=None, private_ip_allocation_method=None,
3307                                private_ip_address_version=None,
3308                                virtual_network_name=None, public_ip_address=None,
3309                                location=None, tags=None, load_balancer_name=None,
3310                                visibility=None, auto_approval=None, fqdns=None,
3311                                enable_proxy_protocol=None, edge_zone=None):
3312    client = network_client_factory(cmd.cli_ctx).private_link_services
3313    FrontendIPConfiguration, PrivateLinkService, PrivateLinkServiceIpConfiguration, PublicIPAddress, Subnet = \
3314        cmd.get_models('FrontendIPConfiguration', 'PrivateLinkService', 'PrivateLinkServiceIpConfiguration',
3315                       'PublicIPAddress', 'Subnet')
3316    pls_ip_config = PrivateLinkServiceIpConfiguration(
3317        name='{}_ipconfig_0'.format(service_name),
3318        private_ip_address=private_ip_address,
3319        private_ip_allocation_method=private_ip_allocation_method,
3320        private_ip_address_version=private_ip_address_version,
3321        subnet=subnet and Subnet(id=subnet),
3322        public_ip_address=public_ip_address and PublicIPAddress(id=public_ip_address)
3323    )
3324    link_service = PrivateLinkService(
3325        location=location,
3326        load_balancer_frontend_ip_configurations=frontend_ip_configurations and [
3327            FrontendIPConfiguration(id=ip_config) for ip_config in frontend_ip_configurations
3328        ],
3329        ip_configurations=[pls_ip_config],
3330        visbility=visibility,
3331        auto_approval=auto_approval,
3332        fqdns=fqdns,
3333        tags=tags,
3334        enable_proxy_protocol=enable_proxy_protocol
3335    )
3336    if edge_zone:
3337        link_service.extended_location = _edge_zone_model(cmd, edge_zone)
3338    return client.begin_create_or_update(resource_group_name, service_name, link_service)
3339
3340
3341def update_private_link_service(instance, cmd, tags=None, frontend_ip_configurations=None, load_balancer_name=None,
3342                                visibility=None, auto_approval=None, fqdns=None, enable_proxy_protocol=None):
3343    FrontendIPConfiguration = cmd.get_models('FrontendIPConfiguration')
3344    with cmd.update_context(instance) as c:
3345        c.set_param('tags', tags)
3346        c.set_param('load_balancer_frontend_ip_configurations', frontend_ip_configurations and [
3347            FrontendIPConfiguration(id=ip_config) for ip_config in frontend_ip_configurations
3348        ])
3349        c.set_param('visibility', visibility)
3350        c.set_param('auto_approval', auto_approval)
3351        c.set_param('fqdns', fqdns)
3352        c.set_param('enable_proxy_protocol', enable_proxy_protocol)
3353    return instance
3354
3355
3356def list_private_link_services(cmd, resource_group_name=None):
3357    client = network_client_factory(cmd.cli_ctx).private_link_services
3358    if resource_group_name:
3359        return client.list(resource_group_name)
3360    return client.list_by_subscription()
3361
3362
3363def update_private_endpoint_connection(cmd, resource_group_name, service_name, pe_connection_name,
3364                                       connection_status, description=None, action_required=None):
3365    client = network_client_factory(cmd.cli_ctx).private_link_services
3366    PrivateEndpointConnection, PrivateLinkServiceConnectionState = cmd.get_models('PrivateEndpointConnection',
3367                                                                                  'PrivateLinkServiceConnectionState')
3368    connection_state = PrivateLinkServiceConnectionState(
3369        status=connection_status,
3370        description=description,
3371        actions_required=action_required
3372    )
3373    pe_connection = PrivateEndpointConnection(
3374        private_link_service_connection_state=connection_state
3375    )
3376    return client.update_private_endpoint_connection(resource_group_name, service_name, pe_connection_name, pe_connection)  # pylint: disable=line-too-long
3377
3378
3379def add_private_link_services_ipconfig(cmd, resource_group_name, service_name,
3380                                       private_ip_address=None, private_ip_allocation_method=None,
3381                                       private_ip_address_version=None,
3382                                       subnet=None, virtual_network_name=None, public_ip_address=None):
3383    client = network_client_factory(cmd.cli_ctx).private_link_services
3384    PrivateLinkServiceIpConfiguration, PublicIPAddress, Subnet = cmd.get_models('PrivateLinkServiceIpConfiguration',
3385                                                                                'PublicIPAddress',
3386                                                                                'Subnet')
3387    link_service = client.get(resource_group_name, service_name)
3388    if link_service is None:
3389        raise CLIError("Private link service should be existed. Please create it first.")
3390    ip_name_index = len(link_service.ip_configurations)
3391    ip_config = PrivateLinkServiceIpConfiguration(
3392        name='{0}_ipconfig_{1}'.format(service_name, ip_name_index),
3393        private_ip_address=private_ip_address,
3394        private_ip_allocation_method=private_ip_allocation_method,
3395        private_ip_address_version=private_ip_address_version,
3396        subnet=subnet and Subnet(id=subnet),
3397        public_ip_address=public_ip_address and PublicIPAddress(id=public_ip_address)
3398    )
3399    link_service.ip_configurations.append(ip_config)
3400    return client.begin_create_or_update(resource_group_name, service_name, link_service)
3401
3402
3403def remove_private_link_services_ipconfig(cmd, resource_group_name, service_name, ip_config_name):
3404    client = network_client_factory(cmd.cli_ctx).private_link_services
3405    link_service = client.get(resource_group_name, service_name)
3406    if link_service is None:
3407        raise CLIError("Private link service should be existed. Please create it first.")
3408    ip_config = None
3409    for item in link_service.ip_configurations:
3410        if item.name == ip_config_name:
3411            ip_config = item
3412            break
3413    if ip_config is None:  # pylint: disable=no-else-return
3414        logger.warning("%s ip configuration doesn't exist", ip_config_name)
3415        return link_service
3416    else:
3417        link_service.ip_configurations.remove(ip_config)
3418        return client.begin_create_or_update(resource_group_name, service_name, link_service)
3419# endregion
3420
3421
3422def _edge_zone_model(cmd, edge_zone):
3423    ExtendedLocation, ExtendedLocationTypes = cmd.get_models('ExtendedLocation', 'ExtendedLocationTypes')
3424    return ExtendedLocation(name=edge_zone, type=ExtendedLocationTypes.EDGE_ZONE)
3425
3426
3427# region LoadBalancers
3428def create_load_balancer(cmd, load_balancer_name, resource_group_name, location=None, tags=None,
3429                         backend_pool_name=None, frontend_ip_name='LoadBalancerFrontEnd',
3430                         private_ip_address=None, public_ip_address=None,
3431                         public_ip_address_allocation=None,
3432                         public_ip_dns_name=None, subnet=None, subnet_address_prefix='10.0.0.0/24',
3433                         virtual_network_name=None, vnet_address_prefix='10.0.0.0/16',
3434                         public_ip_address_type=None, subnet_type=None, validate=False,
3435                         no_wait=False, sku=None, frontend_ip_zone=None, public_ip_zone=None,
3436                         private_ip_address_version=None, edge_zone=None):
3437    from azure.cli.core.util import random_string
3438    from azure.cli.core.commands.arm import ArmTemplateBuilder
3439    from azure.cli.command_modules.network._template_builder import (
3440        build_load_balancer_resource, build_public_ip_resource, build_vnet_resource)
3441
3442    DeploymentProperties = cmd.get_models('DeploymentProperties', resource_type=ResourceType.MGMT_RESOURCE_RESOURCES)
3443    IPAllocationMethod = cmd.get_models('IPAllocationMethod')
3444
3445    tags = tags or {}
3446    public_ip_address = public_ip_address or 'PublicIP{}'.format(load_balancer_name)
3447    backend_pool_name = backend_pool_name or '{}bepool'.format(load_balancer_name)
3448    if not public_ip_address_allocation:
3449        public_ip_address_allocation = IPAllocationMethod.static.value if (sku and sku.lower() == 'standard') \
3450            else IPAllocationMethod.dynamic.value
3451
3452    # Build up the ARM template
3453    master_template = ArmTemplateBuilder()
3454    lb_dependencies = []
3455
3456    public_ip_id = public_ip_address if is_valid_resource_id(public_ip_address) else None
3457    subnet_id = subnet if is_valid_resource_id(subnet) else None
3458    private_ip_allocation = IPAllocationMethod.static.value if private_ip_address \
3459        else IPAllocationMethod.dynamic.value
3460
3461    network_id_template = resource_id(
3462        subscription=get_subscription_id(cmd.cli_ctx), resource_group=resource_group_name,
3463        namespace='Microsoft.Network')
3464
3465    if edge_zone and cmd.supported_api_version(min_api='2020-08-01'):
3466        edge_zone_type = 'EdgeZone'
3467    else:
3468        edge_zone_type = None
3469
3470    if subnet_type == 'new':
3471        lb_dependencies.append('Microsoft.Network/virtualNetworks/{}'.format(virtual_network_name))
3472        vnet = build_vnet_resource(
3473            cmd, virtual_network_name, location, tags, vnet_address_prefix, subnet,
3474            subnet_address_prefix)
3475        master_template.add_resource(vnet)
3476        subnet_id = '{}/virtualNetworks/{}/subnets/{}'.format(
3477            network_id_template, virtual_network_name, subnet)
3478
3479    if public_ip_address_type == 'new':
3480        lb_dependencies.append('Microsoft.Network/publicIpAddresses/{}'.format(public_ip_address))
3481        master_template.add_resource(build_public_ip_resource(cmd, public_ip_address, location,
3482                                                              tags,
3483                                                              public_ip_address_allocation,
3484                                                              public_ip_dns_name,
3485                                                              sku, public_ip_zone, None, edge_zone, edge_zone_type))
3486        public_ip_id = '{}/publicIPAddresses/{}'.format(network_id_template,
3487                                                        public_ip_address)
3488
3489    load_balancer_resource = build_load_balancer_resource(
3490        cmd, load_balancer_name, location, tags, backend_pool_name, frontend_ip_name,
3491        public_ip_id, subnet_id, private_ip_address, private_ip_allocation, sku,
3492        frontend_ip_zone, private_ip_address_version, None, edge_zone, edge_zone_type)
3493    load_balancer_resource['dependsOn'] = lb_dependencies
3494    master_template.add_resource(load_balancer_resource)
3495    master_template.add_output('loadBalancer', load_balancer_name, output_type='object')
3496
3497    template = master_template.build()
3498
3499    # deploy ARM template
3500    deployment_name = 'lb_deploy_' + random_string(32)
3501    client = get_mgmt_service_client(cmd.cli_ctx, ResourceType.MGMT_RESOURCE_RESOURCES).deployments
3502    properties = DeploymentProperties(template=template, parameters={}, mode='incremental')
3503    Deployment = cmd.get_models('Deployment', resource_type=ResourceType.MGMT_RESOURCE_RESOURCES)
3504    deployment = Deployment(properties=properties)
3505
3506    if validate:
3507        _log_pprint_template(template)
3508        if cmd.supported_api_version(min_api='2019-10-01', resource_type=ResourceType.MGMT_RESOURCE_RESOURCES):
3509            from azure.cli.core.commands import LongRunningOperation
3510            validation_poller = client.begin_validate(resource_group_name, deployment_name, deployment)
3511            return LongRunningOperation(cmd.cli_ctx)(validation_poller)
3512
3513        return client.validate(resource_group_name, deployment_name, deployment)
3514
3515    return sdk_no_wait(no_wait, client.begin_create_or_update, resource_group_name, deployment_name, deployment)
3516
3517
3518def list_load_balancer_nic(cmd, resource_group_name, load_balancer_name):
3519    client = network_client_factory(cmd.cli_ctx).load_balancer_network_interfaces
3520    return client.list(resource_group_name, load_balancer_name)
3521
3522
3523def create_lb_inbound_nat_rule(
3524        cmd, resource_group_name, load_balancer_name, item_name, protocol, frontend_port,
3525        backend_port, frontend_ip_name=None, floating_ip=None, idle_timeout=None, enable_tcp_reset=None):
3526    InboundNatRule = cmd.get_models('InboundNatRule')
3527    ncf = network_client_factory(cmd.cli_ctx)
3528    lb = lb_get(ncf.load_balancers, resource_group_name, load_balancer_name)
3529    if not frontend_ip_name:
3530        frontend_ip_name = _get_default_name(lb, 'frontend_ip_configurations', '--frontend-ip-name')
3531    frontend_ip = get_property(lb.frontend_ip_configurations, frontend_ip_name)  # pylint: disable=no-member
3532    new_rule = InboundNatRule(
3533        name=item_name, protocol=protocol,
3534        frontend_port=frontend_port, backend_port=backend_port,
3535        frontend_ip_configuration=frontend_ip,
3536        enable_floating_ip=floating_ip,
3537        idle_timeout_in_minutes=idle_timeout,
3538        enable_tcp_reset=enable_tcp_reset)
3539    upsert_to_collection(lb, 'inbound_nat_rules', new_rule, 'name')
3540    poller = ncf.load_balancers.begin_create_or_update(resource_group_name, load_balancer_name, lb)
3541    return get_property(poller.result().inbound_nat_rules, item_name)
3542
3543
3544# workaround for : https://github.com/Azure/azure-cli/issues/17071
3545def lb_get(client, resource_group_name, load_balancer_name):
3546    lb = client.get(resource_group_name, load_balancer_name)
3547    return lb_get_operation(lb)
3548
3549
3550# workaround for : https://github.com/Azure/azure-cli/issues/17071
3551def lb_get_operation(lb):
3552    for item in lb.frontend_ip_configurations:
3553        if item.zones is not None and len(item.zones) >= 3 and item.subnet is None:
3554            item.zones = None
3555
3556    return lb
3557
3558
3559def set_lb_inbound_nat_rule(
3560        cmd, instance, parent, item_name, protocol=None, frontend_port=None,
3561        frontend_ip_name=None, backend_port=None, floating_ip=None, idle_timeout=None, enable_tcp_reset=None):
3562    if frontend_ip_name:
3563        instance.frontend_ip_configuration = \
3564            get_property(parent.frontend_ip_configurations, frontend_ip_name)
3565
3566    if enable_tcp_reset is not None:
3567        instance.enable_tcp_reset = enable_tcp_reset
3568
3569    with cmd.update_context(instance) as c:
3570        c.set_param('protocol', protocol)
3571        c.set_param('frontend_port', frontend_port)
3572        c.set_param('backend_port', backend_port)
3573        c.set_param('idle_timeout_in_minutes', idle_timeout)
3574        c.set_param('enable_floating_ip', floating_ip)
3575
3576    return parent
3577
3578
3579def create_lb_inbound_nat_pool(
3580        cmd, resource_group_name, load_balancer_name, item_name, protocol, frontend_port_range_start,
3581        frontend_port_range_end, backend_port, frontend_ip_name=None, enable_tcp_reset=None,
3582        floating_ip=None, idle_timeout=None):
3583    InboundNatPool = cmd.get_models('InboundNatPool')
3584    ncf = network_client_factory(cmd.cli_ctx)
3585    lb = lb_get(ncf.load_balancers, resource_group_name, load_balancer_name)
3586    if not frontend_ip_name:
3587        frontend_ip_name = _get_default_name(lb, 'frontend_ip_configurations', '--frontend-ip-name')
3588    frontend_ip = get_property(lb.frontend_ip_configurations, frontend_ip_name) \
3589        if frontend_ip_name else None
3590    new_pool = InboundNatPool(
3591        name=item_name,
3592        protocol=protocol,
3593        frontend_ip_configuration=frontend_ip,
3594        frontend_port_range_start=frontend_port_range_start,
3595        frontend_port_range_end=frontend_port_range_end,
3596        backend_port=backend_port,
3597        enable_tcp_reset=enable_tcp_reset,
3598        enable_floating_ip=floating_ip,
3599        idle_timeout_in_minutes=idle_timeout)
3600    upsert_to_collection(lb, 'inbound_nat_pools', new_pool, 'name')
3601    poller = ncf.load_balancers.begin_create_or_update(resource_group_name, load_balancer_name, lb)
3602    return get_property(poller.result().inbound_nat_pools, item_name)
3603
3604
3605def set_lb_inbound_nat_pool(
3606        cmd, instance, parent, item_name, protocol=None,
3607        frontend_port_range_start=None, frontend_port_range_end=None, backend_port=None,
3608        frontend_ip_name=None, enable_tcp_reset=None, floating_ip=None, idle_timeout=None):
3609    with cmd.update_context(instance) as c:
3610        c.set_param('protocol', protocol)
3611        c.set_param('frontend_port_range_start', frontend_port_range_start)
3612        c.set_param('frontend_port_range_end', frontend_port_range_end)
3613        c.set_param('backend_port', backend_port)
3614        c.set_param('enable_floating_ip', floating_ip)
3615        c.set_param('idle_timeout_in_minutes', idle_timeout)
3616
3617    if enable_tcp_reset is not None:
3618        instance.enable_tcp_reset = enable_tcp_reset
3619
3620    if frontend_ip_name == '':
3621        instance.frontend_ip_configuration = None
3622    elif frontend_ip_name is not None:
3623        instance.frontend_ip_configuration = \
3624            get_property(parent.frontend_ip_configurations, frontend_ip_name)
3625
3626    return parent
3627
3628
3629def create_lb_frontend_ip_configuration(
3630        cmd, resource_group_name, load_balancer_name, item_name, public_ip_address=None,
3631        public_ip_prefix=None, subnet=None, virtual_network_name=None, private_ip_address=None,
3632        private_ip_address_version=None, private_ip_address_allocation=None, zone=None):
3633    FrontendIPConfiguration, SubResource, Subnet = cmd.get_models(
3634        'FrontendIPConfiguration', 'SubResource', 'Subnet')
3635    ncf = network_client_factory(cmd.cli_ctx)
3636    lb = lb_get(ncf.load_balancers, resource_group_name, load_balancer_name)
3637
3638    if private_ip_address_allocation is None:
3639        private_ip_address_allocation = 'static' if private_ip_address else 'dynamic'
3640
3641    new_config = FrontendIPConfiguration(
3642        name=item_name,
3643        private_ip_address=private_ip_address,
3644        private_ip_address_version=private_ip_address_version,
3645        private_ip_allocation_method=private_ip_address_allocation,
3646        public_ip_address=SubResource(id=public_ip_address) if public_ip_address else None,
3647        public_ip_prefix=SubResource(id=public_ip_prefix) if public_ip_prefix else None,
3648        subnet=Subnet(id=subnet) if subnet else None)
3649
3650    if zone and cmd.supported_api_version(min_api='2017-06-01'):
3651        new_config.zones = zone
3652
3653    upsert_to_collection(lb, 'frontend_ip_configurations', new_config, 'name')
3654    poller = ncf.load_balancers.begin_create_or_update(resource_group_name, load_balancer_name, lb)
3655    return get_property(poller.result().frontend_ip_configurations, item_name)
3656
3657
3658def update_lb_frontend_ip_configuration_setter(cmd, resource_group_name, load_balancer_name, parameters, gateway_lb):
3659    aux_subscriptions = []
3660    if is_valid_resource_id(gateway_lb):
3661        aux_subscriptions.append(parse_resource_id(gateway_lb)['subscription'])
3662    client = network_client_factory(cmd.cli_ctx, aux_subscriptions=aux_subscriptions).load_balancers
3663    return client.begin_create_or_update(resource_group_name, load_balancer_name, parameters)
3664
3665
3666def set_lb_frontend_ip_configuration(
3667        cmd, instance, parent, item_name, private_ip_address=None,
3668        private_ip_address_allocation=None, public_ip_address=None,
3669        subnet=None, virtual_network_name=None, public_ip_prefix=None, gateway_lb=None):
3670    PublicIPAddress, Subnet, SubResource = cmd.get_models('PublicIPAddress', 'Subnet', 'SubResource')
3671    if not private_ip_address:
3672        instance.private_ip_allocation_method = 'dynamic'
3673        instance.private_ip_address = None
3674    elif private_ip_address is not None:
3675        instance.private_ip_allocation_method = 'static'
3676        instance.private_ip_address = private_ip_address
3677
3678    # Doesn't support update operation for now
3679    # if cmd.supported_api_version(min_api='2019-04-01'):
3680    #    instance.private_ip_address_version = private_ip_address_version
3681
3682    if subnet == '':
3683        instance.subnet = None
3684    elif subnet is not None:
3685        instance.subnet = Subnet(id=subnet)
3686
3687    if public_ip_address == '':
3688        instance.public_ip_address = None
3689    elif public_ip_address is not None:
3690        instance.public_ip_address = PublicIPAddress(id=public_ip_address)
3691
3692    if public_ip_prefix:
3693        instance.public_ip_prefix = SubResource(id=public_ip_prefix)
3694    if gateway_lb is not None:
3695        instance.gateway_load_balancer = None if gateway_lb == '' else SubResource(id=gateway_lb)
3696
3697    return parent
3698
3699
3700def _process_vnet_name_and_id(vnet, cmd, resource_group_name):
3701    if vnet and not is_valid_resource_id(vnet):
3702        vnet = resource_id(
3703            subscription=get_subscription_id(cmd.cli_ctx),
3704            resource_group=resource_group_name,
3705            namespace='Microsoft.Network',
3706            type='virtualNetworks',
3707            name=vnet)
3708    return vnet
3709
3710
3711def _process_subnet_name_and_id(subnet, vnet, cmd, resource_group_name):
3712    if subnet and not is_valid_resource_id(subnet):
3713        vnet = _process_vnet_name_and_id(vnet, cmd, resource_group_name)
3714        if vnet is None:
3715            raise UnrecognizedArgumentError('vnet should be provided when input subnet name instead of subnet id')
3716
3717        subnet = vnet + f'/subnets/{subnet}'
3718    return subnet
3719
3720
3721# pylint: disable=too-many-branches
3722def create_lb_backend_address_pool(cmd, resource_group_name, load_balancer_name, backend_address_pool_name,
3723                                   vnet=None, backend_addresses=None, backend_addresses_config_file=None):
3724    if backend_addresses and backend_addresses_config_file:
3725        raise CLIError('usage error: Only one of --backend-address and --backend-addresses-config-file can be provided at the same time.')  # pylint: disable=line-too-long
3726    if backend_addresses_config_file:
3727        if not isinstance(backend_addresses_config_file, list):
3728            raise CLIError('Config file must be a list. Please see example as a reference.')
3729        for addr in backend_addresses_config_file:
3730            if not isinstance(addr, dict):
3731                raise CLIError('Each address in config file must be a dictionary. Please see example as a reference.')
3732    ncf = network_client_factory(cmd.cli_ctx)
3733    lb = lb_get(ncf.load_balancers, resource_group_name, load_balancer_name)
3734    (BackendAddressPool,
3735     LoadBalancerBackendAddress,
3736     Subnet,
3737     VirtualNetwork) = cmd.get_models('BackendAddressPool',
3738                                      'LoadBalancerBackendAddress',
3739                                      'Subnet',
3740                                      'VirtualNetwork')
3741    # Before 2020-03-01, service doesn't support the other rest method.
3742    # We have to use old one to keep backward compatibility.
3743    # Same for basic sku. service refuses that basic sku lb call the other rest method.
3744    if cmd.supported_api_version(max_api='2020-03-01') or lb.sku.name.lower() == 'basic':
3745        new_pool = BackendAddressPool(name=backend_address_pool_name)
3746        upsert_to_collection(lb, 'backend_address_pools', new_pool, 'name')
3747        poller = ncf.load_balancers.begin_create_or_update(resource_group_name, load_balancer_name, lb)
3748        return get_property(poller.result().backend_address_pools, backend_address_pool_name)
3749
3750    addresses_pool = []
3751    if backend_addresses:
3752        addresses_pool.extend(backend_addresses)
3753    if backend_addresses_config_file:
3754        addresses_pool.extend(backend_addresses_config_file)
3755    for addr in addresses_pool:
3756        if 'virtual_network' not in addr and vnet:
3757            addr['virtual_network'] = vnet
3758
3759    # pylint: disable=line-too-long
3760    if cmd.supported_api_version(min_api='2020-11-01'):  # pylint: disable=too-many-nested-blocks
3761        try:
3762            if addresses_pool:
3763                new_addresses = []
3764                for addr in addresses_pool:
3765                    # vnet      | subnet        |  status
3766                    # name/id   | name/id/null  |    ok
3767                    # null      | id            |    ok
3768                    if 'virtual_network' in addr:
3769                        address = LoadBalancerBackendAddress(name=addr['name'],
3770                                                             virtual_network=VirtualNetwork(id=_process_vnet_name_and_id(addr['virtual_network'], cmd, resource_group_name)),
3771                                                             subnet=Subnet(id=_process_subnet_name_and_id(addr['subnet'], addr['virtual_network'], cmd, resource_group_name)) if 'subnet' in addr else None,
3772                                                             ip_address=addr['ip_address'])
3773                    elif 'subnet' in addr and is_valid_resource_id(addr['subnet']):
3774                        address = LoadBalancerBackendAddress(name=addr['name'],
3775                                                             subnet=Subnet(id=addr['subnet']),
3776                                                             ip_address=addr['ip_address'])
3777                    else:
3778                        raise KeyError
3779
3780                    new_addresses.append(address)
3781            else:
3782                new_addresses = None
3783        except KeyError:
3784            raise UnrecognizedArgumentError('Each backend address must have name, ip-address, (vnet name and subnet '
3785                                            'name | subnet id) information.')
3786    else:
3787        try:
3788            new_addresses = [LoadBalancerBackendAddress(name=addr['name'],
3789                                                        virtual_network=VirtualNetwork(id=_process_vnet_name_and_id(addr['virtual_network'], cmd, resource_group_name)),
3790                                                        ip_address=addr['ip_address']) for addr in addresses_pool] if addresses_pool else None
3791        except KeyError:
3792            raise UnrecognizedArgumentError('Each backend address must have name, vnet and ip-address information.')
3793
3794    new_pool = BackendAddressPool(name=backend_address_pool_name,
3795                                  load_balancer_backend_addresses=new_addresses)
3796
3797    # when sku is 'gateway', 'tunnelInterfaces' can't be None. Otherwise service will response error
3798    if cmd.supported_api_version(min_api='2021-02-01') and lb.sku.name.lower() == 'gateway':
3799        GatewayLoadBalancerTunnelInterface = cmd.get_models('GatewayLoadBalancerTunnelInterface')
3800        new_pool.tunnel_interfaces = [
3801            GatewayLoadBalancerTunnelInterface(type='Internal', protocol='VXLAN', identifier=900)]
3802    return ncf.load_balancer_backend_address_pools.begin_create_or_update(resource_group_name,
3803                                                                          load_balancer_name,
3804                                                                          backend_address_pool_name,
3805                                                                          new_pool)
3806
3807
3808def delete_lb_backend_address_pool(cmd, resource_group_name, load_balancer_name, backend_address_pool_name):
3809    from azure.cli.core.commands import LongRunningOperation
3810    ncf = network_client_factory(cmd.cli_ctx)
3811    lb = lb_get(ncf.load_balancers, resource_group_name, load_balancer_name)
3812
3813    def delete_basic_lb_backend_address_pool():
3814        new_be_pools = [pool for pool in lb.backend_address_pools
3815                        if pool.name.lower() != backend_address_pool_name.lower()]
3816        lb.backend_address_pools = new_be_pools
3817        poller = ncf.load_balancers.begin_create_or_update(resource_group_name, load_balancer_name, lb)
3818        result = LongRunningOperation(cmd.cli_ctx)(poller).backend_address_pools
3819        if next((x for x in result if x.name.lower() == backend_address_pool_name.lower()), None):
3820            raise CLIError("Failed to delete '{}' on '{}'".format(backend_address_pool_name, load_balancer_name))
3821
3822    if lb.sku.name.lower() == 'basic':
3823        delete_basic_lb_backend_address_pool()
3824        return None
3825
3826    return ncf.load_balancer_backend_address_pools.begin_delete(resource_group_name,
3827                                                                load_balancer_name,
3828                                                                backend_address_pool_name)
3829
3830
3831# region cross-region lb
3832def create_cross_region_load_balancer(cmd, load_balancer_name, resource_group_name, location=None, tags=None,
3833                                      backend_pool_name=None, frontend_ip_name='LoadBalancerFrontEnd',
3834                                      public_ip_address=None, public_ip_address_allocation=None,
3835                                      public_ip_dns_name=None, public_ip_address_type=None, validate=False,
3836                                      no_wait=False, frontend_ip_zone=None, public_ip_zone=None):
3837    from azure.cli.core.util import random_string
3838    from azure.cli.core.commands.arm import ArmTemplateBuilder
3839    from azure.cli.command_modules.network._template_builder import (
3840        build_load_balancer_resource, build_public_ip_resource)
3841
3842    DeploymentProperties = cmd.get_models('DeploymentProperties', resource_type=ResourceType.MGMT_RESOURCE_RESOURCES)
3843    IPAllocationMethod = cmd.get_models('IPAllocationMethod')
3844
3845    sku = 'standard'
3846    tier = 'Global'
3847
3848    tags = tags or {}
3849    public_ip_address = public_ip_address or 'PublicIP{}'.format(load_balancer_name)
3850    backend_pool_name = backend_pool_name or '{}bepool'.format(load_balancer_name)
3851    if not public_ip_address_allocation:
3852        public_ip_address_allocation = IPAllocationMethod.static.value if (sku and sku.lower() == 'standard') \
3853            else IPAllocationMethod.dynamic.value
3854
3855    # Build up the ARM template
3856    master_template = ArmTemplateBuilder()
3857    lb_dependencies = []
3858
3859    public_ip_id = public_ip_address if is_valid_resource_id(public_ip_address) else None
3860
3861    network_id_template = resource_id(
3862        subscription=get_subscription_id(cmd.cli_ctx), resource_group=resource_group_name,
3863        namespace='Microsoft.Network')
3864
3865    if public_ip_address_type == 'new':
3866        lb_dependencies.append('Microsoft.Network/publicIpAddresses/{}'.format(public_ip_address))
3867        master_template.add_resource(build_public_ip_resource(cmd, public_ip_address, location,
3868                                                              tags,
3869                                                              public_ip_address_allocation,
3870                                                              public_ip_dns_name,
3871                                                              sku, public_ip_zone, tier))
3872        public_ip_id = '{}/publicIPAddresses/{}'.format(network_id_template,
3873                                                        public_ip_address)
3874
3875    load_balancer_resource = build_load_balancer_resource(
3876        cmd, load_balancer_name, location, tags, backend_pool_name, frontend_ip_name,
3877        public_ip_id, None, None, None, sku, frontend_ip_zone, None, tier)
3878    load_balancer_resource['dependsOn'] = lb_dependencies
3879    master_template.add_resource(load_balancer_resource)
3880    master_template.add_output('loadBalancer', load_balancer_name, output_type='object')
3881
3882    template = master_template.build()
3883
3884    # deploy ARM template
3885    deployment_name = 'lb_deploy_' + random_string(32)
3886    client = get_mgmt_service_client(cmd.cli_ctx, ResourceType.MGMT_RESOURCE_RESOURCES).deployments
3887    properties = DeploymentProperties(template=template, parameters={}, mode='incremental')
3888    Deployment = cmd.get_models('Deployment', resource_type=ResourceType.MGMT_RESOURCE_RESOURCES)
3889    deployment = Deployment(properties=properties)
3890
3891    if validate:
3892        _log_pprint_template(template)
3893        if cmd.supported_api_version(min_api='2019-10-01', resource_type=ResourceType.MGMT_RESOURCE_RESOURCES):
3894            from azure.cli.core.commands import LongRunningOperation
3895            validation_poller = client.begin_validate(resource_group_name, deployment_name, deployment)
3896            return LongRunningOperation(cmd.cli_ctx)(validation_poller)
3897
3898        return client.validate(resource_group_name, deployment_name, deployment)
3899
3900    return sdk_no_wait(no_wait, client.begin_create_or_update, resource_group_name, deployment_name, deployment)
3901
3902
3903def create_cross_region_lb_frontend_ip_configuration(
3904        cmd, resource_group_name, load_balancer_name, item_name, public_ip_address=None,
3905        public_ip_prefix=None, zone=None):
3906    FrontendIPConfiguration, SubResource = cmd.get_models(
3907        'FrontendIPConfiguration', 'SubResource')
3908    ncf = network_client_factory(cmd.cli_ctx)
3909    lb = lb_get(ncf.load_balancers, resource_group_name, load_balancer_name)
3910
3911    new_config = FrontendIPConfiguration(
3912        name=item_name,
3913        public_ip_address=SubResource(id=public_ip_address) if public_ip_address else None,
3914        public_ip_prefix=SubResource(id=public_ip_prefix) if public_ip_prefix else None)
3915
3916    if zone and cmd.supported_api_version(min_api='2017-06-01'):
3917        new_config.zones = zone
3918
3919    upsert_to_collection(lb, 'frontend_ip_configurations', new_config, 'name')
3920    poller = ncf.load_balancers.begin_create_or_update(resource_group_name, load_balancer_name, lb)
3921    return get_property(poller.result().frontend_ip_configurations, item_name)
3922
3923
3924def set_cross_region_lb_frontend_ip_configuration(
3925        cmd, instance, parent, item_name, public_ip_address=None, public_ip_prefix=None):
3926    PublicIPAddress, SubResource = cmd.get_models('PublicIPAddress', 'SubResource')
3927
3928    if public_ip_address == '':
3929        instance.public_ip_address = None
3930    elif public_ip_address is not None:
3931        instance.public_ip_address = PublicIPAddress(id=public_ip_address)
3932
3933    if public_ip_prefix:
3934        instance.public_ip_prefix = SubResource(id=public_ip_prefix)
3935
3936    return parent
3937
3938
3939def create_cross_region_lb_backend_address_pool(cmd, resource_group_name, load_balancer_name, backend_address_pool_name,
3940                                                backend_addresses=None, backend_addresses_config_file=None):
3941    if backend_addresses and backend_addresses_config_file:
3942        raise CLIError('usage error: Only one of --backend-address and --backend-addresses-config-file can be provided at the same time.')  # pylint: disable=line-too-long
3943    if backend_addresses_config_file:
3944        if not isinstance(backend_addresses_config_file, list):
3945            raise CLIError('Config file must be a list. Please see example as a reference.')
3946        for addr in backend_addresses_config_file:
3947            if not isinstance(addr, dict):
3948                raise CLIError('Each address in config file must be a dictionary. Please see example as a reference.')
3949    ncf = network_client_factory(cmd.cli_ctx)
3950    (BackendAddressPool,
3951     LoadBalancerBackendAddress,
3952     FrontendIPConfiguration) = cmd.get_models('BackendAddressPool',
3953                                               'LoadBalancerBackendAddress',
3954                                               'FrontendIPConfiguration')
3955
3956    addresses_pool = []
3957    if backend_addresses:
3958        addresses_pool.extend(backend_addresses)
3959    if backend_addresses_config_file:
3960        addresses_pool.extend(backend_addresses_config_file)
3961
3962    # pylint: disable=line-too-long
3963    try:
3964        new_addresses = [LoadBalancerBackendAddress(name=addr['name'],
3965                                                    load_balancer_frontend_ip_configuration=FrontendIPConfiguration(id=addr['frontend_ip_address'])) for addr in addresses_pool] if addresses_pool else None
3966    except KeyError:
3967        raise CLIError('Each backend address must have name and frontend_ip_configuration information.')
3968    new_pool = BackendAddressPool(name=backend_address_pool_name,
3969                                  load_balancer_backend_addresses=new_addresses)
3970    return ncf.load_balancer_backend_address_pools.begin_create_or_update(resource_group_name,
3971                                                                          load_balancer_name,
3972                                                                          backend_address_pool_name,
3973                                                                          new_pool)
3974
3975
3976def delete_cross_region_lb_backend_address_pool(cmd, resource_group_name, load_balancer_name, backend_address_pool_name):  # pylint: disable=line-too-long
3977    ncf = network_client_factory(cmd.cli_ctx)
3978
3979    return ncf.load_balancer_backend_address_pools.begin_delete(resource_group_name,
3980                                                                load_balancer_name,
3981                                                                backend_address_pool_name)
3982
3983
3984def add_cross_region_lb_backend_address_pool_address(cmd, resource_group_name, load_balancer_name,
3985                                                     backend_address_pool_name, address_name, frontend_ip_address):
3986    client = network_client_factory(cmd.cli_ctx).load_balancer_backend_address_pools
3987    address_pool = client.get(resource_group_name, load_balancer_name, backend_address_pool_name)
3988    # pylint: disable=line-too-long
3989    (LoadBalancerBackendAddress, FrontendIPConfiguration) = cmd.get_models('LoadBalancerBackendAddress', 'FrontendIPConfiguration')
3990    new_address = LoadBalancerBackendAddress(name=address_name,
3991                                             load_balancer_frontend_ip_configuration=FrontendIPConfiguration(id=frontend_ip_address) if frontend_ip_address else None)
3992    if address_pool.load_balancer_backend_addresses is None:
3993        address_pool.load_balancer_backend_addresses = []
3994    address_pool.load_balancer_backend_addresses.append(new_address)
3995    return client.begin_create_or_update(resource_group_name, load_balancer_name,
3996                                         backend_address_pool_name, address_pool)
3997
3998
3999def create_cross_region_lb_rule(
4000        cmd, resource_group_name, load_balancer_name, item_name,
4001        protocol, frontend_port, backend_port, frontend_ip_name=None,
4002        backend_address_pool_name=None, probe_name=None, load_distribution='default',
4003        floating_ip=None, idle_timeout=None, enable_tcp_reset=None, backend_pools_name=None):
4004    LoadBalancingRule = cmd.get_models('LoadBalancingRule')
4005    ncf = network_client_factory(cmd.cli_ctx)
4006    lb = cached_get(cmd, ncf.load_balancers.get, resource_group_name, load_balancer_name)
4007    lb = lb_get_operation(lb)
4008    if not frontend_ip_name:
4009        frontend_ip_name = _get_default_name(lb, 'frontend_ip_configurations', '--frontend-ip-name')
4010    if not backend_address_pool_name:
4011        backend_address_pool_name = _get_default_name(lb, 'backend_address_pools', '--backend-pool-name')
4012    new_rule = LoadBalancingRule(
4013        name=item_name,
4014        protocol=protocol,
4015        frontend_port=frontend_port,
4016        backend_port=backend_port,
4017        frontend_ip_configuration=get_property(lb.frontend_ip_configurations,
4018                                               frontend_ip_name),
4019        backend_address_pool=get_property(lb.backend_address_pools,
4020                                          backend_address_pool_name),
4021        probe=get_property(lb.probes, probe_name) if probe_name else None,
4022        load_distribution=load_distribution,
4023        enable_floating_ip=floating_ip,
4024        idle_timeout_in_minutes=idle_timeout,
4025        enable_tcp_reset=enable_tcp_reset)
4026    if backend_pools_name:
4027        new_rule.backend_address_pools = [get_property(lb.backend_address_pools, i) for i in backend_pools_name]
4028    upsert_to_collection(lb, 'load_balancing_rules', new_rule, 'name')
4029    poller = cached_put(cmd, ncf.load_balancers.begin_create_or_update, lb, resource_group_name, load_balancer_name)
4030    return get_property(poller.result().load_balancing_rules, item_name)
4031
4032
4033def set_cross_region_lb_rule(
4034        cmd, instance, parent, item_name, protocol=None, frontend_port=None,
4035        frontend_ip_name=None, backend_port=None, backend_address_pool_name=None, probe_name=None,
4036        load_distribution=None, floating_ip=None, idle_timeout=None, enable_tcp_reset=None, backend_pools_name=None):
4037    with cmd.update_context(instance) as c:
4038        c.set_param('protocol', protocol)
4039        c.set_param('frontend_port', frontend_port)
4040        c.set_param('backend_port', backend_port)
4041        c.set_param('idle_timeout_in_minutes', idle_timeout)
4042        c.set_param('load_distribution', load_distribution)
4043        c.set_param('enable_tcp_reset', enable_tcp_reset)
4044        c.set_param('enable_floating_ip', floating_ip)
4045
4046    if frontend_ip_name is not None:
4047        instance.frontend_ip_configuration = \
4048            get_property(parent.frontend_ip_configurations, frontend_ip_name)
4049
4050    if backend_address_pool_name is not None:
4051        instance.backend_address_pool = \
4052            get_property(parent.backend_address_pools, backend_address_pool_name)
4053        # To keep compatible when bump version from '2020-11-01' to '2021-02-01'
4054        # https://github.com/Azure/azure-rest-api-specs/issues/14430
4055        if cmd.supported_api_version(min_api='2021-02-01') and not backend_pools_name:
4056            instance.backend_address_pools = [instance.backend_address_pool]
4057    if backend_pools_name is not None:
4058        instance.backend_address_pools = [get_property(parent.backend_address_pools, i) for i in backend_pools_name]
4059
4060    if probe_name == '':
4061        instance.probe = None
4062    elif probe_name is not None:
4063        instance.probe = get_property(parent.probes, probe_name)
4064
4065    return parent
4066# endregion
4067
4068
4069# pylint: disable=line-too-long
4070def add_lb_backend_address_pool_address(cmd, resource_group_name, load_balancer_name, backend_address_pool_name,
4071                                        address_name, ip_address, vnet=None, subnet=None):
4072    client = network_client_factory(cmd.cli_ctx).load_balancer_backend_address_pools
4073    address_pool = client.get(resource_group_name, load_balancer_name, backend_address_pool_name)
4074    (LoadBalancerBackendAddress,
4075     Subnet,
4076     VirtualNetwork) = cmd.get_models('LoadBalancerBackendAddress',
4077                                      'Subnet',
4078                                      'VirtualNetwork')
4079    if cmd.supported_api_version(min_api='2020-11-01'):
4080        if vnet:
4081            new_address = LoadBalancerBackendAddress(name=address_name,
4082                                                     subnet=Subnet(id=_process_subnet_name_and_id(subnet, vnet, cmd, resource_group_name)) if subnet else None,
4083                                                     virtual_network=VirtualNetwork(id=vnet),
4084                                                     ip_address=ip_address if ip_address else None)
4085        elif is_valid_resource_id(subnet):
4086            new_address = LoadBalancerBackendAddress(name=address_name,
4087                                                     subnet=Subnet(id=subnet),
4088                                                     ip_address=ip_address if ip_address else None)
4089        else:
4090            raise UnrecognizedArgumentError('Each backend address must have name, ip-address, (vnet name and subnet name | subnet id) information.')
4091
4092    else:
4093        new_address = LoadBalancerBackendAddress(name=address_name,
4094                                                 virtual_network=VirtualNetwork(id=vnet) if vnet else None,
4095                                                 ip_address=ip_address if ip_address else None)
4096    if address_pool.load_balancer_backend_addresses is None:
4097        address_pool.load_balancer_backend_addresses = []
4098    address_pool.load_balancer_backend_addresses.append(new_address)
4099    return client.begin_create_or_update(resource_group_name, load_balancer_name,
4100                                         backend_address_pool_name, address_pool)
4101
4102
4103def remove_lb_backend_address_pool_address(cmd, resource_group_name, load_balancer_name,
4104                                           backend_address_pool_name, address_name):
4105    client = network_client_factory(cmd.cli_ctx).load_balancer_backend_address_pools
4106    address_pool = client.get(resource_group_name, load_balancer_name, backend_address_pool_name)
4107    if address_pool.load_balancer_backend_addresses is None:
4108        address_pool.load_balancer_backend_addresses = []
4109    lb_addresses = [addr for addr in address_pool.load_balancer_backend_addresses if addr.name != address_name]
4110    address_pool.load_balancer_backend_addresses = lb_addresses
4111    return client.begin_create_or_update(resource_group_name, load_balancer_name,
4112                                         backend_address_pool_name, address_pool)
4113
4114
4115def list_lb_backend_address_pool_address(cmd, resource_group_name, load_balancer_name,
4116                                         backend_address_pool_name):
4117    client = network_client_factory(cmd.cli_ctx).load_balancer_backend_address_pools
4118    address_pool = client.get(resource_group_name, load_balancer_name, backend_address_pool_name)
4119    return address_pool.load_balancer_backend_addresses
4120
4121
4122def create_lb_outbound_rule(cmd, resource_group_name, load_balancer_name, item_name,
4123                            backend_address_pool, frontend_ip_configurations, protocol,
4124                            outbound_ports=None, enable_tcp_reset=None, idle_timeout=None):
4125    OutboundRule, SubResource = cmd.get_models('OutboundRule', 'SubResource')
4126    client = network_client_factory(cmd.cli_ctx).load_balancers
4127    lb = lb_get(client, resource_group_name, load_balancer_name)
4128    rule = OutboundRule(
4129        protocol=protocol, enable_tcp_reset=enable_tcp_reset, idle_timeout_in_minutes=idle_timeout,
4130        backend_address_pool=SubResource(id=backend_address_pool),
4131        frontend_ip_configurations=[SubResource(id=x) for x in frontend_ip_configurations]
4132        if frontend_ip_configurations else None,
4133        allocated_outbound_ports=outbound_ports, name=item_name)
4134    upsert_to_collection(lb, 'outbound_rules', rule, 'name')
4135    poller = client.begin_create_or_update(resource_group_name, load_balancer_name, lb)
4136    return get_property(poller.result().outbound_rules, item_name)
4137
4138
4139def set_lb_outbound_rule(instance, cmd, parent, item_name, protocol=None, outbound_ports=None,
4140                         idle_timeout=None, frontend_ip_configurations=None, enable_tcp_reset=None,
4141                         backend_address_pool=None):
4142    SubResource = cmd.get_models('SubResource')
4143    with cmd.update_context(instance) as c:
4144        c.set_param('protocol', protocol)
4145        c.set_param('allocated_outbound_ports', outbound_ports)
4146        c.set_param('idle_timeout_in_minutes', idle_timeout)
4147        c.set_param('enable_tcp_reset', enable_tcp_reset)
4148        c.set_param('backend_address_pool', SubResource(id=backend_address_pool)
4149                    if backend_address_pool else None)
4150        c.set_param('frontend_ip_configurations',
4151                    [SubResource(id=x) for x in frontend_ip_configurations] if frontend_ip_configurations else None)
4152    return parent
4153
4154
4155def create_lb_probe(cmd, resource_group_name, load_balancer_name, item_name, protocol, port,
4156                    path=None, interval=None, threshold=None):
4157    Probe = cmd.get_models('Probe')
4158    ncf = network_client_factory(cmd.cli_ctx)
4159    lb = lb_get(ncf.load_balancers, resource_group_name, load_balancer_name)
4160    new_probe = Probe(
4161        protocol=protocol, port=port, interval_in_seconds=interval, number_of_probes=threshold,
4162        request_path=path, name=item_name)
4163    upsert_to_collection(lb, 'probes', new_probe, 'name')
4164    poller = ncf.load_balancers.begin_create_or_update(resource_group_name, load_balancer_name, lb)
4165    return get_property(poller.result().probes, item_name)
4166
4167
4168def set_lb_probe(cmd, instance, parent, item_name, protocol=None, port=None,
4169                 path=None, interval=None, threshold=None):
4170    with cmd.update_context(instance) as c:
4171        c.set_param('protocol', protocol)
4172        c.set_param('port', port)
4173        c.set_param('request_path', path)
4174        c.set_param('interval_in_seconds', interval)
4175        c.set_param('number_of_probes', threshold)
4176    return parent
4177
4178
4179def create_lb_rule(
4180        cmd, resource_group_name, load_balancer_name, item_name,
4181        protocol, frontend_port, backend_port, frontend_ip_name=None,
4182        backend_address_pool_name=None, probe_name=None, load_distribution='default',
4183        floating_ip=None, idle_timeout=None, enable_tcp_reset=None, disable_outbound_snat=None, backend_pools_name=None):
4184    LoadBalancingRule = cmd.get_models('LoadBalancingRule')
4185    ncf = network_client_factory(cmd.cli_ctx)
4186    lb = cached_get(cmd, ncf.load_balancers.get, resource_group_name, load_balancer_name)
4187    lb = lb_get_operation(lb)
4188    if not frontend_ip_name:
4189        frontend_ip_name = _get_default_name(lb, 'frontend_ip_configurations', '--frontend-ip-name')
4190    # avoid break when backend_address_pool_name is None and backend_pools_name is not None
4191    if not backend_address_pool_name and backend_pools_name:
4192        backend_address_pool_name = backend_pools_name[0]
4193    if not backend_address_pool_name:
4194        backend_address_pool_name = _get_default_name(lb, 'backend_address_pools', '--backend-pool-name')
4195    new_rule = LoadBalancingRule(
4196        name=item_name,
4197        protocol=protocol,
4198        frontend_port=frontend_port,
4199        backend_port=backend_port,
4200        frontend_ip_configuration=get_property(lb.frontend_ip_configurations,
4201                                               frontend_ip_name),
4202        backend_address_pool=get_property(lb.backend_address_pools,
4203                                          backend_address_pool_name),
4204        probe=get_property(lb.probes, probe_name) if probe_name else None,
4205        load_distribution=load_distribution,
4206        enable_floating_ip=floating_ip,
4207        idle_timeout_in_minutes=idle_timeout,
4208        enable_tcp_reset=enable_tcp_reset,
4209        disable_outbound_snat=disable_outbound_snat)
4210
4211    if backend_pools_name:
4212        new_rule.backend_address_pools = [get_property(lb.backend_address_pools, name) for name in backend_pools_name]
4213        # Otherwiase service will response error : (LoadBalancingRuleBackendAdressPoolAndBackendAddressPoolsCannotBeSetAtTheSameTimeWithDifferentValue) BackendAddressPool and BackendAddressPools[] in LoadBalancingRule rule2 cannot be set at the same time with different value.
4214        new_rule.backend_address_pool = None
4215
4216    upsert_to_collection(lb, 'load_balancing_rules', new_rule, 'name')
4217    poller = cached_put(cmd, ncf.load_balancers.begin_create_or_update, lb, resource_group_name, load_balancer_name)
4218    return get_property(poller.result().load_balancing_rules, item_name)
4219
4220
4221def set_lb_rule(
4222        cmd, instance, parent, item_name, protocol=None, frontend_port=None,
4223        frontend_ip_name=None, backend_port=None, backend_address_pool_name=None, probe_name=None,
4224        load_distribution='default', floating_ip=None, idle_timeout=None, enable_tcp_reset=None,
4225        disable_outbound_snat=None, backend_pools_name=None):
4226    with cmd.update_context(instance) as c:
4227        c.set_param('protocol', protocol)
4228        c.set_param('frontend_port', frontend_port)
4229        c.set_param('backend_port', backend_port)
4230        c.set_param('idle_timeout_in_minutes', idle_timeout)
4231        c.set_param('load_distribution', load_distribution)
4232        c.set_param('disable_outbound_snat', disable_outbound_snat)
4233        c.set_param('enable_tcp_reset', enable_tcp_reset)
4234        c.set_param('enable_floating_ip', floating_ip)
4235
4236    if frontend_ip_name is not None:
4237        instance.frontend_ip_configuration = \
4238            get_property(parent.frontend_ip_configurations, frontend_ip_name)
4239
4240    if backend_address_pool_name is not None:
4241        instance.backend_address_pool = \
4242            get_property(parent.backend_address_pools, backend_address_pool_name)
4243        # To keep compatible when bump version from '2020-11-01' to '2021-02-01'
4244        # https://github.com/Azure/azure-rest-api-specs/issues/14430
4245        if cmd.supported_api_version(min_api='2021-02-01') and not backend_pools_name:
4246            instance.backend_address_pools = [instance.backend_address_pool]
4247    if backend_pools_name is not None:
4248        instance.backend_address_pools = [get_property(parent.backend_address_pools, i) for i in backend_pools_name]
4249        # Otherwiase service will response error : (LoadBalancingRuleBackendAdressPoolAndBackendAddressPoolsCannotBeSetAtTheSameTimeWithDifferentValue) BackendAddressPool and BackendAddressPools[] in LoadBalancingRule rule2 cannot be set at the same time with different value.
4250        instance.backend_address_pool = None
4251
4252    if probe_name == '':
4253        instance.probe = None
4254    elif probe_name is not None:
4255        instance.probe = get_property(parent.probes, probe_name)
4256
4257    return parent
4258
4259
4260def add_lb_backend_address_pool_tunnel_interface(cmd, resource_group_name, load_balancer_name,
4261                                                 backend_address_pool_name, protocol, identifier, traffic_type, port=None):
4262    client = network_client_factory(cmd.cli_ctx).load_balancer_backend_address_pools
4263    address_pool = client.get(resource_group_name, load_balancer_name, backend_address_pool_name)
4264    GatewayLoadBalancerTunnelInterface = cmd.get_models('GatewayLoadBalancerTunnelInterface')
4265    tunnel_interface = GatewayLoadBalancerTunnelInterface(port=port, identifier=identifier, protocol=protocol, type=traffic_type)
4266    if not address_pool.tunnel_interfaces:
4267        address_pool.tunnel_interfaces = []
4268    address_pool.tunnel_interfaces.append(tunnel_interface)
4269    return client.begin_create_or_update(resource_group_name, load_balancer_name,
4270                                         backend_address_pool_name, address_pool)
4271
4272
4273def update_lb_backend_address_pool_tunnel_interface(cmd, resource_group_name, load_balancer_name,
4274                                                    backend_address_pool_name, index, protocol=None, identifier=None, traffic_type=None, port=None):
4275    client = network_client_factory(cmd.cli_ctx).load_balancer_backend_address_pools
4276    address_pool = client.get(resource_group_name, load_balancer_name, backend_address_pool_name)
4277    if index >= len(address_pool.tunnel_interfaces):
4278        raise UnrecognizedArgumentError(f'{index} is out of scope, please input proper index')
4279
4280    item = address_pool.tunnel_interfaces[index]
4281    if protocol:
4282        item.protocol = protocol
4283    if identifier:
4284        item.identifier = identifier
4285    if port:
4286        item.port = port
4287    if traffic_type:
4288        item.type = traffic_type
4289    return client.begin_create_or_update(resource_group_name, load_balancer_name,
4290                                         backend_address_pool_name, address_pool)
4291
4292
4293def remove_lb_backend_address_pool_tunnel_interface(cmd, resource_group_name, load_balancer_name,
4294                                                    backend_address_pool_name, index):
4295    client = network_client_factory(cmd.cli_ctx).load_balancer_backend_address_pools
4296    address_pool = client.get(resource_group_name, load_balancer_name, backend_address_pool_name)
4297    if index >= len(address_pool.tunnel_interfaces):
4298        raise UnrecognizedArgumentError(f'{index} is out of scope, please input proper index')
4299    address_pool.tunnel_interfaces.pop(index)
4300    return client.begin_create_or_update(resource_group_name, load_balancer_name,
4301                                         backend_address_pool_name, address_pool)
4302
4303
4304def list_lb_backend_address_pool_tunnel_interface(cmd, resource_group_name, load_balancer_name,
4305                                                  backend_address_pool_name):
4306    client = network_client_factory(cmd.cli_ctx).load_balancer_backend_address_pools
4307    address_pool = client.get(resource_group_name, load_balancer_name, backend_address_pool_name)
4308    return address_pool.tunnel_interfaces
4309# endregion
4310
4311
4312# region LocalGateways
4313def _validate_bgp_peering(cmd, instance, asn, bgp_peering_address, peer_weight):
4314    if any([asn, bgp_peering_address, peer_weight]):
4315        if instance.bgp_settings is not None:
4316            # update existing parameters selectively
4317            if asn is not None:
4318                instance.bgp_settings.asn = asn
4319            if peer_weight is not None:
4320                instance.bgp_settings.peer_weight = peer_weight
4321            if bgp_peering_address is not None:
4322                instance.bgp_settings.bgp_peering_address = bgp_peering_address
4323        elif asn:
4324            BgpSettings = cmd.get_models('BgpSettings')
4325            instance.bgp_settings = BgpSettings(asn, bgp_peering_address, peer_weight)
4326        else:
4327            raise CLIError(
4328                'incorrect usage: --asn ASN [--peer-weight WEIGHT --bgp-peering-address IP]')
4329
4330
4331def create_local_gateway(cmd, resource_group_name, local_network_gateway_name, gateway_ip_address,
4332                         location=None, tags=None, local_address_prefix=None, asn=None,
4333                         bgp_peering_address=None, peer_weight=None, no_wait=False):
4334    AddressSpace, LocalNetworkGateway, BgpSettings = cmd.get_models(
4335        'AddressSpace', 'LocalNetworkGateway', 'BgpSettings')
4336    client = network_client_factory(cmd.cli_ctx).local_network_gateways
4337    local_gateway = LocalNetworkGateway(
4338        local_network_address_space=AddressSpace(address_prefixes=(local_address_prefix or [])),
4339        location=location, tags=tags, gateway_ip_address=gateway_ip_address)
4340    if bgp_peering_address or asn or peer_weight:
4341        local_gateway.bgp_settings = BgpSettings(asn=asn, bgp_peering_address=bgp_peering_address,
4342                                                 peer_weight=peer_weight)
4343    return sdk_no_wait(no_wait, client.begin_create_or_update,
4344                       resource_group_name, local_network_gateway_name, local_gateway)
4345
4346
4347def update_local_gateway(cmd, instance, gateway_ip_address=None, local_address_prefix=None, asn=None,
4348                         bgp_peering_address=None, peer_weight=None, tags=None):
4349    _validate_bgp_peering(cmd, instance, asn, bgp_peering_address, peer_weight)
4350
4351    if gateway_ip_address is not None:
4352        instance.gateway_ip_address = gateway_ip_address
4353    if local_address_prefix is not None:
4354        instance.local_network_address_space.address_prefixes = local_address_prefix
4355    if tags is not None:
4356        instance.tags = tags
4357    return instance
4358# endregion
4359
4360
4361# region NetworkInterfaces (NIC)
4362def create_nic(cmd, resource_group_name, network_interface_name, subnet, location=None, tags=None,
4363               internal_dns_name_label=None, dns_servers=None, enable_ip_forwarding=False,
4364               load_balancer_backend_address_pool_ids=None,
4365               load_balancer_inbound_nat_rule_ids=None,
4366               load_balancer_name=None, network_security_group=None,
4367               private_ip_address=None, private_ip_address_version=None,
4368               public_ip_address=None, virtual_network_name=None, enable_accelerated_networking=None,
4369               application_security_groups=None, no_wait=False,
4370               app_gateway_backend_address_pools=None, edge_zone=None):
4371    client = network_client_factory(cmd.cli_ctx).network_interfaces
4372    (NetworkInterface, NetworkInterfaceDnsSettings, NetworkInterfaceIPConfiguration, NetworkSecurityGroup,
4373     PublicIPAddress, Subnet, SubResource) = cmd.get_models(
4374         'NetworkInterface', 'NetworkInterfaceDnsSettings', 'NetworkInterfaceIPConfiguration',
4375         'NetworkSecurityGroup', 'PublicIPAddress', 'Subnet', 'SubResource')
4376
4377    dns_settings = NetworkInterfaceDnsSettings(internal_dns_name_label=internal_dns_name_label,
4378                                               dns_servers=dns_servers or [])
4379
4380    nic = NetworkInterface(location=location, tags=tags, enable_ip_forwarding=enable_ip_forwarding,
4381                           dns_settings=dns_settings)
4382
4383    if cmd.supported_api_version(min_api='2016-09-01'):
4384        nic.enable_accelerated_networking = enable_accelerated_networking
4385
4386    if network_security_group:
4387        nic.network_security_group = NetworkSecurityGroup(id=network_security_group)
4388    ip_config_args = {
4389        'name': 'ipconfig1',
4390        'load_balancer_backend_address_pools': load_balancer_backend_address_pool_ids,
4391        'load_balancer_inbound_nat_rules': load_balancer_inbound_nat_rule_ids,
4392        'private_ip_allocation_method': 'Static' if private_ip_address else 'Dynamic',
4393        'private_ip_address': private_ip_address,
4394        'subnet': Subnet(id=subnet),
4395        'application_gateway_backend_address_pools':
4396            [SubResource(id=x) for x in app_gateway_backend_address_pools]
4397            if app_gateway_backend_address_pools else None
4398    }
4399    if cmd.supported_api_version(min_api='2016-09-01'):
4400        ip_config_args['private_ip_address_version'] = private_ip_address_version
4401    if cmd.supported_api_version(min_api='2017-09-01'):
4402        ip_config_args['application_security_groups'] = application_security_groups
4403    ip_config = NetworkInterfaceIPConfiguration(**ip_config_args)
4404
4405    if public_ip_address:
4406        ip_config.public_ip_address = PublicIPAddress(id=public_ip_address)
4407    nic.ip_configurations = [ip_config]
4408
4409    if edge_zone:
4410        nic.extended_location = _edge_zone_model(cmd, edge_zone)
4411    return sdk_no_wait(no_wait, client.begin_create_or_update, resource_group_name, network_interface_name, nic)
4412
4413
4414def update_nic(cmd, instance, network_security_group=None, enable_ip_forwarding=None,
4415               internal_dns_name_label=None, dns_servers=None, enable_accelerated_networking=None):
4416    if enable_ip_forwarding is not None:
4417        instance.enable_ip_forwarding = enable_ip_forwarding
4418
4419    if network_security_group == '':
4420        instance.network_security_group = None
4421    elif network_security_group is not None:
4422        NetworkSecurityGroup = cmd.get_models('NetworkSecurityGroup')
4423        instance.network_security_group = NetworkSecurityGroup(id=network_security_group)
4424
4425    if internal_dns_name_label == '':
4426        instance.dns_settings.internal_dns_name_label = None
4427    elif internal_dns_name_label is not None:
4428        instance.dns_settings.internal_dns_name_label = internal_dns_name_label
4429    if dns_servers == ['']:
4430        instance.dns_settings.dns_servers = None
4431    elif dns_servers:
4432        instance.dns_settings.dns_servers = dns_servers
4433
4434    if enable_accelerated_networking is not None:
4435        instance.enable_accelerated_networking = enable_accelerated_networking
4436
4437    return instance
4438
4439
4440def create_nic_ip_config(cmd, resource_group_name, network_interface_name, ip_config_name, subnet=None,
4441                         virtual_network_name=None, public_ip_address=None, load_balancer_name=None,
4442                         load_balancer_backend_address_pool_ids=None,
4443                         load_balancer_inbound_nat_rule_ids=None,
4444                         private_ip_address=None,
4445                         private_ip_address_version=None,
4446                         make_primary=False,
4447                         application_security_groups=None,
4448                         app_gateway_backend_address_pools=None):
4449    NetworkInterfaceIPConfiguration, PublicIPAddress, Subnet, SubResource = cmd.get_models(
4450        'NetworkInterfaceIPConfiguration', 'PublicIPAddress', 'Subnet', 'SubResource')
4451    ncf = network_client_factory(cmd.cli_ctx)
4452    nic = ncf.network_interfaces.get(resource_group_name, network_interface_name)
4453
4454    if cmd.supported_api_version(min_api='2016-09-01'):
4455        IPVersion = cmd.get_models('IPVersion')
4456        private_ip_address_version = private_ip_address_version or IPVersion.I_PV4.value
4457        if private_ip_address_version == IPVersion.I_PV4.value and not subnet:
4458            primary_config = next(x for x in nic.ip_configurations if x.primary)
4459            subnet = primary_config.subnet.id
4460        if make_primary:
4461            for config in nic.ip_configurations:
4462                config.primary = False
4463
4464    new_config_args = {
4465        'name': ip_config_name,
4466        'subnet': Subnet(id=subnet) if subnet else None,
4467        'public_ip_address': PublicIPAddress(id=public_ip_address) if public_ip_address else None,
4468        'load_balancer_backend_address_pools': load_balancer_backend_address_pool_ids,
4469        'load_balancer_inbound_nat_rules': load_balancer_inbound_nat_rule_ids,
4470        'private_ip_address': private_ip_address,
4471        'private_ip_allocation_method': 'Static' if private_ip_address else 'Dynamic'
4472    }
4473    if cmd.supported_api_version(min_api='2016-09-01'):
4474        new_config_args['private_ip_address_version'] = private_ip_address_version
4475        new_config_args['primary'] = make_primary
4476    if cmd.supported_api_version(min_api='2017-09-01'):
4477        new_config_args['application_security_groups'] = application_security_groups
4478    if cmd.supported_api_version(min_api='2018-08-01'):
4479        new_config_args['application_gateway_backend_address_pools'] = \
4480            [SubResource(id=x) for x in app_gateway_backend_address_pools] \
4481            if app_gateway_backend_address_pools else None
4482
4483    new_config = NetworkInterfaceIPConfiguration(**new_config_args)
4484
4485    upsert_to_collection(nic, 'ip_configurations', new_config, 'name')
4486    poller = ncf.network_interfaces.begin_create_or_update(
4487        resource_group_name, network_interface_name, nic)
4488    return get_property(poller.result().ip_configurations, ip_config_name)
4489
4490
4491def update_nic_ip_config_setter(cmd, resource_group_name, network_interface_name, parameters, gateway_lb):
4492    aux_subscriptions = []
4493    if is_valid_resource_id(gateway_lb):
4494        aux_subscriptions.append(parse_resource_id(gateway_lb)['subscription'])
4495    client = network_client_factory(cmd.cli_ctx, aux_subscriptions=aux_subscriptions).network_interfaces
4496    return client.begin_create_or_update(resource_group_name, network_interface_name, parameters)
4497
4498
4499def set_nic_ip_config(cmd, instance, parent, ip_config_name, subnet=None,
4500                      virtual_network_name=None, public_ip_address=None, load_balancer_name=None,
4501                      load_balancer_backend_address_pool_ids=None,
4502                      load_balancer_inbound_nat_rule_ids=None,
4503                      private_ip_address=None,
4504                      private_ip_address_version=None, make_primary=False,
4505                      application_security_groups=None,
4506                      app_gateway_backend_address_pools=None, gateway_lb=None):
4507    PublicIPAddress, Subnet, SubResource = cmd.get_models('PublicIPAddress', 'Subnet', 'SubResource')
4508
4509    if make_primary:
4510        for config in parent.ip_configurations:
4511            config.primary = False
4512        instance.primary = True
4513
4514    if private_ip_address == '':
4515        # switch private IP address allocation to Dynamic if empty string is used
4516        instance.private_ip_address = None
4517        instance.private_ip_allocation_method = 'dynamic'
4518        if cmd.supported_api_version(min_api='2016-09-01'):
4519            instance.private_ip_address_version = 'ipv4'
4520    elif private_ip_address is not None:
4521        # if specific address provided, allocation is static
4522        instance.private_ip_address = private_ip_address
4523        instance.private_ip_allocation_method = 'static'
4524        if private_ip_address_version is not None:
4525            instance.private_ip_address_version = private_ip_address_version
4526
4527    if subnet == '':
4528        instance.subnet = None
4529    elif subnet is not None:
4530        instance.subnet = Subnet(id=subnet)
4531
4532    if public_ip_address == '':
4533        instance.public_ip_address = None
4534    elif public_ip_address is not None:
4535        instance.public_ip_address = PublicIPAddress(id=public_ip_address)
4536
4537    if load_balancer_backend_address_pool_ids == '':
4538        instance.load_balancer_backend_address_pools = None
4539    elif load_balancer_backend_address_pool_ids is not None:
4540        instance.load_balancer_backend_address_pools = load_balancer_backend_address_pool_ids
4541
4542    if load_balancer_inbound_nat_rule_ids == '':
4543        instance.load_balancer_inbound_nat_rules = None
4544    elif load_balancer_inbound_nat_rule_ids is not None:
4545        instance.load_balancer_inbound_nat_rules = load_balancer_inbound_nat_rule_ids
4546
4547    if application_security_groups == ['']:
4548        instance.application_security_groups = None
4549    elif application_security_groups:
4550        instance.application_security_groups = application_security_groups
4551
4552    if app_gateway_backend_address_pools == ['']:
4553        instance.application_gateway_backend_address_pools = None
4554    elif app_gateway_backend_address_pools:
4555        instance.application_gateway_backend_address_pools = \
4556            [SubResource(id=x) for x in app_gateway_backend_address_pools]
4557    if gateway_lb is not None:
4558        instance.gateway_load_balancer = None if gateway_lb == '' else SubResource(id=gateway_lb)
4559    return parent
4560
4561
4562def _get_nic_ip_config(nic, name):
4563    if nic.ip_configurations:
4564        ip_config = next(
4565            (x for x in nic.ip_configurations if x.name.lower() == name.lower()), None)
4566    else:
4567        ip_config = None
4568    if not ip_config:
4569        raise CLIError('IP configuration {} not found.'.format(name))
4570    return ip_config
4571
4572
4573def add_nic_ip_config_address_pool(
4574        cmd, resource_group_name, network_interface_name, ip_config_name, backend_address_pool,
4575        load_balancer_name=None, application_gateway_name=None):
4576    BackendAddressPool = cmd.get_models('BackendAddressPool')
4577    client = network_client_factory(cmd.cli_ctx).network_interfaces
4578    nic = client.get(resource_group_name, network_interface_name)
4579    ip_config = _get_nic_ip_config(nic, ip_config_name)
4580    if load_balancer_name:
4581        upsert_to_collection(ip_config, 'load_balancer_backend_address_pools',
4582                             BackendAddressPool(id=backend_address_pool),
4583                             'id')
4584    elif application_gateway_name:
4585        upsert_to_collection(ip_config, 'application_gateway_backend_address_pools',
4586                             BackendAddressPool(id=backend_address_pool),
4587                             'id')
4588    poller = client.begin_create_or_update(resource_group_name, network_interface_name, nic)
4589    return get_property(poller.result().ip_configurations, ip_config_name)
4590
4591
4592def remove_nic_ip_config_address_pool(
4593        cmd, resource_group_name, network_interface_name, ip_config_name, backend_address_pool,
4594        load_balancer_name=None, application_gateway_name=None):
4595    client = network_client_factory(cmd.cli_ctx).network_interfaces
4596    nic = client.get(resource_group_name, network_interface_name)
4597    ip_config = _get_nic_ip_config(nic, ip_config_name)
4598    if load_balancer_name:
4599        keep_items = [x for x in ip_config.load_balancer_backend_address_pools or [] if x.id != backend_address_pool]
4600        ip_config.load_balancer_backend_address_pools = keep_items
4601    elif application_gateway_name:
4602        keep_items = [x for x in ip_config.application_gateway_backend_address_pools or [] if
4603                      x.id != backend_address_pool]
4604        ip_config.application_gateway_backend_address_pools = keep_items
4605    poller = client.begin_create_or_update(resource_group_name, network_interface_name, nic)
4606    return get_property(poller.result().ip_configurations, ip_config_name)
4607
4608
4609def add_nic_ip_config_inbound_nat_rule(
4610        cmd, resource_group_name, network_interface_name, ip_config_name, inbound_nat_rule,
4611        load_balancer_name=None):
4612    InboundNatRule = cmd.get_models('InboundNatRule')
4613    client = network_client_factory(cmd.cli_ctx).network_interfaces
4614    nic = client.get(resource_group_name, network_interface_name)
4615    ip_config = _get_nic_ip_config(nic, ip_config_name)
4616    upsert_to_collection(ip_config, 'load_balancer_inbound_nat_rules',
4617                         InboundNatRule(id=inbound_nat_rule),
4618                         'id')
4619    poller = client.begin_create_or_update(resource_group_name, network_interface_name, nic)
4620    return get_property(poller.result().ip_configurations, ip_config_name)
4621
4622
4623def remove_nic_ip_config_inbound_nat_rule(
4624        cmd, resource_group_name, network_interface_name, ip_config_name, inbound_nat_rule,
4625        load_balancer_name=None):
4626    client = network_client_factory(cmd.cli_ctx).network_interfaces
4627    nic = client.get(resource_group_name, network_interface_name)
4628    ip_config = _get_nic_ip_config(nic, ip_config_name)
4629    keep_items = \
4630        [x for x in ip_config.load_balancer_inbound_nat_rules or [] if x.id != inbound_nat_rule]
4631    ip_config.load_balancer_inbound_nat_rules = keep_items
4632    poller = client.begin_create_or_update(resource_group_name, network_interface_name, nic)
4633    return get_property(poller.result().ip_configurations, ip_config_name)
4634# endregion
4635
4636
4637# region NetworkSecurityGroups
4638def create_nsg(cmd, resource_group_name, network_security_group_name, location=None, tags=None):
4639    client = network_client_factory(cmd.cli_ctx).network_security_groups
4640    NetworkSecurityGroup = cmd.get_models('NetworkSecurityGroup')
4641    nsg = NetworkSecurityGroup(location=location, tags=tags)
4642    return client.begin_create_or_update(resource_group_name, network_security_group_name, nsg)
4643
4644
4645def _create_singular_or_plural_property(kwargs, val, singular_name, plural_name):
4646
4647    if not val:
4648        return
4649    if not isinstance(val, list):
4650        val = [val]
4651    if len(val) > 1:
4652        kwargs[plural_name] = val
4653        kwargs[singular_name] = None
4654    else:
4655        kwargs[singular_name] = val[0]
4656        kwargs[plural_name] = None
4657
4658
4659def _handle_asg_property(kwargs, key, asgs):
4660    prefix = key.split('_', 1)[0] + '_'
4661    if asgs:
4662        kwargs[key] = asgs
4663        if kwargs[prefix + 'address_prefix'].is_default:
4664            kwargs[prefix + 'address_prefix'] = ''
4665
4666
4667def create_nsg_rule_2017_06_01(cmd, resource_group_name, network_security_group_name, security_rule_name,
4668                               priority, description=None, protocol=None, access=None, direction=None,
4669                               source_port_ranges='*', source_address_prefixes='*',
4670                               destination_port_ranges=80, destination_address_prefixes='*',
4671                               source_asgs=None, destination_asgs=None):
4672    kwargs = {
4673        'protocol': protocol,
4674        'direction': direction,
4675        'description': description,
4676        'priority': priority,
4677        'access': access,
4678        'name': security_rule_name
4679    }
4680    _create_singular_or_plural_property(kwargs, source_address_prefixes,
4681                                        'source_address_prefix', 'source_address_prefixes')
4682    _create_singular_or_plural_property(kwargs, destination_address_prefixes,
4683                                        'destination_address_prefix', 'destination_address_prefixes')
4684    _create_singular_or_plural_property(kwargs, source_port_ranges,
4685                                        'source_port_range', 'source_port_ranges')
4686    _create_singular_or_plural_property(kwargs, destination_port_ranges,
4687                                        'destination_port_range', 'destination_port_ranges')
4688
4689    # workaround for issue https://github.com/Azure/azure-rest-api-specs/issues/1591
4690    kwargs['source_address_prefix'] = kwargs['source_address_prefix'] or ''
4691    kwargs['destination_address_prefix'] = kwargs['destination_address_prefix'] or ''
4692
4693    if cmd.supported_api_version(min_api='2017-09-01'):
4694        _handle_asg_property(kwargs, 'source_application_security_groups', source_asgs)
4695        _handle_asg_property(kwargs, 'destination_application_security_groups', destination_asgs)
4696
4697    SecurityRule = cmd.get_models('SecurityRule')
4698    settings = SecurityRule(**kwargs)
4699    ncf = network_client_factory(cmd.cli_ctx)
4700    return ncf.security_rules.begin_create_or_update(
4701        resource_group_name, network_security_group_name, security_rule_name, settings)
4702
4703
4704def create_nsg_rule_2017_03_01(cmd, resource_group_name, network_security_group_name, security_rule_name,
4705                               priority, description=None, protocol=None, access=None, direction=None,
4706                               source_port_range='*', source_address_prefix='*',
4707                               destination_port_range=80, destination_address_prefix='*'):
4708    SecurityRule = cmd.get_models('SecurityRule')
4709    settings = SecurityRule(protocol=protocol, source_address_prefix=source_address_prefix,
4710                            destination_address_prefix=destination_address_prefix, access=access,
4711                            direction=direction,
4712                            description=description, source_port_range=source_port_range,
4713                            destination_port_range=destination_port_range, priority=priority,
4714                            name=security_rule_name)
4715
4716    ncf = network_client_factory(cmd.cli_ctx)
4717    return ncf.security_rules.begin_create_or_update(
4718        resource_group_name, network_security_group_name, security_rule_name, settings)
4719
4720
4721def _update_singular_or_plural_property(instance, val, singular_name, plural_name):
4722
4723    if val is None:
4724        return
4725    if not isinstance(val, list):
4726        val = [val]
4727    if len(val) > 1:
4728        setattr(instance, plural_name, val)
4729        setattr(instance, singular_name, None)
4730    else:
4731        setattr(instance, plural_name, None)
4732        setattr(instance, singular_name, val[0])
4733
4734
4735def update_nsg_rule_2017_06_01(instance, protocol=None, source_address_prefixes=None,
4736                               destination_address_prefixes=None, access=None, direction=None, description=None,
4737                               source_port_ranges=None, destination_port_ranges=None, priority=None,
4738                               source_asgs=None, destination_asgs=None):
4739    # No client validation as server side returns pretty good errors
4740    instance.protocol = protocol if protocol is not None else instance.protocol
4741    instance.access = access if access is not None else instance.access
4742    instance.direction = direction if direction is not None else instance.direction
4743    instance.description = description if description is not None else instance.description
4744    instance.priority = priority if priority is not None else instance.priority
4745
4746    _update_singular_or_plural_property(instance, source_address_prefixes,
4747                                        'source_address_prefix', 'source_address_prefixes')
4748    _update_singular_or_plural_property(instance, destination_address_prefixes,
4749                                        'destination_address_prefix', 'destination_address_prefixes')
4750    _update_singular_or_plural_property(instance, source_port_ranges,
4751                                        'source_port_range', 'source_port_ranges')
4752    _update_singular_or_plural_property(instance, destination_port_ranges,
4753                                        'destination_port_range', 'destination_port_ranges')
4754
4755    # workaround for issue https://github.com/Azure/azure-rest-api-specs/issues/1591
4756    instance.source_address_prefix = instance.source_address_prefix or ''
4757    instance.destination_address_prefix = instance.destination_address_prefix or ''
4758
4759    if source_asgs == ['']:
4760        instance.source_application_security_groups = None
4761    elif source_asgs:
4762        instance.source_application_security_groups = source_asgs
4763
4764    if destination_asgs == ['']:
4765        instance.destination_application_security_groups = None
4766    elif destination_asgs:
4767        instance.destination_application_security_groups = destination_asgs
4768
4769    return instance
4770
4771
4772def update_nsg_rule_2017_03_01(instance, protocol=None, source_address_prefix=None,
4773                               destination_address_prefix=None, access=None, direction=None, description=None,
4774                               source_port_range=None, destination_port_range=None, priority=None):
4775    # No client validation as server side returns pretty good errors
4776    instance.protocol = protocol if protocol is not None else instance.protocol
4777    instance.source_address_prefix = (source_address_prefix if source_address_prefix is not None
4778                                      else instance.source_address_prefix)
4779    instance.destination_address_prefix = destination_address_prefix \
4780        if destination_address_prefix is not None else instance.destination_address_prefix
4781    instance.access = access if access is not None else instance.access
4782    instance.direction = direction if direction is not None else instance.direction
4783    instance.description = description if description is not None else instance.description
4784    instance.source_port_range = source_port_range \
4785        if source_port_range is not None else instance.source_port_range
4786    instance.destination_port_range = destination_port_range \
4787        if destination_port_range is not None else instance.destination_port_range
4788    instance.priority = priority if priority is not None else instance.priority
4789    return instance
4790# endregion
4791
4792
4793# region NetworkProfiles
4794def list_network_profiles(cmd, resource_group_name=None):
4795    client = network_client_factory(cmd.cli_ctx).network_profiles
4796    if resource_group_name:
4797        return client.list(resource_group_name)
4798    return client.list_all()
4799# endregion
4800
4801
4802# region NetworkWatchers
4803def _create_network_watchers(cmd, client, resource_group_name, locations, tags):
4804    if resource_group_name is None:
4805        raise CLIError("usage error: '--resource-group' required when enabling new regions")
4806
4807    NetworkWatcher = cmd.get_models('NetworkWatcher')
4808    for location in locations:
4809        client.create_or_update(
4810            resource_group_name, '{}-watcher'.format(location),
4811            NetworkWatcher(location=location, tags=tags))
4812
4813
4814def _update_network_watchers(cmd, client, watchers, tags):
4815    NetworkWatcher = cmd.get_models('NetworkWatcher')
4816    for watcher in watchers:
4817        id_parts = parse_resource_id(watcher.id)
4818        watcher_rg = id_parts['resource_group']
4819        watcher_name = id_parts['name']
4820        watcher_tags = watcher.tags if tags is None else tags
4821        client.create_or_update(
4822            watcher_rg, watcher_name,
4823            NetworkWatcher(location=watcher.location, tags=watcher_tags))
4824
4825
4826def _delete_network_watchers(cmd, client, watchers):
4827    for watcher in watchers:
4828        from azure.cli.core.commands import LongRunningOperation
4829        id_parts = parse_resource_id(watcher.id)
4830        watcher_rg = id_parts['resource_group']
4831        watcher_name = id_parts['name']
4832        logger.warning(
4833            "Disabling Network Watcher for region '%s' by deleting resource '%s'",
4834            watcher.location, watcher.id)
4835        LongRunningOperation(cmd.cli_ctx)(client.begin_delete(watcher_rg, watcher_name))
4836
4837
4838def configure_network_watcher(cmd, client, locations, resource_group_name=None, enabled=None, tags=None):
4839    watcher_list = list(client.list_all())
4840    locations_list = [location.lower() for location in locations]
4841    existing_watchers = [w for w in watcher_list if w.location in locations_list]
4842    nonenabled_regions = list(set(locations) - set(watcher.location for watcher in existing_watchers))
4843
4844    if enabled is None:
4845        if resource_group_name is not None:
4846            logger.warning(
4847                "Resource group '%s' is only used when enabling new regions and will be ignored.",
4848                resource_group_name)
4849        for location in nonenabled_regions:
4850            logger.warning(
4851                "Region '%s' is not enabled for Network Watcher and will be ignored.", location)
4852        _update_network_watchers(cmd, client, existing_watchers, tags)
4853
4854    elif enabled:
4855        _create_network_watchers(cmd, client, resource_group_name, nonenabled_regions, tags)
4856        _update_network_watchers(cmd, client, existing_watchers, tags)
4857
4858    else:
4859        if tags is not None:
4860            raise CLIError("usage error: '--tags' cannot be used when disabling regions")
4861        _delete_network_watchers(cmd, client, existing_watchers)
4862
4863    return client.list_all()
4864
4865
4866def create_nw_connection_monitor(cmd,
4867                                 client,
4868                                 connection_monitor_name,
4869                                 watcher_rg,
4870                                 watcher_name,
4871                                 resource_group_name=None,
4872                                 location=None,
4873                                 source_resource=None,
4874                                 source_port=None,
4875                                 dest_resource=None,
4876                                 dest_port=None,
4877                                 dest_address=None,
4878                                 tags=None,
4879                                 do_not_start=None,
4880                                 monitoring_interval=None,
4881                                 endpoint_source_name=None,
4882                                 endpoint_source_resource_id=None,
4883                                 endpoint_source_address=None,
4884                                 endpoint_source_type=None,
4885                                 endpoint_source_coverage_level=None,
4886                                 endpoint_dest_name=None,
4887                                 endpoint_dest_resource_id=None,
4888                                 endpoint_dest_address=None,
4889                                 endpoint_dest_type=None,
4890                                 endpoint_dest_coverage_level=None,
4891                                 test_config_name=None,
4892                                 test_config_frequency=None,
4893                                 test_config_protocol=None,
4894                                 test_config_preferred_ip_version=None,
4895                                 test_config_threshold_failed_percent=None,
4896                                 test_config_threshold_round_trip_time=None,
4897                                 test_config_tcp_disable_trace_route=None,
4898                                 test_config_tcp_port=None,
4899                                 test_config_tcp_port_behavior=None,
4900                                 test_config_icmp_disable_trace_route=None,
4901                                 test_config_http_port=None,
4902                                 test_config_http_method=None,
4903                                 test_config_http_path=None,
4904                                 test_config_http_valid_status_codes=None,
4905                                 test_config_http_prefer_https=None,
4906                                 test_group_name=None,
4907                                 test_group_disable=None,
4908                                 output_type=None,
4909                                 workspace_ids=None,
4910                                 notes=None):
4911    v1_required_parameter_set = [
4912        source_resource, source_port,
4913        dest_resource, dest_address, dest_port
4914    ]
4915
4916    v2_required_parameter_set = [
4917        endpoint_source_name, endpoint_source_resource_id, endpoint_source_type, endpoint_source_coverage_level,
4918        endpoint_dest_name, endpoint_dest_address, endpoint_dest_type, endpoint_dest_coverage_level,
4919        test_config_name, test_config_protocol,
4920        output_type, workspace_ids,
4921    ]
4922
4923    if any(v1_required_parameter_set):  # V1 creation
4924        connection_monitor = _create_nw_connection_monitor_v1(cmd,
4925                                                              connection_monitor_name,
4926                                                              watcher_rg,
4927                                                              watcher_name,
4928                                                              source_resource,
4929                                                              resource_group_name,
4930                                                              source_port,
4931                                                              location,
4932                                                              dest_resource,
4933                                                              dest_port,
4934                                                              dest_address,
4935                                                              tags,
4936                                                              do_not_start,
4937                                                              monitoring_interval)
4938        from azure.cli.core.profiles._shared import AD_HOC_API_VERSIONS
4939        client = get_mgmt_service_client(
4940            cmd.cli_ctx,
4941            ResourceType.MGMT_NETWORK,
4942            api_version=AD_HOC_API_VERSIONS[ResourceType.MGMT_NETWORK]['nw_connection_monitor']
4943        ).connection_monitors
4944    elif any(v2_required_parameter_set):  # V2 creation
4945        connection_monitor = _create_nw_connection_monitor_v2(cmd,
4946                                                              location,
4947                                                              tags,
4948                                                              endpoint_source_name,
4949                                                              endpoint_source_resource_id,
4950                                                              endpoint_source_address,
4951                                                              endpoint_source_type,
4952                                                              endpoint_source_coverage_level,
4953                                                              endpoint_dest_name,
4954                                                              endpoint_dest_resource_id,
4955                                                              endpoint_dest_address,
4956                                                              endpoint_dest_type,
4957                                                              endpoint_dest_coverage_level,
4958                                                              test_config_name,
4959                                                              test_config_frequency,
4960                                                              test_config_protocol,
4961                                                              test_config_preferred_ip_version,
4962                                                              test_config_threshold_failed_percent,
4963                                                              test_config_threshold_round_trip_time,
4964                                                              test_config_tcp_port,
4965                                                              test_config_tcp_port_behavior,
4966                                                              test_config_tcp_disable_trace_route,
4967                                                              test_config_icmp_disable_trace_route,
4968                                                              test_config_http_port,
4969                                                              test_config_http_method,
4970                                                              test_config_http_path,
4971                                                              test_config_http_valid_status_codes,
4972                                                              test_config_http_prefer_https,
4973                                                              test_group_name,
4974                                                              test_group_disable,
4975                                                              output_type,
4976                                                              workspace_ids,
4977                                                              notes)
4978    else:
4979        raise CLIError('Unknown operation')
4980
4981    return client.begin_create_or_update(watcher_rg, watcher_name, connection_monitor_name, connection_monitor)
4982
4983
4984def _create_nw_connection_monitor_v1(cmd,
4985                                     connection_monitor_name,
4986                                     watcher_rg,
4987                                     watcher_name,
4988                                     source_resource,
4989                                     resource_group_name=None,
4990                                     source_port=None,
4991                                     location=None,
4992                                     dest_resource=None,
4993                                     dest_port=None,
4994                                     dest_address=None,
4995                                     tags=None,
4996                                     do_not_start=None,
4997                                     monitoring_interval=60):
4998    ConnectionMonitor, ConnectionMonitorSource, ConnectionMonitorDestination = cmd.get_models(
4999        'ConnectionMonitor', 'ConnectionMonitorSource', 'ConnectionMonitorDestination')
5000
5001    cmv1 = ConnectionMonitor(
5002        location=location,
5003        tags=tags,
5004        source=ConnectionMonitorSource(
5005            resource_id=source_resource,
5006            port=source_port
5007        ),
5008        destination=ConnectionMonitorDestination(
5009            resource_id=dest_resource,
5010            port=dest_port,
5011            address=dest_address
5012        ),
5013        auto_start=not do_not_start,
5014        monitoring_interval_in_seconds=monitoring_interval,
5015        endpoints=None,
5016        test_configurations=None,
5017        test_groups=None,
5018        outputs=None,
5019        notes=None
5020    )
5021
5022    return cmv1
5023
5024
5025def _create_nw_connection_monitor_v2(cmd,
5026                                     location=None,
5027                                     tags=None,
5028                                     endpoint_source_name=None,
5029                                     endpoint_source_resource_id=None,
5030                                     endpoint_source_address=None,
5031                                     endpoint_source_type=None,
5032                                     endpoint_source_coverage_level=None,
5033                                     endpoint_dest_name=None,
5034                                     endpoint_dest_resource_id=None,
5035                                     endpoint_dest_address=None,
5036                                     endpoint_dest_type=None,
5037                                     endpoint_dest_coverage_level=None,
5038                                     test_config_name=None,
5039                                     test_config_frequency=None,
5040                                     test_config_protocol=None,
5041                                     test_config_preferred_ip_version=None,
5042                                     test_config_threshold_failed_percent=None,
5043                                     test_config_threshold_round_trip_time=None,
5044                                     test_config_tcp_port=None,
5045                                     test_config_tcp_port_behavior=None,
5046                                     test_config_tcp_disable_trace_route=False,
5047                                     test_config_icmp_disable_trace_route=False,
5048                                     test_config_http_port=None,
5049                                     test_config_http_method=None,
5050                                     test_config_http_path=None,
5051                                     test_config_http_valid_status_codes=None,
5052                                     test_config_http_prefer_https=None,
5053                                     test_group_name=None,
5054                                     test_group_disable=False,
5055                                     output_type=None,
5056                                     workspace_ids=None,
5057                                     notes=None):
5058    src_endpoint = _create_nw_connection_monitor_v2_endpoint(cmd,
5059                                                             endpoint_source_name,
5060                                                             endpoint_resource_id=endpoint_source_resource_id,
5061                                                             address=endpoint_source_address,
5062                                                             endpoint_type=endpoint_source_type,
5063                                                             coverage_level=endpoint_source_coverage_level)
5064    dst_endpoint = _create_nw_connection_monitor_v2_endpoint(cmd,
5065                                                             endpoint_dest_name,
5066                                                             endpoint_resource_id=endpoint_dest_resource_id,
5067                                                             address=endpoint_dest_address,
5068                                                             endpoint_type=endpoint_dest_type,
5069                                                             coverage_level=endpoint_dest_coverage_level)
5070    test_config = _create_nw_connection_monitor_v2_test_configuration(cmd,
5071                                                                      test_config_name,
5072                                                                      test_config_frequency,
5073                                                                      test_config_protocol,
5074                                                                      test_config_threshold_failed_percent,
5075                                                                      test_config_threshold_round_trip_time,
5076                                                                      test_config_preferred_ip_version,
5077                                                                      test_config_tcp_port,
5078                                                                      test_config_tcp_port_behavior,
5079                                                                      test_config_tcp_disable_trace_route,
5080                                                                      test_config_icmp_disable_trace_route,
5081                                                                      test_config_http_port,
5082                                                                      test_config_http_method,
5083                                                                      test_config_http_path,
5084                                                                      test_config_http_valid_status_codes,
5085                                                                      test_config_http_prefer_https)
5086    test_group = _create_nw_connection_monitor_v2_test_group(cmd,
5087                                                             test_group_name,
5088                                                             test_group_disable,
5089                                                             [test_config],
5090                                                             [src_endpoint],
5091                                                             [dst_endpoint])
5092    if output_type:
5093        outputs = []
5094        if workspace_ids:
5095            for workspace_id in workspace_ids:
5096                output = _create_nw_connection_monitor_v2_output(cmd, output_type, workspace_id)
5097                outputs.append(output)
5098    else:
5099        outputs = []
5100
5101    ConnectionMonitor = cmd.get_models('ConnectionMonitor')
5102    cmv2 = ConnectionMonitor(location=location,
5103                             tags=tags,
5104                             auto_start=None,
5105                             monitoring_interval_in_seconds=None,
5106                             endpoints=[src_endpoint, dst_endpoint],
5107                             test_configurations=[test_config],
5108                             test_groups=[test_group],
5109                             outputs=outputs,
5110                             notes=notes)
5111    return cmv2
5112
5113
5114def _create_nw_connection_monitor_v2_endpoint(cmd,
5115                                              name,
5116                                              endpoint_resource_id=None,
5117                                              address=None,
5118                                              filter_type=None,
5119                                              filter_items=None,
5120                                              endpoint_type=None,
5121                                              coverage_level=None):
5122    if (filter_type and not filter_items) or (not filter_type and filter_items):
5123        raise CLIError('usage error: '
5124                       '--filter-type and --filter-item for endpoint filter must be present at the same time.')
5125
5126    ConnectionMonitorEndpoint, ConnectionMonitorEndpointFilter = cmd.get_models(
5127        'ConnectionMonitorEndpoint', 'ConnectionMonitorEndpointFilter')
5128
5129    endpoint = ConnectionMonitorEndpoint(name=name,
5130                                         resource_id=endpoint_resource_id,
5131                                         address=address,
5132                                         type=endpoint_type,
5133                                         coverage_level=coverage_level)
5134
5135    if filter_type and filter_items:
5136        endpoint_filter = ConnectionMonitorEndpointFilter(type=filter_type, items=filter_items)
5137        endpoint.filter = endpoint_filter
5138
5139    return endpoint
5140
5141
5142def _create_nw_connection_monitor_v2_test_configuration(cmd,
5143                                                        name,
5144                                                        test_frequency,
5145                                                        protocol,
5146                                                        threshold_failed_percent,
5147                                                        threshold_round_trip_time,
5148                                                        preferred_ip_version,
5149                                                        tcp_port=None,
5150                                                        tcp_port_behavior=None,
5151                                                        tcp_disable_trace_route=None,
5152                                                        icmp_disable_trace_route=None,
5153                                                        http_port=None,
5154                                                        http_method=None,
5155                                                        http_path=None,
5156                                                        http_valid_status_codes=None,
5157                                                        http_prefer_https=None,
5158                                                        http_request_headers=None):
5159    (ConnectionMonitorTestConfigurationProtocol,
5160     ConnectionMonitorTestConfiguration, ConnectionMonitorSuccessThreshold) = cmd.get_models(
5161         'ConnectionMonitorTestConfigurationProtocol',
5162         'ConnectionMonitorTestConfiguration', 'ConnectionMonitorSuccessThreshold')
5163
5164    test_config = ConnectionMonitorTestConfiguration(name=name,
5165                                                     test_frequency_sec=test_frequency,
5166                                                     protocol=protocol,
5167                                                     preferred_ip_version=preferred_ip_version)
5168
5169    if threshold_failed_percent or threshold_round_trip_time:
5170        threshold = ConnectionMonitorSuccessThreshold(checks_failed_percent=threshold_failed_percent,
5171                                                      round_trip_time_ms=threshold_round_trip_time)
5172        test_config.success_threshold = threshold
5173
5174    if protocol == ConnectionMonitorTestConfigurationProtocol.tcp:
5175        ConnectionMonitorTcpConfiguration = cmd.get_models('ConnectionMonitorTcpConfiguration')
5176        tcp_config = ConnectionMonitorTcpConfiguration(
5177            port=tcp_port,
5178            destination_port_behavior=tcp_port_behavior,
5179            disable_trace_route=tcp_disable_trace_route
5180        )
5181        test_config.tcp_configuration = tcp_config
5182    elif protocol == ConnectionMonitorTestConfigurationProtocol.icmp:
5183        ConnectionMonitorIcmpConfiguration = cmd.get_models('ConnectionMonitorIcmpConfiguration')
5184        icmp_config = ConnectionMonitorIcmpConfiguration(disable_trace_route=icmp_disable_trace_route)
5185        test_config.icmp_configuration = icmp_config
5186    elif protocol == ConnectionMonitorTestConfigurationProtocol.http:
5187        ConnectionMonitorHttpConfiguration = cmd.get_models('ConnectionMonitorHttpConfiguration')
5188        http_config = ConnectionMonitorHttpConfiguration(
5189            port=http_port,
5190            method=http_method,
5191            path=http_path,
5192            request_headers=http_request_headers,
5193            valid_status_code_ranges=http_valid_status_codes,
5194            prefer_https=http_prefer_https)
5195        test_config.http_configuration = http_config
5196    else:
5197        raise CLIError('Unsupported protocol: "{}" for test configuration'.format(protocol))
5198
5199    return test_config
5200
5201
5202def _create_nw_connection_monitor_v2_test_group(cmd,
5203                                                name,
5204                                                disable,
5205                                                test_configurations,
5206                                                source_endpoints,
5207                                                destination_endpoints):
5208    ConnectionMonitorTestGroup = cmd.get_models('ConnectionMonitorTestGroup')
5209
5210    test_group = ConnectionMonitorTestGroup(name=name,
5211                                            disable=disable,
5212                                            test_configurations=[tc.name for tc in test_configurations],
5213                                            sources=[e.name for e in source_endpoints],
5214                                            destinations=[e.name for e in destination_endpoints])
5215    return test_group
5216
5217
5218def _create_nw_connection_monitor_v2_output(cmd,
5219                                            output_type,
5220                                            workspace_id=None):
5221    ConnectionMonitorOutput, OutputType = cmd.get_models('ConnectionMonitorOutput', 'OutputType')
5222    output = ConnectionMonitorOutput(type=output_type)
5223
5224    if output_type == OutputType.workspace:
5225        ConnectionMonitorWorkspaceSettings = cmd.get_models('ConnectionMonitorWorkspaceSettings')
5226        workspace = ConnectionMonitorWorkspaceSettings(workspace_resource_id=workspace_id)
5227        output.workspace_settings = workspace
5228    else:
5229        raise CLIError('Unsupported output type: "{}"'.format(output_type))
5230
5231    return output
5232
5233
5234def add_nw_connection_monitor_v2_endpoint(cmd,
5235                                          client,
5236                                          watcher_rg,
5237                                          watcher_name,
5238                                          connection_monitor_name,
5239                                          location,
5240                                          name,
5241                                          coverage_level=None,
5242                                          endpoint_type=None,
5243                                          source_test_groups=None,
5244                                          dest_test_groups=None,
5245                                          endpoint_resource_id=None,
5246                                          address=None,
5247                                          filter_type=None,
5248                                          filter_items=None,
5249                                          address_include=None,
5250                                          address_exclude=None):
5251    (ConnectionMonitorEndpoint, ConnectionMonitorEndpointFilter,
5252     ConnectionMonitorEndpointScope, ConnectionMonitorEndpointScopeItem) = cmd.get_models(
5253         'ConnectionMonitorEndpoint', 'ConnectionMonitorEndpointFilter',
5254         'ConnectionMonitorEndpointScope', 'ConnectionMonitorEndpointScopeItem')
5255
5256    endpoint_scope = ConnectionMonitorEndpointScope(include=[], exclude=[])
5257    for ip in address_include or []:
5258        include_item = ConnectionMonitorEndpointScopeItem(address=ip)
5259        endpoint_scope.include.append(include_item)
5260    for ip in address_exclude or []:
5261        exclude_item = ConnectionMonitorEndpointScopeItem(address=ip)
5262        endpoint_scope.exclude.append(exclude_item)
5263
5264    endpoint = ConnectionMonitorEndpoint(name=name,
5265                                         resource_id=endpoint_resource_id,
5266                                         address=address,
5267                                         type=endpoint_type,
5268                                         coverage_level=coverage_level,
5269                                         scope=endpoint_scope if address_include or address_exclude else None)
5270
5271    if filter_type and filter_items:
5272        endpoint_filter = ConnectionMonitorEndpointFilter(type=filter_type, items=filter_items)
5273        endpoint.filter = endpoint_filter
5274
5275    connection_monitor = client.get(watcher_rg, watcher_name, connection_monitor_name)
5276    connection_monitor.endpoints.append(endpoint)
5277
5278    src_test_groups, dst_test_groups = set(source_test_groups or []), set(dest_test_groups or [])
5279    for test_group in connection_monitor.test_groups:
5280        if test_group.name in src_test_groups:
5281            test_group.sources.append(endpoint.name)
5282        if test_group.name in dst_test_groups:
5283            test_group.destinations.append(endpoint.name)
5284
5285    return client.begin_create_or_update(watcher_rg, watcher_name, connection_monitor_name, connection_monitor)
5286
5287
5288def remove_nw_connection_monitor_v2_endpoint(client,
5289                                             watcher_rg,
5290                                             watcher_name,
5291                                             connection_monitor_name,
5292                                             location,
5293                                             name,
5294                                             test_groups=None):
5295    connection_monitor = client.get(watcher_rg, watcher_name, connection_monitor_name)
5296
5297    # refresh endpoints
5298    new_endpoints = [endpoint for endpoint in connection_monitor.endpoints if endpoint.name != name]
5299    connection_monitor.endpoints = new_endpoints
5300
5301    # refresh test groups
5302    if test_groups is not None:
5303        temp_test_groups = [t for t in connection_monitor.test_groups if t.name in test_groups]
5304    else:
5305        temp_test_groups = connection_monitor.test_groups
5306
5307    for test_group in temp_test_groups:
5308        if name in test_group.sources:
5309            test_group.sources.remove(name)
5310        if name in test_group.destinations:
5311            test_group.destinations.remove(name)
5312
5313    return client.begin_create_or_update(watcher_rg, watcher_name, connection_monitor_name, connection_monitor)
5314
5315
5316def show_nw_connection_monitor_v2_endpoint(client,
5317                                           watcher_rg,
5318                                           watcher_name,
5319                                           connection_monitor_name,
5320                                           location,
5321                                           name):
5322    connection_monitor = client.get(watcher_rg, watcher_name, connection_monitor_name)
5323
5324    for endpoint in connection_monitor.endpoints:
5325        if endpoint.name == name:
5326            return endpoint
5327
5328    raise CLIError('unknown endpoint: {}'.format(name))
5329
5330
5331def list_nw_connection_monitor_v2_endpoint(client,
5332                                           watcher_rg,
5333                                           watcher_name,
5334                                           connection_monitor_name,
5335                                           location):
5336    connection_monitor = client.get(watcher_rg, watcher_name, connection_monitor_name)
5337    return connection_monitor.endpoints
5338
5339
5340def add_nw_connection_monitor_v2_test_configuration(cmd,
5341                                                    client,
5342                                                    watcher_rg,
5343                                                    watcher_name,
5344                                                    connection_monitor_name,
5345                                                    location,
5346                                                    name,
5347                                                    protocol,
5348                                                    test_groups,
5349                                                    frequency=None,
5350                                                    threshold_failed_percent=None,
5351                                                    threshold_round_trip_time=None,
5352                                                    preferred_ip_version=None,
5353                                                    tcp_port=None,
5354                                                    tcp_port_behavior=None,
5355                                                    tcp_disable_trace_route=None,
5356                                                    icmp_disable_trace_route=None,
5357                                                    http_port=None,
5358                                                    http_method=None,
5359                                                    http_path=None,
5360                                                    http_valid_status_codes=None,
5361                                                    http_prefer_https=None,
5362                                                    http_request_headers=None):
5363    new_test_config = _create_nw_connection_monitor_v2_test_configuration(cmd,
5364                                                                          name,
5365                                                                          frequency,
5366                                                                          protocol,
5367                                                                          threshold_failed_percent,
5368                                                                          threshold_round_trip_time,
5369                                                                          preferred_ip_version,
5370                                                                          tcp_port,
5371                                                                          tcp_port_behavior,
5372                                                                          tcp_disable_trace_route,
5373                                                                          icmp_disable_trace_route,
5374                                                                          http_port,
5375                                                                          http_method,
5376                                                                          http_path,
5377                                                                          http_valid_status_codes,
5378                                                                          http_prefer_https,
5379                                                                          http_request_headers)
5380
5381    connection_monitor = client.get(watcher_rg, watcher_name, connection_monitor_name)
5382    connection_monitor.test_configurations.append(new_test_config)
5383
5384    for test_group in connection_monitor.test_groups:
5385        if test_group.name in test_groups:
5386            test_group.test_configurations.append(new_test_config.name)
5387
5388    return client.begin_create_or_update(watcher_rg, watcher_name, connection_monitor_name, connection_monitor)
5389
5390
5391def remove_nw_connection_monitor_v2_test_configuration(client,
5392                                                       watcher_rg,
5393                                                       watcher_name,
5394                                                       connection_monitor_name,
5395                                                       location,
5396                                                       name,
5397                                                       test_groups=None):
5398    connection_monitor = client.get(watcher_rg, watcher_name, connection_monitor_name)
5399
5400    # refresh test configurations
5401    new_test_configurations = [t for t in connection_monitor.test_configurations if t.name != name]
5402    connection_monitor.test_configurations = new_test_configurations
5403
5404    if test_groups is not None:
5405        temp_test_groups = [t for t in connection_monitor.test_groups if t.name in test_groups]
5406    else:
5407        temp_test_groups = connection_monitor.test_groups
5408
5409    # refresh test groups
5410    for test_group in temp_test_groups:
5411        test_group.test_configurations.remove(name)
5412
5413    return client.begin_create_or_update(watcher_rg, watcher_name, connection_monitor_name, connection_monitor)
5414
5415
5416def show_nw_connection_monitor_v2_test_configuration(client,
5417                                                     watcher_rg,
5418                                                     watcher_name,
5419                                                     connection_monitor_name,
5420                                                     location,
5421                                                     name):
5422    connection_monitor = client.get(watcher_rg, watcher_name, connection_monitor_name)
5423
5424    for test_config in connection_monitor.test_configurations:
5425        if test_config.name == name:
5426            return test_config
5427
5428    raise CLIError('unknown test configuration: {}'.format(name))
5429
5430
5431def list_nw_connection_monitor_v2_test_configuration(client,
5432                                                     watcher_rg,
5433                                                     watcher_name,
5434                                                     connection_monitor_name,
5435                                                     location):
5436    connection_monitor = client.get(watcher_rg, watcher_name, connection_monitor_name)
5437    return connection_monitor.test_configurations
5438
5439
5440def add_nw_connection_monitor_v2_test_group(cmd,
5441                                            client,
5442                                            connection_monitor_name,
5443                                            watcher_rg,
5444                                            watcher_name,
5445                                            location,
5446                                            name,
5447                                            endpoint_source_name,
5448                                            endpoint_dest_name,
5449                                            test_config_name,
5450                                            disable=False,
5451                                            endpoint_source_resource_id=None,
5452                                            endpoint_source_address=None,
5453                                            endpoint_dest_resource_id=None,
5454                                            endpoint_dest_address=None,
5455                                            test_config_frequency=None,
5456                                            test_config_protocol=None,
5457                                            test_config_preferred_ip_version=None,
5458                                            test_config_threshold_failed_percent=None,
5459                                            test_config_threshold_round_trip_time=None,
5460                                            test_config_tcp_disable_trace_route=None,
5461                                            test_config_tcp_port=None,
5462                                            test_config_icmp_disable_trace_route=None,
5463                                            test_config_http_port=None,
5464                                            test_config_http_method=None,
5465                                            test_config_http_path=None,
5466                                            test_config_http_valid_status_codes=None,
5467                                            test_config_http_prefer_https=None):
5468    new_test_configuration_creation_requirements = [
5469        test_config_protocol, test_config_preferred_ip_version,
5470        test_config_threshold_failed_percent, test_config_threshold_round_trip_time,
5471        test_config_tcp_disable_trace_route, test_config_tcp_port,
5472        test_config_icmp_disable_trace_route,
5473        test_config_http_port, test_config_http_method,
5474        test_config_http_path, test_config_http_valid_status_codes, test_config_http_prefer_https
5475    ]
5476
5477    connection_monitor = client.get(watcher_rg, watcher_name, connection_monitor_name)
5478
5479    new_test_group = _create_nw_connection_monitor_v2_test_group(cmd,
5480                                                                 name,
5481                                                                 disable,
5482                                                                 [], [], [])
5483
5484    # deal with endpoint
5485    if any([endpoint_source_address, endpoint_source_resource_id]):
5486        src_endpoint = _create_nw_connection_monitor_v2_endpoint(cmd,
5487                                                                 endpoint_source_name,
5488                                                                 endpoint_source_resource_id,
5489                                                                 endpoint_source_address)
5490        connection_monitor.endpoints.append(src_endpoint)
5491    if any([endpoint_dest_address, endpoint_dest_resource_id]):
5492        dst_endpoint = _create_nw_connection_monitor_v2_endpoint(cmd,
5493                                                                 endpoint_dest_name,
5494                                                                 endpoint_dest_resource_id,
5495                                                                 endpoint_dest_address)
5496        connection_monitor.endpoints.append(dst_endpoint)
5497
5498    new_test_group.sources.append(endpoint_source_name)
5499    new_test_group.destinations.append(endpoint_dest_name)
5500
5501    # deal with test configuration
5502    if any(new_test_configuration_creation_requirements):
5503        test_config = _create_nw_connection_monitor_v2_test_configuration(cmd,
5504                                                                          test_config_name,
5505                                                                          test_config_frequency,
5506                                                                          test_config_protocol,
5507                                                                          test_config_threshold_failed_percent,
5508                                                                          test_config_threshold_round_trip_time,
5509                                                                          test_config_preferred_ip_version,
5510                                                                          test_config_tcp_port,
5511                                                                          test_config_tcp_disable_trace_route,
5512                                                                          test_config_icmp_disable_trace_route,
5513                                                                          test_config_http_port,
5514                                                                          test_config_http_method,
5515                                                                          test_config_http_path,
5516                                                                          test_config_http_valid_status_codes,
5517                                                                          test_config_http_prefer_https)
5518        connection_monitor.test_configurations.append(test_config)
5519    new_test_group.test_configurations.append(test_config_name)
5520
5521    connection_monitor.test_groups.append(new_test_group)
5522
5523    return client.begin_create_or_update(watcher_rg, watcher_name, connection_monitor_name, connection_monitor)
5524
5525
5526def remove_nw_connection_monitor_v2_test_group(client,
5527                                               watcher_rg,
5528                                               watcher_name,
5529                                               connection_monitor_name,
5530                                               location,
5531                                               name):
5532    connection_monitor = client.get(watcher_rg, watcher_name, connection_monitor_name)
5533
5534    new_test_groups, removed_test_group = [], None
5535    for t in connection_monitor.test_groups:
5536        if t.name == name:
5537            removed_test_group = t
5538        else:
5539            new_test_groups.append(t)
5540
5541    if removed_test_group is None:
5542        raise CLIError('test group: "{}" not exist'.format(name))
5543    connection_monitor.test_groups = new_test_groups
5544
5545    # deal with endpoints which are only referenced by this removed test group
5546    removed_endpoints = []
5547    for e in removed_test_group.sources + removed_test_group.destinations:
5548        tmp = [t for t in connection_monitor.test_groups if (e in t.sources or e in t.destinations)]
5549        if not tmp:
5550            removed_endpoints.append(e)
5551    connection_monitor.endpoints = [e for e in connection_monitor.endpoints if e.name not in removed_endpoints]
5552
5553    # deal with test configurations which are only referenced by this remove test group
5554    removed_test_configurations = []
5555    for c in removed_test_group.test_configurations:
5556        tmp = [t for t in connection_monitor.test_groups if c in t.test_configurations]
5557        if not tmp:
5558            removed_test_configurations.append(c)
5559    connection_monitor.test_configurations = [c for c in connection_monitor.test_configurations
5560                                              if c.name not in removed_test_configurations]
5561
5562    return client.begin_create_or_update(watcher_rg, watcher_name, connection_monitor_name, connection_monitor)
5563
5564
5565def show_nw_connection_monitor_v2_test_group(client,
5566                                             watcher_rg,
5567                                             watcher_name,
5568                                             connection_monitor_name,
5569                                             location,
5570                                             name):
5571    connection_monitor = client.get(watcher_rg, watcher_name, connection_monitor_name)
5572
5573    for t in connection_monitor.test_groups:
5574        if t.name == name:
5575            return t
5576
5577    raise CLIError('unknown test group: {}'.format(name))
5578
5579
5580def list_nw_connection_monitor_v2_test_group(client,
5581                                             watcher_rg,
5582                                             watcher_name,
5583                                             connection_monitor_name,
5584                                             location):
5585    connection_monitor = client.get(watcher_rg, watcher_name, connection_monitor_name)
5586    return connection_monitor.test_groups
5587
5588
5589def add_nw_connection_monitor_v2_output(cmd,
5590                                        client,
5591                                        watcher_rg,
5592                                        watcher_name,
5593                                        connection_monitor_name,
5594                                        location,
5595                                        out_type,
5596                                        workspace_id=None):
5597    output = _create_nw_connection_monitor_v2_output(cmd, out_type, workspace_id)
5598
5599    connection_monitor = client.get(watcher_rg, watcher_name, connection_monitor_name)
5600
5601    if connection_monitor.outputs is None:
5602        connection_monitor.outputs = []
5603
5604    connection_monitor.outputs.append(output)
5605
5606    return client.begin_create_or_update(watcher_rg, watcher_name, connection_monitor_name, connection_monitor)
5607
5608
5609def remove_nw_connection_monitor_v2_output(client,
5610                                           watcher_rg,
5611                                           watcher_name,
5612                                           connection_monitor_name,
5613                                           location):
5614    connection_monitor = client.get(watcher_rg, watcher_name, connection_monitor_name)
5615    connection_monitor.outputs = []
5616
5617    return client.begin_create_or_update(watcher_rg, watcher_name, connection_monitor_name, connection_monitor)
5618
5619
5620def list_nw_connection_monitor_v2_output(client,
5621                                         watcher_rg,
5622                                         watcher_name,
5623                                         connection_monitor_name,
5624                                         location):
5625    connection_monitor = client.get(watcher_rg, watcher_name, connection_monitor_name)
5626    return connection_monitor.outputs
5627
5628
5629def show_topology_watcher(cmd, client, resource_group_name, network_watcher_name, target_resource_group_name=None,
5630                          target_vnet=None, target_subnet=None):  # pylint: disable=unused-argument
5631    TopologyParameters = cmd.get_models('TopologyParameters')
5632    return client.get_topology(
5633        resource_group_name=resource_group_name,
5634        network_watcher_name=network_watcher_name,
5635        parameters=TopologyParameters(
5636            target_resource_group_name=target_resource_group_name,
5637            target_virtual_network=target_vnet,
5638            target_subnet=target_subnet
5639        ))
5640
5641
5642def check_nw_connectivity(cmd, client, watcher_rg, watcher_name, source_resource, source_port=None,
5643                          dest_resource=None, dest_port=None, dest_address=None,
5644                          resource_group_name=None, protocol=None, method=None, headers=None, valid_status_codes=None):
5645    ConnectivitySource, ConnectivityDestination, ConnectivityParameters, ProtocolConfiguration, HTTPConfiguration = \
5646        cmd.get_models(
5647            'ConnectivitySource', 'ConnectivityDestination', 'ConnectivityParameters', 'ProtocolConfiguration',
5648            'HTTPConfiguration')
5649    params = ConnectivityParameters(
5650        source=ConnectivitySource(resource_id=source_resource, port=source_port),
5651        destination=ConnectivityDestination(resource_id=dest_resource, address=dest_address, port=dest_port),
5652        protocol=protocol
5653    )
5654    if any([method, headers, valid_status_codes]):
5655        params.protocol_configuration = ProtocolConfiguration(http_configuration=HTTPConfiguration(
5656            method=method,
5657            headers=headers,
5658            valid_status_codes=valid_status_codes
5659        ))
5660    return client.begin_check_connectivity(watcher_rg, watcher_name, params)
5661
5662
5663def check_nw_ip_flow(cmd, client, vm, watcher_rg, watcher_name, direction, protocol, local, remote,
5664                     resource_group_name=None, nic=None, location=None):
5665    VerificationIPFlowParameters = cmd.get_models('VerificationIPFlowParameters')
5666
5667    try:
5668        local_ip_address, local_port = local.split(':')
5669        remote_ip_address, remote_port = remote.split(':')
5670    except:
5671        raise CLIError("usage error: the format of the '--local' and '--remote' should be like x.x.x.x:port")
5672
5673    if not is_valid_resource_id(vm):
5674        if not resource_group_name:
5675            raise CLIError("usage error: --vm NAME --resource-group NAME | --vm ID")
5676
5677        vm = resource_id(
5678            subscription=get_subscription_id(cmd.cli_ctx), resource_group=resource_group_name,
5679            namespace='Microsoft.Compute', type='virtualMachines', name=vm)
5680
5681    if nic and not is_valid_resource_id(nic):
5682        if not resource_group_name:
5683            raise CLIError("usage error: --nic NAME --resource-group NAME | --nic ID")
5684
5685        nic = resource_id(
5686            subscription=get_subscription_id(cmd.cli_ctx), resource_group=resource_group_name,
5687            namespace='Microsoft.Network', type='networkInterfaces', name=nic)
5688
5689    return client.begin_verify_ip_flow(
5690        watcher_rg, watcher_name,
5691        VerificationIPFlowParameters(
5692            target_resource_id=vm, direction=direction, protocol=protocol, local_port=local_port,
5693            remote_port=remote_port, local_ip_address=local_ip_address,
5694            remote_ip_address=remote_ip_address, target_nic_resource_id=nic))
5695
5696
5697def show_nw_next_hop(cmd, client, resource_group_name, vm, watcher_rg, watcher_name,
5698                     source_ip, dest_ip, nic=None, location=None):
5699    NextHopParameters = cmd.get_models('NextHopParameters')
5700
5701    if not is_valid_resource_id(vm):
5702        vm = resource_id(
5703            subscription=get_subscription_id(cmd.cli_ctx), resource_group=resource_group_name,
5704            namespace='Microsoft.Compute', type='virtualMachines', name=vm)
5705
5706    if nic and not is_valid_resource_id(nic):
5707        nic = resource_id(
5708            subscription=get_subscription_id(cmd.cli_ctx), resource_group=resource_group_name,
5709            namespace='Microsoft.Network', type='networkInterfaces', name=nic)
5710
5711    return client.begin_get_next_hop(
5712        watcher_rg, watcher_name, NextHopParameters(target_resource_id=vm,
5713                                                    source_ip_address=source_ip,
5714                                                    destination_ip_address=dest_ip,
5715                                                    target_nic_resource_id=nic))
5716
5717
5718def show_nw_security_view(cmd, client, resource_group_name, vm, watcher_rg, watcher_name, location=None):
5719    if not is_valid_resource_id(vm):
5720        vm = resource_id(
5721            subscription=get_subscription_id(cmd.cli_ctx), resource_group=resource_group_name,
5722            namespace='Microsoft.Compute', type='virtualMachines', name=vm)
5723
5724    security_group_view_parameters = cmd.get_models('SecurityGroupViewParameters')(target_resource_id=vm)
5725    return client.begin_get_vm_security_rules(watcher_rg, watcher_name, security_group_view_parameters)
5726
5727
5728def create_nw_packet_capture(cmd, client, resource_group_name, capture_name, vm,
5729                             watcher_rg, watcher_name, location=None,
5730                             storage_account=None, storage_path=None, file_path=None,
5731                             capture_size=None, capture_limit=None, time_limit=None, filters=None):
5732    PacketCapture, PacketCaptureStorageLocation = cmd.get_models('PacketCapture', 'PacketCaptureStorageLocation')
5733
5734    storage_settings = PacketCaptureStorageLocation(storage_id=storage_account,
5735                                                    storage_path=storage_path, file_path=file_path)
5736    capture_params = PacketCapture(target=vm, storage_location=storage_settings,
5737                                   bytes_to_capture_per_packet=capture_size,
5738                                   total_bytes_per_session=capture_limit, time_limit_in_seconds=time_limit,
5739                                   filters=filters)
5740    return client.begin_create(watcher_rg, watcher_name, capture_name, capture_params)
5741
5742
5743def set_nsg_flow_logging(cmd, client, watcher_rg, watcher_name, nsg, storage_account=None,
5744                         resource_group_name=None, enabled=None, retention=0, log_format=None, log_version=None,
5745                         traffic_analytics_workspace=None, traffic_analytics_interval=None,
5746                         traffic_analytics_enabled=None):
5747    from azure.cli.core.commands import LongRunningOperation
5748    flowlog_status_parameters = cmd.get_models('FlowLogStatusParameters')(target_resource_id=nsg)
5749    config = LongRunningOperation(cmd.cli_ctx)(client.begin_get_flow_log_status(watcher_rg,
5750                                                                                watcher_name,
5751                                                                                flowlog_status_parameters))
5752
5753    try:
5754        if not config.flow_analytics_configuration.network_watcher_flow_analytics_configuration.workspace_id:
5755            config.flow_analytics_configuration = None
5756    except AttributeError:
5757        config.flow_analytics_configuration = None
5758
5759    with cmd.update_context(config) as c:
5760        c.set_param('enabled', enabled if enabled is not None else config.enabled)
5761        c.set_param('storage_id', storage_account or config.storage_id)
5762    if retention is not None:
5763        config.retention_policy = {
5764            'days': retention,
5765            'enabled': int(retention) > 0
5766        }
5767    if cmd.supported_api_version(min_api='2018-10-01') and (log_format or log_version):
5768        config.format = {
5769            'type': log_format,
5770            'version': log_version
5771        }
5772
5773    if cmd.supported_api_version(min_api='2018-10-01') and \
5774            any([traffic_analytics_workspace is not None, traffic_analytics_enabled is not None]):
5775        workspace = None
5776
5777        if traffic_analytics_workspace:
5778            from azure.cli.core.commands.arm import get_arm_resource_by_id
5779            workspace = get_arm_resource_by_id(cmd.cli_ctx, traffic_analytics_workspace)
5780
5781        if not config.flow_analytics_configuration:
5782            # must create whole object
5783            if not workspace:
5784                raise CLIError('usage error (analytics not already configured): --workspace NAME_OR_ID '
5785                               '[--enabled {true|false}]')
5786            if traffic_analytics_enabled is None:
5787                traffic_analytics_enabled = True
5788            config.flow_analytics_configuration = {
5789                'network_watcher_flow_analytics_configuration': {
5790                    'enabled': traffic_analytics_enabled,
5791                    'workspace_id': workspace.properties['customerId'],
5792                    'workspace_region': workspace.location,
5793                    'workspace_resource_id': traffic_analytics_workspace,
5794                    'traffic_analytics_interval': traffic_analytics_interval
5795                }
5796            }
5797        else:
5798            # pylint: disable=line-too-long
5799            with cmd.update_context(config.flow_analytics_configuration.network_watcher_flow_analytics_configuration) as c:
5800                # update object
5801                c.set_param('enabled', traffic_analytics_enabled)
5802                if traffic_analytics_workspace == "":
5803                    config.flow_analytics_configuration = None
5804                elif workspace:
5805                    c.set_param('workspace_id', workspace.properties['customerId'])
5806                    c.set_param('workspace_region', workspace.location)
5807                    c.set_param('workspace_resource_id', traffic_analytics_workspace)
5808                    c.set_param('traffic_analytics_interval', traffic_analytics_interval)
5809
5810    return client.begin_set_flow_log_configuration(watcher_rg, watcher_name, config)
5811
5812
5813# combination of resource_group_name and nsg is for old output
5814# combination of location and flow_log_name is for new output
5815def show_nsg_flow_logging(cmd, client, watcher_rg, watcher_name, location=None, resource_group_name=None, nsg=None,
5816                          flow_log_name=None):
5817    # deprecated approach to show flow log
5818    if nsg is not None:
5819        flowlog_status_parameters = cmd.get_models('FlowLogStatusParameters')(target_resource_id=nsg)
5820        return client.begin_get_flow_log_status(watcher_rg, watcher_name, flowlog_status_parameters)
5821
5822    # new approach to show flow log
5823    from ._client_factory import cf_flow_logs
5824    client = cf_flow_logs(cmd.cli_ctx, None)
5825    return client.get(watcher_rg, watcher_name, flow_log_name)
5826
5827
5828def create_nw_flow_log(cmd,
5829                       client,
5830                       location,
5831                       watcher_rg,
5832                       watcher_name,
5833                       flow_log_name,
5834                       nsg,
5835                       storage_account=None,
5836                       resource_group_name=None,
5837                       enabled=None,
5838                       retention=0,
5839                       log_format=None,
5840                       log_version=None,
5841                       traffic_analytics_workspace=None,
5842                       traffic_analytics_interval=60,
5843                       traffic_analytics_enabled=None,
5844                       tags=None):
5845    FlowLog = cmd.get_models('FlowLog')
5846    flow_log = FlowLog(location=location,
5847                       target_resource_id=nsg,
5848                       storage_id=storage_account,
5849                       enabled=enabled,
5850                       tags=tags)
5851
5852    if retention > 0:
5853        RetentionPolicyParameters = cmd.get_models('RetentionPolicyParameters')
5854        retention_policy = RetentionPolicyParameters(days=retention, enabled=(retention > 0))
5855        flow_log.retention_policy = retention_policy
5856
5857    if log_format is not None or log_version is not None:
5858        FlowLogFormatParameters = cmd.get_models('FlowLogFormatParameters')
5859        format_config = FlowLogFormatParameters(type=log_format, version=log_version)
5860        flow_log.format = format_config
5861
5862    if traffic_analytics_workspace is not None:
5863        TrafficAnalyticsProperties, TrafficAnalyticsConfigurationProperties = \
5864            cmd.get_models('TrafficAnalyticsProperties', 'TrafficAnalyticsConfigurationProperties')
5865
5866        from azure.cli.core.commands.arm import get_arm_resource_by_id
5867        workspace = get_arm_resource_by_id(cmd.cli_ctx, traffic_analytics_workspace)
5868        if not workspace:
5869            raise CLIError('Name or ID of workspace is invalid')
5870
5871        traffic_analytics_config = TrafficAnalyticsConfigurationProperties(
5872            enabled=traffic_analytics_enabled,
5873            workspace_id=workspace.properties['customerId'],
5874            workspace_region=workspace.location,
5875            workspace_resource_id=workspace.id,
5876            traffic_analytics_interval=traffic_analytics_interval
5877        )
5878        traffic_analytics = TrafficAnalyticsProperties(
5879            network_watcher_flow_analytics_configuration=traffic_analytics_config
5880        )
5881
5882        flow_log.flow_analytics_configuration = traffic_analytics
5883
5884    return client.begin_create_or_update(watcher_rg, watcher_name, flow_log_name, flow_log)
5885
5886
5887def update_nw_flow_log_getter(client, watcher_rg, watcher_name, flow_log_name):
5888    return client.get(watcher_rg, watcher_name, flow_log_name)
5889
5890
5891def update_nw_flow_log_setter(client, watcher_rg, watcher_name, flow_log_name, parameters):
5892    return client.begin_create_or_update(watcher_rg, watcher_name, flow_log_name, parameters)
5893
5894
5895def update_nw_flow_log(cmd,
5896                       instance,
5897                       location,
5898                       resource_group_name=None,    # dummy parameter to let it appear in command
5899                       enabled=None,
5900                       nsg=None,
5901                       storage_account=None,
5902                       retention=0,
5903                       log_format=None,
5904                       log_version=None,
5905                       traffic_analytics_workspace=None,
5906                       traffic_analytics_interval=60,
5907                       traffic_analytics_enabled=None,
5908                       tags=None):
5909    with cmd.update_context(instance) as c:
5910        c.set_param('enabled', enabled)
5911        c.set_param('tags', tags)
5912        c.set_param('storage_id', storage_account)
5913        c.set_param('target_resource_id', nsg)
5914
5915    with cmd.update_context(instance.retention_policy) as c:
5916        c.set_param('days', retention)
5917        c.set_param('enabled', retention > 0)
5918
5919    with cmd.update_context(instance.format) as c:
5920        c.set_param('type', log_format)
5921        c.set_param('version', log_version)
5922
5923    if traffic_analytics_workspace is not None:
5924        from azure.cli.core.commands.arm import get_arm_resource_by_id
5925        workspace = get_arm_resource_by_id(cmd.cli_ctx, traffic_analytics_workspace)
5926        if not workspace:
5927            raise CLIError('Name or ID of workspace is invalid')
5928
5929        if instance.flow_analytics_configuration.network_watcher_flow_analytics_configuration is None:
5930            analytics_conf = cmd.get_models('TrafficAnalyticsConfigurationProperties')
5931            instance.flow_analytics_configuration.network_watcher_flow_analytics_configuration = analytics_conf()
5932
5933        with cmd.update_context(
5934                instance.flow_analytics_configuration.network_watcher_flow_analytics_configuration) as c:
5935            c.set_param('enabled', traffic_analytics_enabled)
5936            c.set_param('workspace_id', workspace.properties['customerId'])
5937            c.set_param('workspace_region', workspace.location)
5938            c.set_param('workspace_resource_id', workspace.id)
5939            c.set_param('traffic_analytics_interval', traffic_analytics_interval)
5940
5941    return instance
5942
5943
5944def list_nw_flow_log(client, watcher_rg, watcher_name, location):
5945    return client.list(watcher_rg, watcher_name)
5946
5947
5948def delete_nw_flow_log(client, watcher_rg, watcher_name, location, flow_log_name):
5949    return client.begin_delete(watcher_rg, watcher_name, flow_log_name)
5950
5951
5952def start_nw_troubleshooting(cmd, client, watcher_name, watcher_rg, resource, storage_account,
5953                             storage_path, resource_type=None, resource_group_name=None,
5954                             no_wait=False):
5955    TroubleshootingParameters = cmd.get_models('TroubleshootingParameters')
5956    params = TroubleshootingParameters(target_resource_id=resource, storage_id=storage_account,
5957                                       storage_path=storage_path)
5958    return sdk_no_wait(no_wait, client.begin_get_troubleshooting, watcher_rg, watcher_name, params)
5959
5960
5961def show_nw_troubleshooting_result(cmd, client, watcher_name, watcher_rg, resource, resource_type=None,
5962                                   resource_group_name=None):
5963    query_troubleshooting_parameters = cmd.get_models('QueryTroubleshootingParameters')(target_resource_id=resource)
5964    return client.begin_get_troubleshooting_result(watcher_rg, watcher_name, query_troubleshooting_parameters)
5965
5966
5967def run_network_configuration_diagnostic(cmd, client, watcher_rg, watcher_name, resource,
5968                                         direction=None, protocol=None, source=None, destination=None,
5969                                         destination_port=None, queries=None,
5970                                         resource_group_name=None, resource_type=None, parent=None):
5971    NetworkConfigurationDiagnosticParameters, NetworkConfigurationDiagnosticProfile = \
5972        cmd.get_models('NetworkConfigurationDiagnosticParameters', 'NetworkConfigurationDiagnosticProfile')
5973
5974    if not queries:
5975        queries = [NetworkConfigurationDiagnosticProfile(
5976            direction=direction,
5977            protocol=protocol,
5978            source=source,
5979            destination=destination,
5980            destination_port=destination_port
5981        )]
5982    params = NetworkConfigurationDiagnosticParameters(target_resource_id=resource, profiles=queries)
5983    return client.begin_get_network_configuration_diagnostic(watcher_rg, watcher_name, params)
5984# endregion
5985
5986
5987# region CustomIpPrefix
5988def create_custom_ip_prefix(cmd, client, resource_group_name, custom_ip_prefix_name, location=None,
5989                            cidr=None, tags=None, zone=None, signed_message=None, authorization_message=None,
5990                            custom_ip_prefix_parent=None, no_wait=False):
5991
5992    CustomIpPrefix = cmd.get_models('CustomIpPrefix')
5993    prefix = CustomIpPrefix(
5994        location=location,
5995        cidr=cidr,
5996        zones=zone,
5997        tags=tags,
5998        signed_message=signed_message,
5999        authorization_message=authorization_message
6000    )
6001
6002    if custom_ip_prefix_parent:
6003        try:
6004            prefix.custom_ip_prefix_parent = client.get(resource_group_name, custom_ip_prefix_name)
6005        except ResourceNotFoundError:
6006            raise ResourceNotFoundError("Custom ip prefix parent {} doesn't exist".format(custom_ip_prefix_name))
6007
6008    return sdk_no_wait(no_wait, client.begin_create_or_update, resource_group_name, custom_ip_prefix_name, prefix)
6009
6010
6011def update_custom_ip_prefix(instance,
6012                            signed_message=None,
6013                            authorization_message=None,
6014                            tags=None,
6015                            commissioned_state=None):
6016    if tags is not None:
6017        instance.tags = tags
6018    if signed_message is not None:
6019        instance.signed_message = signed_message
6020    if authorization_message is not None:
6021        instance.authorization_message = authorization_message
6022    if commissioned_state is not None:
6023        instance.commissioned_state = commissioned_state[0].upper() + commissioned_state[1:] + 'ing'
6024    return instance
6025# endregion
6026
6027
6028# region PublicIPAddresses
6029def create_public_ip(cmd, resource_group_name, public_ip_address_name, location=None, tags=None,
6030                     allocation_method=None, dns_name=None,
6031                     idle_timeout=4, reverse_fqdn=None, version=None, sku=None, tier=None, zone=None, ip_tags=None,
6032                     public_ip_prefix=None, edge_zone=None, ip_address=None):
6033    IPAllocationMethod, PublicIPAddress, PublicIPAddressDnsSettings, SubResource = cmd.get_models(
6034        'IPAllocationMethod', 'PublicIPAddress', 'PublicIPAddressDnsSettings', 'SubResource')
6035    client = network_client_factory(cmd.cli_ctx).public_ip_addresses
6036    if not allocation_method:
6037        allocation_method = IPAllocationMethod.static.value if (sku and sku.lower() == 'standard') \
6038            else IPAllocationMethod.dynamic.value
6039
6040    public_ip_args = {
6041        'location': location,
6042        'tags': tags,
6043        'public_ip_allocation_method': allocation_method,
6044        'idle_timeout_in_minutes': idle_timeout,
6045        'ip_address': ip_address,
6046        'dns_settings': None
6047    }
6048    if cmd.supported_api_version(min_api='2016-09-01'):
6049        public_ip_args['public_ip_address_version'] = version
6050    if cmd.supported_api_version(min_api='2017-06-01'):
6051        public_ip_args['zones'] = zone
6052    if cmd.supported_api_version(min_api='2017-11-01'):
6053        public_ip_args['ip_tags'] = ip_tags
6054    if cmd.supported_api_version(min_api='2018-07-01') and public_ip_prefix:
6055        public_ip_args['public_ip_prefix'] = SubResource(id=public_ip_prefix)
6056
6057    if sku:
6058        public_ip_args['sku'] = {'name': sku}
6059    if tier:
6060        if not sku:
6061            public_ip_args['sku'] = {'name': 'Basic'}
6062        public_ip_args['sku'].update({'tier': tier})
6063
6064    public_ip = PublicIPAddress(**public_ip_args)
6065
6066    if dns_name or reverse_fqdn:
6067        public_ip.dns_settings = PublicIPAddressDnsSettings(
6068            domain_name_label=dns_name,
6069            reverse_fqdn=reverse_fqdn)
6070
6071    if edge_zone:
6072        public_ip.extended_location = _edge_zone_model(cmd, edge_zone)
6073    return client.begin_create_or_update(resource_group_name, public_ip_address_name, public_ip)
6074
6075
6076def update_public_ip(cmd, instance, dns_name=None, allocation_method=None, version=None,
6077                     idle_timeout=None, reverse_fqdn=None, tags=None, sku=None, ip_tags=None,
6078                     public_ip_prefix=None):
6079    if dns_name is not None or reverse_fqdn is not None:
6080        if instance.dns_settings:
6081            if dns_name is not None:
6082                instance.dns_settings.domain_name_label = dns_name
6083            if reverse_fqdn is not None:
6084                instance.dns_settings.reverse_fqdn = reverse_fqdn
6085        else:
6086            PublicIPAddressDnsSettings = cmd.get_models('PublicIPAddressDnsSettings')
6087            instance.dns_settings = PublicIPAddressDnsSettings(domain_name_label=dns_name, fqdn=None,
6088                                                               reverse_fqdn=reverse_fqdn)
6089    if allocation_method is not None:
6090        instance.public_ip_allocation_method = allocation_method
6091    if version is not None:
6092        instance.public_ip_address_version = version
6093    if idle_timeout is not None:
6094        instance.idle_timeout_in_minutes = idle_timeout
6095    if tags is not None:
6096        instance.tags = tags
6097    if sku is not None:
6098        instance.sku.name = sku
6099    if ip_tags:
6100        instance.ip_tags = ip_tags
6101    if public_ip_prefix:
6102        SubResource = cmd.get_models('SubResource')
6103        instance.public_ip_prefix = SubResource(id=public_ip_prefix)
6104    return instance
6105
6106
6107def create_public_ip_prefix(cmd, client, resource_group_name, public_ip_prefix_name, prefix_length,
6108                            version=None, location=None, tags=None, zone=None, edge_zone=None,
6109                            custom_ip_prefix_name=None):
6110    PublicIPPrefix, PublicIPPrefixSku = cmd.get_models('PublicIPPrefix', 'PublicIPPrefixSku')
6111    prefix = PublicIPPrefix(
6112        location=location,
6113        prefix_length=prefix_length,
6114        sku=PublicIPPrefixSku(name='Standard'),
6115        tags=tags,
6116        zones=zone
6117    )
6118
6119    if cmd.supported_api_version(min_api='2019-08-01'):
6120        prefix.public_ip_address_version = version if version is not None else 'ipv4'
6121
6122    if cmd.supported_api_version(min_api='2020-06-01') and custom_ip_prefix_name:
6123        cip_client = network_client_factory(cmd.cli_ctx).custom_ip_prefixes
6124        try:
6125            prefix.custom_ip_prefix = cip_client.get(resource_group_name, custom_ip_prefix_name)
6126        except ResourceNotFoundError:
6127            raise ResourceNotFoundError('Custom ip prefix {} doesn\'t exist.'.format(custom_ip_prefix_name))
6128
6129    if edge_zone:
6130        prefix.extended_location = _edge_zone_model(cmd, edge_zone)
6131    return client.begin_create_or_update(resource_group_name, public_ip_prefix_name, prefix)
6132
6133
6134def update_public_ip_prefix(instance, tags=None):
6135    if tags is not None:
6136        instance.tags = tags
6137    return instance
6138# endregion
6139
6140
6141# region RouteFilters
6142def create_route_filter(cmd, client, resource_group_name, route_filter_name, location=None, tags=None):
6143    RouteFilter = cmd.get_models('RouteFilter')
6144    return client.begin_create_or_update(resource_group_name, route_filter_name,
6145                                         RouteFilter(location=location, tags=tags))
6146
6147
6148def list_route_filters(client, resource_group_name=None):
6149    if resource_group_name:
6150        return client.list_by_resource_group(resource_group_name)
6151
6152    return client.list()
6153
6154
6155def create_route_filter_rule(cmd, client, resource_group_name, route_filter_name, rule_name, access, communities,
6156                             location=None):
6157    RouteFilterRule = cmd.get_models('RouteFilterRule')
6158    return client.begin_create_or_update(resource_group_name, route_filter_name, rule_name,
6159                                         RouteFilterRule(access=access, communities=communities,
6160                                                         location=location))
6161
6162# endregion
6163
6164
6165# region RouteTables
6166def create_route_table(cmd, resource_group_name, route_table_name, location=None, tags=None,
6167                       disable_bgp_route_propagation=None):
6168    RouteTable = cmd.get_models('RouteTable')
6169    ncf = network_client_factory(cmd.cli_ctx)
6170    route_table = RouteTable(location=location, tags=tags)
6171    if cmd.supported_api_version(min_api='2017-10-01'):
6172        route_table.disable_bgp_route_propagation = disable_bgp_route_propagation
6173    return ncf.route_tables.begin_create_or_update(resource_group_name, route_table_name, route_table)
6174
6175
6176def update_route_table(instance, tags=None, disable_bgp_route_propagation=None):
6177    if tags == '':
6178        instance.tags = None
6179    elif tags is not None:
6180        instance.tags = tags
6181    if disable_bgp_route_propagation is not None:
6182        instance.disable_bgp_route_propagation = disable_bgp_route_propagation
6183    return instance
6184
6185
6186def create_route(cmd, resource_group_name, route_table_name, route_name, next_hop_type, address_prefix,
6187                 next_hop_ip_address=None):
6188    Route = cmd.get_models('Route')
6189    route = Route(next_hop_type=next_hop_type, address_prefix=address_prefix,
6190                  next_hop_ip_address=next_hop_ip_address, name=route_name)
6191    ncf = network_client_factory(cmd.cli_ctx)
6192    return ncf.routes.begin_create_or_update(resource_group_name, route_table_name, route_name, route)
6193
6194
6195def update_route(instance, address_prefix=None, next_hop_type=None, next_hop_ip_address=None):
6196    if address_prefix is not None:
6197        instance.address_prefix = address_prefix
6198
6199    if next_hop_type is not None:
6200        instance.next_hop_type = next_hop_type
6201
6202    if next_hop_ip_address is not None:
6203        instance.next_hop_ip_address = next_hop_ip_address
6204    return instance
6205# endregion
6206
6207
6208# region ServiceEndpoints
6209def create_service_endpoint_policy(cmd, resource_group_name, service_endpoint_policy_name, location=None, tags=None):
6210    client = network_client_factory(cmd.cli_ctx).service_endpoint_policies
6211    ServiceEndpointPolicy = cmd.get_models('ServiceEndpointPolicy')
6212    policy = ServiceEndpointPolicy(tags=tags, location=location)
6213    return client.begin_create_or_update(resource_group_name, service_endpoint_policy_name, policy)
6214
6215
6216def list_service_endpoint_policies(cmd, resource_group_name=None):
6217    client = network_client_factory(cmd.cli_ctx).service_endpoint_policies
6218    if resource_group_name:
6219        return client.list_by_resource_group(resource_group_name)
6220    return client.list()
6221
6222
6223def update_service_endpoint_policy(instance, tags=None):
6224    if tags is not None:
6225        instance.tags = tags
6226
6227    return instance
6228
6229
6230def create_service_endpoint_policy_definition(cmd, resource_group_name, service_endpoint_policy_name,
6231                                              service_endpoint_policy_definition_name, service, service_resources,
6232                                              description=None):
6233    client = network_client_factory(cmd.cli_ctx).service_endpoint_policy_definitions
6234    ServiceEndpointPolicyDefinition = cmd.get_models('ServiceEndpointPolicyDefinition')
6235    policy_def = ServiceEndpointPolicyDefinition(description=description, service=service,
6236                                                 service_resources=service_resources)
6237    return client.begin_create_or_update(resource_group_name, service_endpoint_policy_name,
6238                                         service_endpoint_policy_definition_name, policy_def)
6239
6240
6241def update_service_endpoint_policy_definition(instance, service=None, service_resources=None, description=None):
6242    if service is not None:
6243        instance.service = service
6244
6245    if service_resources is not None:
6246        instance.service_resources = service_resources
6247
6248    if description is not None:
6249        instance.description = description
6250
6251    return instance
6252# endregion
6253
6254
6255# region TrafficManagers
6256def list_traffic_manager_profiles(cmd, resource_group_name=None):
6257    from azure.mgmt.trafficmanager import TrafficManagerManagementClient
6258    client = get_mgmt_service_client(cmd.cli_ctx, TrafficManagerManagementClient).profiles
6259    if resource_group_name:
6260        return client.list_by_resource_group(resource_group_name)
6261
6262    return client.list_by_subscription()
6263
6264
6265def create_traffic_manager_profile(cmd, traffic_manager_profile_name, resource_group_name,
6266                                   routing_method, unique_dns_name, monitor_path=None,
6267                                   monitor_port=80, monitor_protocol=MonitorProtocol.http.value,
6268                                   profile_status=ProfileStatus.enabled.value,
6269                                   ttl=30, tags=None, interval=None, timeout=None, max_failures=None,
6270                                   monitor_custom_headers=None, status_code_ranges=None, max_return=None):
6271    from azure.mgmt.trafficmanager import TrafficManagerManagementClient
6272    from azure.mgmt.trafficmanager.models import Profile, DnsConfig, MonitorConfig
6273    client = get_mgmt_service_client(cmd.cli_ctx, TrafficManagerManagementClient).profiles
6274    if monitor_path is None and monitor_protocol == 'HTTP':
6275        monitor_path = '/'
6276    profile = Profile(location='global', tags=tags, profile_status=profile_status,
6277                      traffic_routing_method=routing_method,
6278                      dns_config=DnsConfig(relative_name=unique_dns_name, ttl=ttl),
6279                      monitor_config=MonitorConfig(protocol=monitor_protocol,
6280                                                   port=monitor_port,
6281                                                   path=monitor_path,
6282                                                   interval_in_seconds=interval,
6283                                                   timeout_in_seconds=timeout,
6284                                                   tolerated_number_of_failures=max_failures,
6285                                                   custom_headers=monitor_custom_headers,
6286                                                   expected_status_code_ranges=status_code_ranges),
6287                      max_return=max_return)
6288    return client.create_or_update(resource_group_name, traffic_manager_profile_name, profile)
6289
6290
6291def update_traffic_manager_profile(instance, profile_status=None, routing_method=None, tags=None,
6292                                   monitor_protocol=None, monitor_port=None, monitor_path=None,
6293                                   ttl=None, timeout=None, interval=None, max_failures=None,
6294                                   monitor_custom_headers=None, status_code_ranges=None, max_return=None):
6295    if tags is not None:
6296        instance.tags = tags
6297    if profile_status is not None:
6298        instance.profile_status = profile_status
6299    if routing_method is not None:
6300        instance.traffic_routing_method = routing_method
6301    if ttl is not None:
6302        instance.dns_config.ttl = ttl
6303
6304    if monitor_protocol is not None:
6305        instance.monitor_config.protocol = monitor_protocol
6306    if monitor_port is not None:
6307        instance.monitor_config.port = monitor_port
6308    if monitor_path == '':
6309        instance.monitor_config.path = None
6310    elif monitor_path is not None:
6311        instance.monitor_config.path = monitor_path
6312    if interval is not None:
6313        instance.monitor_config.interval_in_seconds = interval
6314    if timeout is not None:
6315        instance.monitor_config.timeout_in_seconds = timeout
6316    if max_failures is not None:
6317        instance.monitor_config.tolerated_number_of_failures = max_failures
6318    if monitor_custom_headers is not None:
6319        instance.monitor_config.custom_headers = monitor_custom_headers
6320    if status_code_ranges is not None:
6321        instance.monitor_config.expected_status_code_ranges = status_code_ranges
6322    if max_return is not None:
6323        instance.max_return = max_return
6324
6325    # TODO: Remove workaround after https://github.com/Azure/azure-rest-api-specs/issues/1940 fixed
6326    for endpoint in instance.endpoints:
6327        endpoint._validation = {  # pylint: disable=protected-access
6328            'name': {'readonly': False},
6329            'type': {'readonly': False},
6330        }
6331    return instance
6332
6333
6334def create_traffic_manager_endpoint(cmd, resource_group_name, profile_name, endpoint_type, endpoint_name,
6335                                    target_resource_id=None, target=None,
6336                                    endpoint_status=None, weight=None, priority=None,
6337                                    endpoint_location=None, endpoint_monitor_status=None,
6338                                    min_child_endpoints=None, geo_mapping=None,
6339                                    monitor_custom_headers=None, subnets=None):
6340    from azure.mgmt.trafficmanager import TrafficManagerManagementClient
6341    from azure.mgmt.trafficmanager.models import Endpoint
6342    ncf = get_mgmt_service_client(cmd.cli_ctx, TrafficManagerManagementClient).endpoints
6343
6344    endpoint = Endpoint(target_resource_id=target_resource_id, target=target,
6345                        endpoint_status=endpoint_status, weight=weight, priority=priority,
6346                        endpoint_location=endpoint_location,
6347                        endpoint_monitor_status=endpoint_monitor_status,
6348                        min_child_endpoints=min_child_endpoints,
6349                        geo_mapping=geo_mapping,
6350                        subnets=subnets,
6351                        custom_headers=monitor_custom_headers)
6352
6353    return ncf.create_or_update(resource_group_name, profile_name, endpoint_type, endpoint_name,
6354                                endpoint)
6355
6356
6357def update_traffic_manager_endpoint(instance, endpoint_type=None, endpoint_location=None,
6358                                    endpoint_status=None, endpoint_monitor_status=None,
6359                                    priority=None, target=None, target_resource_id=None,
6360                                    weight=None, min_child_endpoints=None, geo_mapping=None,
6361                                    subnets=None, monitor_custom_headers=None):
6362    if endpoint_location is not None:
6363        instance.endpoint_location = endpoint_location
6364    if endpoint_status is not None:
6365        instance.endpoint_status = endpoint_status
6366    if endpoint_monitor_status is not None:
6367        instance.endpoint_monitor_status = endpoint_monitor_status
6368    if priority is not None:
6369        instance.priority = priority
6370    if target is not None:
6371        instance.target = target
6372    if target_resource_id is not None:
6373        instance.target_resource_id = target_resource_id
6374    if weight is not None:
6375        instance.weight = weight
6376    if min_child_endpoints is not None:
6377        instance.min_child_endpoints = min_child_endpoints
6378    if geo_mapping is not None:
6379        instance.geo_mapping = geo_mapping
6380    if subnets is not None:
6381        instance.subnets = subnets
6382    if monitor_custom_headers:
6383        instance.custom_headers = monitor_custom_headers
6384
6385    return instance
6386
6387
6388def list_traffic_manager_endpoints(cmd, resource_group_name, profile_name, endpoint_type=None):
6389    from azure.mgmt.trafficmanager import TrafficManagerManagementClient
6390    client = get_mgmt_service_client(cmd.cli_ctx, TrafficManagerManagementClient).profiles
6391    profile = client.get(resource_group_name, profile_name)
6392    return [e for e in profile.endpoints if not endpoint_type or e.type.endswith(endpoint_type)]
6393
6394
6395# endregion
6396
6397
6398# region VirtualNetworks
6399# pylint: disable=too-many-locals
6400def create_vnet(cmd, resource_group_name, vnet_name, vnet_prefixes='10.0.0.0/16',
6401                subnet_name=None, subnet_prefix=None, dns_servers=None,
6402                location=None, tags=None, vm_protection=None, ddos_protection=None,
6403                ddos_protection_plan=None, network_security_group=None, edge_zone=None, flowtimeout=None):
6404    AddressSpace, DhcpOptions, Subnet, VirtualNetwork, SubResource, NetworkSecurityGroup = \
6405        cmd.get_models('AddressSpace', 'DhcpOptions', 'Subnet', 'VirtualNetwork',
6406                       'SubResource', 'NetworkSecurityGroup')
6407    client = network_client_factory(cmd.cli_ctx).virtual_networks
6408    tags = tags or {}
6409
6410    vnet = VirtualNetwork(
6411        location=location, tags=tags,
6412        dhcp_options=DhcpOptions(dns_servers=dns_servers),
6413        address_space=AddressSpace(address_prefixes=(vnet_prefixes if isinstance(vnet_prefixes, list) else [vnet_prefixes])))  # pylint: disable=line-too-long
6414    if subnet_name:
6415        if cmd.supported_api_version(min_api='2018-08-01'):
6416            vnet.subnets = [Subnet(name=subnet_name,
6417                                   address_prefix=subnet_prefix[0] if len(subnet_prefix) == 1 else None,
6418                                   address_prefixes=subnet_prefix if len(subnet_prefix) > 1 else None,
6419                                   network_security_group=NetworkSecurityGroup(id=network_security_group)
6420                                   if network_security_group else None)]
6421        else:
6422            vnet.subnets = [Subnet(name=subnet_name, address_prefix=subnet_prefix)]
6423    if cmd.supported_api_version(min_api='2017-09-01'):
6424        vnet.enable_ddos_protection = ddos_protection
6425        vnet.enable_vm_protection = vm_protection
6426    if cmd.supported_api_version(min_api='2018-02-01'):
6427        vnet.ddos_protection_plan = SubResource(id=ddos_protection_plan) if ddos_protection_plan else None
6428    if edge_zone:
6429        vnet.extended_location = _edge_zone_model(cmd, edge_zone)
6430    if flowtimeout is not None:
6431        vnet.flow_timeout_in_minutes = flowtimeout
6432    return cached_put(cmd, client.begin_create_or_update, vnet, resource_group_name, vnet_name)
6433
6434
6435def update_vnet(cmd, instance, vnet_prefixes=None, dns_servers=None, ddos_protection=None, vm_protection=None,
6436                ddos_protection_plan=None, flowtimeout=None):
6437    # server side validation reports pretty good error message on invalid CIDR,
6438    # so we don't validate at client side
6439    AddressSpace, DhcpOptions, SubResource = cmd.get_models('AddressSpace', 'DhcpOptions', 'SubResource')
6440    if vnet_prefixes and instance.address_space:
6441        instance.address_space.address_prefixes = vnet_prefixes
6442    elif vnet_prefixes:
6443        instance.address_space = AddressSpace(address_prefixes=vnet_prefixes)
6444
6445    if dns_servers == ['']:
6446        instance.dhcp_options.dns_servers = None
6447    elif dns_servers and instance.dhcp_options:
6448        instance.dhcp_options.dns_servers = dns_servers
6449    elif dns_servers:
6450        instance.dhcp_options = DhcpOptions(dns_servers=dns_servers)
6451
6452    if ddos_protection is not None:
6453        instance.enable_ddos_protection = ddos_protection
6454    if vm_protection is not None:
6455        instance.enable_vm_protection = vm_protection
6456    if ddos_protection_plan == '':
6457        instance.ddos_protection_plan = None
6458    elif ddos_protection_plan is not None:
6459        instance.ddos_protection_plan = SubResource(id=ddos_protection_plan)
6460    if flowtimeout is not None:
6461        instance.flow_timeout_in_minutes = flowtimeout
6462    return instance
6463
6464
6465def _set_route_table(ncf, resource_group_name, route_table, subnet):
6466    if route_table:
6467        is_id = is_valid_resource_id(route_table)
6468        rt = None
6469        if is_id:
6470            res_id = parse_resource_id(route_table)
6471            rt = ncf.route_tables.get(res_id['resource_group'], res_id['name'])
6472        else:
6473            rt = ncf.route_tables.get(resource_group_name, route_table)
6474        subnet.route_table = rt
6475    elif route_table == '':
6476        subnet.route_table = None
6477
6478
6479def create_subnet(cmd, resource_group_name, virtual_network_name, subnet_name,
6480                  address_prefix, network_security_group=None,
6481                  route_table=None, service_endpoints=None, service_endpoint_policy=None,
6482                  delegations=None, nat_gateway=None,
6483                  disable_private_endpoint_network_policies=None,
6484                  disable_private_link_service_network_policies=None):
6485    NetworkSecurityGroup, ServiceEndpoint, Subnet, SubResource = cmd.get_models(
6486        'NetworkSecurityGroup', 'ServiceEndpointPropertiesFormat', 'Subnet', 'SubResource')
6487    ncf = network_client_factory(cmd.cli_ctx)
6488
6489    if cmd.supported_api_version(min_api='2018-08-01'):
6490        subnet = Subnet(
6491            name=subnet_name,
6492            address_prefixes=address_prefix if len(address_prefix) > 1 else None,
6493            address_prefix=address_prefix[0] if len(address_prefix) == 1 else None
6494        )
6495        if cmd.supported_api_version(min_api='2019-02-01') and nat_gateway:
6496            subnet.nat_gateway = SubResource(id=nat_gateway)
6497    else:
6498        subnet = Subnet(name=subnet_name, address_prefix=address_prefix)
6499
6500    if network_security_group:
6501        subnet.network_security_group = NetworkSecurityGroup(id=network_security_group)
6502    _set_route_table(ncf, resource_group_name, route_table, subnet)
6503    if service_endpoints:
6504        subnet.service_endpoints = []
6505        for service in service_endpoints:
6506            subnet.service_endpoints.append(ServiceEndpoint(service=service))
6507    if service_endpoint_policy:
6508        subnet.service_endpoint_policies = []
6509        for policy in service_endpoint_policy:
6510            subnet.service_endpoint_policies.append(SubResource(id=policy))
6511    if delegations:
6512        subnet.delegations = delegations
6513
6514    if disable_private_endpoint_network_policies is True:
6515        subnet.private_endpoint_network_policies = "Disabled"
6516    if disable_private_endpoint_network_policies is False:
6517        subnet.private_endpoint_network_policies = "Enabled"
6518
6519    if disable_private_link_service_network_policies is True:
6520        subnet.private_link_service_network_policies = "Disabled"
6521    if disable_private_link_service_network_policies is False:
6522        subnet.private_link_service_network_policies = "Enabled"
6523
6524    vnet = cached_get(cmd, ncf.virtual_networks.get, resource_group_name, virtual_network_name)
6525    upsert_to_collection(vnet, 'subnets', subnet, 'name')
6526    vnet = cached_put(
6527        cmd, ncf.virtual_networks.begin_create_or_update, vnet, resource_group_name, virtual_network_name).result()
6528    return get_property(vnet.subnets, subnet_name)
6529
6530
6531def update_subnet(cmd, instance, resource_group_name, address_prefix=None, network_security_group=None,
6532                  route_table=None, service_endpoints=None, delegations=None, nat_gateway=None,
6533                  service_endpoint_policy=None, disable_private_endpoint_network_policies=None,
6534                  disable_private_link_service_network_policies=None):
6535    NetworkSecurityGroup, ServiceEndpoint, SubResource = cmd.get_models(
6536        'NetworkSecurityGroup', 'ServiceEndpointPropertiesFormat', 'SubResource')
6537
6538    if address_prefix:
6539        if cmd.supported_api_version(min_api='2018-08-01'):
6540            instance.address_prefixes = address_prefix if len(address_prefix) > 1 else None
6541            instance.address_prefix = address_prefix[0] if len(address_prefix) == 1 else None
6542        else:
6543            instance.address_prefix = address_prefix
6544
6545    if cmd.supported_api_version(min_api='2019-02-01') and nat_gateway:
6546        instance.nat_gateway = SubResource(id=nat_gateway)
6547    elif nat_gateway == '':
6548        instance.nat_gateway = None
6549
6550    if network_security_group:
6551        instance.network_security_group = NetworkSecurityGroup(id=network_security_group)
6552    elif network_security_group == '':  # clear it
6553        instance.network_security_group = None
6554
6555    _set_route_table(network_client_factory(cmd.cli_ctx), resource_group_name, route_table, instance)
6556
6557    if service_endpoints == ['']:
6558        instance.service_endpoints = None
6559    elif service_endpoints:
6560        instance.service_endpoints = []
6561        for service in service_endpoints:
6562            instance.service_endpoints.append(ServiceEndpoint(service=service))
6563
6564    if service_endpoint_policy == '':
6565        instance.service_endpoint_policies = None
6566    elif service_endpoint_policy:
6567        instance.service_endpoint_policies = []
6568        for policy in service_endpoint_policy:
6569            instance.service_endpoint_policies.append(SubResource(id=policy))
6570
6571    if delegations:
6572        instance.delegations = delegations
6573
6574    if disable_private_endpoint_network_policies:
6575        instance.private_endpoint_network_policies = "Disabled"
6576    elif disable_private_endpoint_network_policies is not None:
6577        instance.private_endpoint_network_policies = "Enabled"
6578
6579    if disable_private_link_service_network_policies:
6580        instance.private_link_service_network_policies = "Disabled"
6581    elif disable_private_link_service_network_policies is not None:
6582        instance.private_link_service_network_policies = "Enabled"
6583
6584    return instance
6585
6586
6587def list_avail_subnet_delegations(cmd, resource_group_name=None, location=None):
6588    client = network_client_factory(cmd.cli_ctx)
6589    if resource_group_name:
6590        return client.available_resource_group_delegations.list(location, resource_group_name)
6591    return client.available_delegations.list(location)
6592
6593
6594def create_vnet_peering(cmd, resource_group_name, virtual_network_name, virtual_network_peering_name,
6595                        remote_virtual_network, allow_virtual_network_access=False,
6596                        allow_forwarded_traffic=False, allow_gateway_transit=False,
6597                        use_remote_gateways=False):
6598    if not is_valid_resource_id(remote_virtual_network):
6599        remote_virtual_network = resource_id(
6600            subscription=get_subscription_id(cmd.cli_ctx),
6601            resource_group=resource_group_name,
6602            namespace='Microsoft.Network',
6603            type='virtualNetworks',
6604            name=remote_virtual_network
6605        )
6606    SubResource, VirtualNetworkPeering = cmd.get_models('SubResource', 'VirtualNetworkPeering')
6607    peering = VirtualNetworkPeering(
6608        id=resource_id(
6609            subscription=get_subscription_id(cmd.cli_ctx),
6610            resource_group=resource_group_name,
6611            namespace='Microsoft.Network',
6612            type='virtualNetworks',
6613            name=virtual_network_name),
6614        name=virtual_network_peering_name,
6615        remote_virtual_network=SubResource(id=remote_virtual_network),
6616        allow_virtual_network_access=allow_virtual_network_access,
6617        allow_gateway_transit=allow_gateway_transit,
6618        allow_forwarded_traffic=allow_forwarded_traffic,
6619        use_remote_gateways=use_remote_gateways)
6620    aux_subscription = parse_resource_id(remote_virtual_network)['subscription']
6621    ncf = network_client_factory(cmd.cli_ctx, aux_subscriptions=[aux_subscription])
6622    return ncf.virtual_network_peerings.begin_create_or_update(
6623        resource_group_name, virtual_network_name, virtual_network_peering_name, peering)
6624
6625
6626def update_vnet_peering(cmd, resource_group_name, virtual_network_name, virtual_network_peering_name, **kwargs):
6627    peering = kwargs['parameters']
6628    aux_subscription = parse_resource_id(peering.remote_virtual_network.id)['subscription']
6629    ncf = network_client_factory(cmd.cli_ctx, aux_subscriptions=[aux_subscription])
6630    return ncf.virtual_network_peerings.begin_create_or_update(
6631        resource_group_name, virtual_network_name, virtual_network_peering_name, peering)
6632
6633
6634def list_available_ips(cmd, resource_group_name, virtual_network_name):
6635    client = network_client_factory(cmd.cli_ctx).virtual_networks
6636    vnet = client.get(resource_group_name=resource_group_name,
6637                      virtual_network_name=virtual_network_name)
6638    start_ip = vnet.address_space.address_prefixes[0].split('/')[0]
6639    available_ips = client.check_ip_address_availability(resource_group_name=resource_group_name,
6640                                                         virtual_network_name=virtual_network_name,
6641                                                         ip_address=start_ip)
6642    return available_ips.available_ip_addresses
6643
6644# endregion
6645
6646
6647# region VirtualNetworkGateways
6648def create_vnet_gateway_root_cert(cmd, resource_group_name, gateway_name, public_cert_data, cert_name):
6649    VpnClientRootCertificate = cmd.get_models('VpnClientRootCertificate')
6650    ncf = network_client_factory(cmd.cli_ctx).virtual_network_gateways
6651    gateway = ncf.get(resource_group_name, gateway_name)
6652    if not gateway.vpn_client_configuration:
6653        raise CLIError("Must add address prefixes to gateway '{}' prior to adding a root cert."
6654                       .format(gateway_name))
6655    config = gateway.vpn_client_configuration
6656
6657    if config.vpn_client_root_certificates is None:
6658        config.vpn_client_root_certificates = []
6659
6660    cert = VpnClientRootCertificate(name=cert_name, public_cert_data=public_cert_data)
6661    upsert_to_collection(config, 'vpn_client_root_certificates', cert, 'name')
6662    return ncf.begin_create_or_update(resource_group_name, gateway_name, gateway)
6663
6664
6665def delete_vnet_gateway_root_cert(cmd, resource_group_name, gateway_name, cert_name):
6666    ncf = network_client_factory(cmd.cli_ctx).virtual_network_gateways
6667    gateway = ncf.get(resource_group_name, gateway_name)
6668    config = gateway.vpn_client_configuration
6669
6670    try:
6671        cert = next(c for c in config.vpn_client_root_certificates if c.name == cert_name)
6672    except (AttributeError, StopIteration):
6673        raise CLIError('Certificate "{}" not found in gateway "{}"'.format(cert_name, gateway_name))
6674    config.vpn_client_root_certificates.remove(cert)
6675
6676    return ncf.begin_create_or_update(resource_group_name, gateway_name, gateway)
6677
6678
6679def create_vnet_gateway_revoked_cert(cmd, resource_group_name, gateway_name, thumbprint, cert_name):
6680    VpnClientRevokedCertificate = cmd.get_models('VpnClientRevokedCertificate')
6681    config, gateway, ncf = _prep_cert_create(cmd, gateway_name, resource_group_name)
6682
6683    cert = VpnClientRevokedCertificate(name=cert_name, thumbprint=thumbprint)
6684    upsert_to_collection(config, 'vpn_client_revoked_certificates', cert, 'name')
6685    return ncf.begin_create_or_update(resource_group_name, gateway_name, gateway)
6686
6687
6688def delete_vnet_gateway_revoked_cert(cmd, resource_group_name, gateway_name, cert_name):
6689    ncf = network_client_factory(cmd.cli_ctx).virtual_network_gateways
6690    gateway = ncf.get(resource_group_name, gateway_name)
6691    config = gateway.vpn_client_configuration
6692
6693    try:
6694        cert = next(c for c in config.vpn_client_revoked_certificates if c.name == cert_name)
6695    except (AttributeError, StopIteration):
6696        raise CLIError('Certificate "{}" not found in gateway "{}"'.format(cert_name, gateway_name))
6697    config.vpn_client_revoked_certificates.remove(cert)
6698
6699    return ncf.begin_create_or_update(resource_group_name, gateway_name, gateway)
6700
6701
6702def _prep_cert_create(cmd, gateway_name, resource_group_name):
6703    VpnClientConfiguration = cmd.get_models('VpnClientConfiguration')
6704    ncf = network_client_factory(cmd.cli_ctx).virtual_network_gateways
6705    gateway = ncf.get(resource_group_name, gateway_name)
6706    if not gateway.vpn_client_configuration:
6707        gateway.vpn_client_configuration = VpnClientConfiguration()
6708    config = gateway.vpn_client_configuration
6709
6710    if not config.vpn_client_address_pool or not config.vpn_client_address_pool.address_prefixes:
6711        raise CLIError('Address prefixes must be set on VPN gateways before adding'
6712                       ' certificates.  Please use "update" with --address-prefixes first.')
6713
6714    if config.vpn_client_revoked_certificates is None:
6715        config.vpn_client_revoked_certificates = []
6716    if config.vpn_client_root_certificates is None:
6717        config.vpn_client_root_certificates = []
6718
6719    return config, gateway, ncf
6720
6721
6722def create_vnet_gateway(cmd, resource_group_name, virtual_network_gateway_name, public_ip_address,
6723                        virtual_network, location=None, tags=None,
6724                        no_wait=False, gateway_type=None, sku=None, vpn_type=None, vpn_gateway_generation=None,
6725                        asn=None, bgp_peering_address=None, peer_weight=None,
6726                        address_prefixes=None, radius_server=None, radius_secret=None, client_protocol=None,
6727                        gateway_default_site=None, custom_routes=None, aad_tenant=None, aad_audience=None,
6728                        aad_issuer=None, root_cert_data=None, root_cert_name=None, vpn_auth_type=None, edge_zone=None,
6729                        nat_rule=None):
6730    (VirtualNetworkGateway, BgpSettings, SubResource, VirtualNetworkGatewayIPConfiguration, VirtualNetworkGatewaySku,
6731     VpnClientConfiguration, AddressSpace, VpnClientRootCertificate, VirtualNetworkGatewayNatRule,
6732     VpnNatRuleMapping) = cmd.get_models(
6733         'VirtualNetworkGateway', 'BgpSettings', 'SubResource', 'VirtualNetworkGatewayIPConfiguration',
6734         'VirtualNetworkGatewaySku', 'VpnClientConfiguration', 'AddressSpace', 'VpnClientRootCertificate',
6735         'VirtualNetworkGatewayNatRule', 'VpnNatRuleMapping')
6736
6737    client = network_client_factory(cmd.cli_ctx).virtual_network_gateways
6738    subnet = virtual_network + '/subnets/GatewaySubnet'
6739    active = len(public_ip_address) == 2
6740    vnet_gateway = VirtualNetworkGateway(
6741        gateway_type=gateway_type, vpn_type=vpn_type, vpn_gateway_generation=vpn_gateway_generation, location=location,
6742        tags=tags, sku=VirtualNetworkGatewaySku(name=sku, tier=sku), active=active, ip_configurations=[],
6743        gateway_default_site=SubResource(id=gateway_default_site) if gateway_default_site else None)
6744    for i, public_ip in enumerate(public_ip_address):
6745        ip_configuration = VirtualNetworkGatewayIPConfiguration(
6746            subnet=SubResource(id=subnet),
6747            public_ip_address=SubResource(id=public_ip),
6748            private_ip_allocation_method='Dynamic',
6749            name='vnetGatewayConfig{}'.format(i)
6750        )
6751        vnet_gateway.ip_configurations.append(ip_configuration)
6752    if asn or bgp_peering_address or peer_weight:
6753        vnet_gateway.enable_bgp = True
6754        vnet_gateway.bgp_settings = BgpSettings(asn=asn, bgp_peering_address=bgp_peering_address,
6755                                                peer_weight=peer_weight)
6756
6757    if any((address_prefixes, client_protocol)):
6758        vnet_gateway.vpn_client_configuration = VpnClientConfiguration()
6759        vnet_gateway.vpn_client_configuration.vpn_client_address_pool = AddressSpace()
6760        vnet_gateway.vpn_client_configuration.vpn_client_address_pool.address_prefixes = address_prefixes
6761        vnet_gateway.vpn_client_configuration.vpn_client_protocols = client_protocol
6762        if any((radius_secret, radius_server)) and cmd.supported_api_version(min_api='2017-06-01'):
6763            vnet_gateway.vpn_client_configuration.radius_server_address = radius_server
6764            vnet_gateway.vpn_client_configuration.radius_server_secret = radius_secret
6765
6766        # multi authentication
6767        if cmd.supported_api_version(min_api='2020-11-01'):
6768            vnet_gateway.vpn_client_configuration.vpn_authentication_types = vpn_auth_type
6769            vnet_gateway.vpn_client_configuration.aad_tenant = aad_tenant
6770            vnet_gateway.vpn_client_configuration.aad_issuer = aad_issuer
6771            vnet_gateway.vpn_client_configuration.aad_audience = aad_audience
6772            vnet_gateway.vpn_client_configuration.vpn_client_root_certificates = [
6773                VpnClientRootCertificate(name=root_cert_name,
6774                                         public_cert_data=root_cert_data)] if root_cert_data else None
6775
6776    if custom_routes and cmd.supported_api_version(min_api='2019-02-01'):
6777        vnet_gateway.custom_routes = AddressSpace()
6778        vnet_gateway.custom_routes.address_prefixes = custom_routes
6779
6780    if edge_zone:
6781        vnet_gateway.extended_location = _edge_zone_model(cmd, edge_zone)
6782    if nat_rule:
6783        vnet_gateway.nat_rules = [
6784            VirtualNetworkGatewayNatRule(type_properties_type=rule.get('type'), mode=rule.get('mode'), name=rule.get('name'),
6785                                         internal_mappings=[VpnNatRuleMapping(address_space=i_map) for i_map in rule.get('internal_mappings')] if rule.get('internal_mappings') else None,
6786                                         external_mappings=[VpnNatRuleMapping(address_space=i_map) for i_map in rule.get('external_mappings')] if rule.get('external_mappings') else None,
6787                                         ip_configuration_id=rule.get('ip_config_id')) for rule in nat_rule]
6788
6789    return sdk_no_wait(no_wait, client.begin_create_or_update,
6790                       resource_group_name, virtual_network_gateway_name, vnet_gateway)
6791
6792
6793def update_vnet_gateway(cmd, instance, sku=None, vpn_type=None, tags=None,
6794                        public_ip_address=None, gateway_type=None, enable_bgp=None,
6795                        asn=None, bgp_peering_address=None, peer_weight=None, virtual_network=None,
6796                        address_prefixes=None, radius_server=None, radius_secret=None, client_protocol=None,
6797                        gateway_default_site=None, custom_routes=None, aad_tenant=None, aad_audience=None,
6798                        aad_issuer=None, root_cert_data=None, root_cert_name=None, vpn_auth_type=None):
6799    (AddressSpace, SubResource, VirtualNetworkGatewayIPConfiguration, VpnClientConfiguration,
6800     VpnClientRootCertificate) = cmd.get_models('AddressSpace', 'SubResource', 'VirtualNetworkGatewayIPConfiguration',
6801                                                'VpnClientConfiguration', 'VpnClientRootCertificate')
6802
6803    if any((address_prefixes, radius_server, radius_secret, client_protocol)) and not instance.vpn_client_configuration:
6804        instance.vpn_client_configuration = VpnClientConfiguration()
6805
6806    if address_prefixes is not None:
6807        if not instance.vpn_client_configuration.vpn_client_address_pool:
6808            instance.vpn_client_configuration.vpn_client_address_pool = AddressSpace()
6809        if not instance.vpn_client_configuration.vpn_client_address_pool.address_prefixes:
6810            instance.vpn_client_configuration.vpn_client_address_pool.address_prefixes = []
6811        instance.vpn_client_configuration.vpn_client_address_pool.address_prefixes = address_prefixes
6812
6813    with cmd.update_context(instance.vpn_client_configuration) as c:
6814        c.set_param('vpn_client_protocols', client_protocol)
6815        c.set_param('radius_server_address', radius_server)
6816        c.set_param('radius_server_secret', radius_secret)
6817        if cmd.supported_api_version(min_api='2020-11-01'):
6818            c.set_param('aad_tenant', aad_tenant)
6819            c.set_param('aad_audience', aad_audience)
6820            c.set_param('aad_issuer', aad_issuer)
6821            c.set_param('vpn_authentication_types', vpn_auth_type)
6822
6823    if root_cert_data and cmd.supported_api_version(min_api='2020-11-01'):
6824        upsert_to_collection(instance.vpn_client_configuration, 'vpn_client_root_certificates',
6825                             VpnClientRootCertificate(name=root_cert_name, public_cert_data=root_cert_data), 'name')
6826
6827    with cmd.update_context(instance.sku) as c:
6828        c.set_param('name', sku)
6829        c.set_param('tier', sku)
6830
6831    with cmd.update_context(instance) as c:
6832        c.set_param('gateway_default_site', SubResource(id=gateway_default_site) if gateway_default_site else None)
6833        c.set_param('vpn_type', vpn_type)
6834        c.set_param('tags', tags)
6835
6836    subnet_id = '{}/subnets/GatewaySubnet'.format(virtual_network) if virtual_network else \
6837        instance.ip_configurations[0].subnet.id
6838    if virtual_network is not None:
6839        for config in instance.ip_configurations:
6840            config.subnet.id = subnet_id
6841
6842    if public_ip_address is not None:
6843        instance.ip_configurations = []
6844        for i, public_ip in enumerate(public_ip_address):
6845            ip_configuration = VirtualNetworkGatewayIPConfiguration(
6846                subnet=SubResource(id=subnet_id),
6847                public_ip_address=SubResource(id=public_ip),
6848                private_ip_allocation_method='Dynamic', name='vnetGatewayConfig{}'.format(i))
6849            instance.ip_configurations.append(ip_configuration)
6850
6851        # Update active-active/active-standby status
6852        active = len(public_ip_address) == 2
6853        if instance.active and not active:
6854            logger.info('Placing gateway in active-standby mode.')
6855        elif not instance.active and active:
6856            logger.info('Placing gateway in active-active mode.')
6857        instance.active = active
6858
6859    if gateway_type is not None:
6860        instance.gateway_type = gateway_type
6861
6862    if enable_bgp is not None:
6863        instance.enable_bgp = enable_bgp.lower() == 'true'
6864
6865    if custom_routes and cmd.supported_api_version(min_api='2019-02-01'):
6866        if not instance.custom_routes:
6867            instance.custom_routes = AddressSpace()
6868        instance.custom_routes.address_prefixes = custom_routes
6869
6870    _validate_bgp_peering(cmd, instance, asn, bgp_peering_address, peer_weight)
6871
6872    return instance
6873
6874
6875def start_vnet_gateway_package_capture(cmd, client, resource_group_name, virtual_network_gateway_name,
6876                                       filter_data=None, no_wait=False):
6877    VpnPacketCaptureStartParameters = cmd.get_models('VpnPacketCaptureStartParameters')
6878    parameters = VpnPacketCaptureStartParameters(filter_data=filter_data)
6879    return sdk_no_wait(no_wait, client.begin_start_packet_capture, resource_group_name,
6880                       virtual_network_gateway_name, parameters=parameters)
6881
6882
6883def stop_vnet_gateway_package_capture(cmd, client, resource_group_name, virtual_network_gateway_name,
6884                                      sas_url, no_wait=False):
6885    VpnPacketCaptureStopParameters = cmd.get_models('VpnPacketCaptureStopParameters')
6886    parameters = VpnPacketCaptureStopParameters(sas_url=sas_url)
6887    return sdk_no_wait(no_wait, client.begin_stop_packet_capture, resource_group_name,
6888                       virtual_network_gateway_name, parameters=parameters)
6889
6890
6891def generate_vpn_client(cmd, client, resource_group_name, virtual_network_gateway_name, processor_architecture=None,
6892                        authentication_method=None, radius_server_auth_certificate=None, client_root_certificates=None,
6893                        use_legacy=False):
6894    params = cmd.get_models('VpnClientParameters')(
6895        processor_architecture=processor_architecture
6896    )
6897
6898    if cmd.supported_api_version(min_api='2017-06-01') and not use_legacy:
6899        params.authentication_method = authentication_method
6900        params.radius_server_auth_certificate = radius_server_auth_certificate
6901        params.client_root_certificates = client_root_certificates
6902        return client.begin_generate_vpn_profile(resource_group_name, virtual_network_gateway_name, params)
6903    # legacy implementation
6904    return client.begin_generatevpnclientpackage(resource_group_name, virtual_network_gateway_name, params)
6905
6906
6907def set_vpn_client_ipsec_policy(cmd, client, resource_group_name, virtual_network_gateway_name,
6908                                sa_life_time_seconds, sa_data_size_kilobytes,
6909                                ipsec_encryption, ipsec_integrity,
6910                                ike_encryption, ike_integrity, dh_group, pfs_group, no_wait=False):
6911    VpnClientIPsecParameters = cmd.get_models('VpnClientIPsecParameters')
6912    vpnclient_ipsec_params = VpnClientIPsecParameters(sa_life_time_seconds=sa_life_time_seconds,
6913                                                      sa_data_size_kilobytes=sa_data_size_kilobytes,
6914                                                      ipsec_encryption=ipsec_encryption,
6915                                                      ipsec_integrity=ipsec_integrity,
6916                                                      ike_encryption=ike_encryption,
6917                                                      ike_integrity=ike_integrity,
6918                                                      dh_group=dh_group,
6919                                                      pfs_group=pfs_group)
6920    return sdk_no_wait(no_wait, client.begin_set_vpnclient_ipsec_parameters, resource_group_name,
6921                       virtual_network_gateway_name, vpnclient_ipsec_params)
6922
6923
6924def disconnect_vnet_gateway_vpn_connections(cmd, client, resource_group_name, virtual_network_gateway_name,
6925                                            vpn_connection_ids, no_wait=False):
6926    P2SVpnConnectionRequest = cmd.get_models('P2SVpnConnectionRequest')
6927    request = P2SVpnConnectionRequest(vpn_connection_ids=vpn_connection_ids)
6928    return sdk_no_wait(no_wait, client.begin_disconnect_virtual_network_gateway_vpn_connections,
6929                       resource_group_name, virtual_network_gateway_name, request)
6930
6931# endregion
6932
6933
6934# region VirtualNetworkGatewayConnections
6935# pylint: disable=too-many-locals
6936def create_vpn_connection(cmd, resource_group_name, connection_name, vnet_gateway1,
6937                          location=None, tags=None, no_wait=False, validate=False,
6938                          vnet_gateway2=None, express_route_circuit2=None, local_gateway2=None,
6939                          authorization_key=None, enable_bgp=False, routing_weight=10,
6940                          connection_type=None, shared_key=None,
6941                          use_policy_based_traffic_selectors=False,
6942                          express_route_gateway_bypass=None, ingress_nat_rule=None, egress_nat_rule=None):
6943    from azure.cli.core.util import random_string
6944    from azure.cli.core.commands.arm import ArmTemplateBuilder
6945    from azure.cli.command_modules.network._template_builder import build_vpn_connection_resource
6946
6947    client = network_client_factory(cmd.cli_ctx).virtual_network_gateway_connections
6948    DeploymentProperties = cmd.get_models('DeploymentProperties', resource_type=ResourceType.MGMT_RESOURCE_RESOURCES)
6949    tags = tags or {}
6950
6951    # Build up the ARM template
6952    master_template = ArmTemplateBuilder()
6953    vpn_connection_resource = build_vpn_connection_resource(
6954        cmd, connection_name, location, tags, vnet_gateway1,
6955        vnet_gateway2 or local_gateway2 or express_route_circuit2,
6956        connection_type, authorization_key, enable_bgp, routing_weight, shared_key,
6957        use_policy_based_traffic_selectors, express_route_gateway_bypass, ingress_nat_rule, egress_nat_rule)
6958    master_template.add_resource(vpn_connection_resource)
6959    master_template.add_output('resource', connection_name, output_type='object')
6960    if shared_key:
6961        master_template.add_secure_parameter('sharedKey', shared_key)
6962    if authorization_key:
6963        master_template.add_secure_parameter('authorizationKey', authorization_key)
6964
6965    template = master_template.build()
6966    parameters = master_template.build_parameters()
6967
6968    # deploy ARM template
6969    deployment_name = 'vpn_connection_deploy_' + random_string(32)
6970    client = get_mgmt_service_client(cmd.cli_ctx, ResourceType.MGMT_RESOURCE_RESOURCES).deployments
6971    properties = DeploymentProperties(template=template, parameters=parameters, mode='incremental')
6972    Deployment = cmd.get_models('Deployment', resource_type=ResourceType.MGMT_RESOURCE_RESOURCES)
6973    deployment = Deployment(properties=properties)
6974
6975    if validate:
6976        _log_pprint_template(template)
6977        if cmd.supported_api_version(min_api='2019-10-01', resource_type=ResourceType.MGMT_RESOURCE_RESOURCES):
6978            from azure.cli.core.commands import LongRunningOperation
6979            validation_poller = client.begin_validate(resource_group_name, deployment_name, deployment)
6980            return LongRunningOperation(cmd.cli_ctx)(validation_poller)
6981
6982        return client.validate(resource_group_name, deployment_name, deployment)
6983
6984    return sdk_no_wait(no_wait, client.begin_create_or_update, resource_group_name, deployment_name, deployment)
6985
6986
6987def update_vpn_connection(cmd, instance, routing_weight=None, shared_key=None, tags=None,
6988                          enable_bgp=None, use_policy_based_traffic_selectors=None,
6989                          express_route_gateway_bypass=None):
6990
6991    with cmd.update_context(instance) as c:
6992        c.set_param('routing_weight', routing_weight)
6993        c.set_param('shared_key', shared_key)
6994        c.set_param('tags', tags)
6995        c.set_param('enable_bgp', enable_bgp)
6996        c.set_param('express_route_gateway_bypass', express_route_gateway_bypass)
6997        c.set_param('use_policy_based_traffic_selectors', use_policy_based_traffic_selectors)
6998
6999    # TODO: Remove these when issue #1615 is fixed
7000    gateway1_id = parse_resource_id(instance.virtual_network_gateway1.id)
7001    ncf = network_client_factory(cmd.cli_ctx, subscription_id=gateway1_id['subscription'])
7002    instance.virtual_network_gateway1 = ncf.virtual_network_gateways.get(
7003        gateway1_id['resource_group'], gateway1_id['name'])
7004
7005    if instance.virtual_network_gateway2:
7006        gateway2_id = parse_resource_id(instance.virtual_network_gateway2.id)
7007        ncf = network_client_factory(cmd.cli_ctx, subscription_id=gateway2_id['subscription'])
7008        instance.virtual_network_gateway2 = ncf.virtual_network_gateways.get(
7009            gateway2_id['resource_group'], gateway2_id['name'])
7010
7011    if instance.local_network_gateway2:
7012        gateway2_id = parse_resource_id(instance.local_network_gateway2.id)
7013        ncf = network_client_factory(cmd.cli_ctx, subscription_id=gateway2_id['subscription'])
7014        instance.local_network_gateway2 = ncf.local_network_gateways.get(
7015            gateway2_id['resource_group'], gateway2_id['name'])
7016
7017    return instance
7018
7019
7020def list_vpn_connections(cmd, resource_group_name, virtual_network_gateway_name=None):
7021    if virtual_network_gateway_name:
7022        client = network_client_factory(cmd.cli_ctx).virtual_network_gateways
7023        return client.list_connections(resource_group_name, virtual_network_gateway_name)
7024    client = network_client_factory(cmd.cli_ctx).virtual_network_gateway_connections
7025    return client.list(resource_group_name)
7026
7027
7028def start_vpn_conn_package_capture(cmd, client, resource_group_name, virtual_network_gateway_connection_name,
7029                                   filter_data=None, no_wait=False):
7030    VpnPacketCaptureStartParameters = cmd.get_models('VpnPacketCaptureStartParameters')
7031    parameters = VpnPacketCaptureStartParameters(filter_data=filter_data)
7032    return sdk_no_wait(no_wait, client.begin_start_packet_capture, resource_group_name,
7033                       virtual_network_gateway_connection_name, parameters=parameters)
7034
7035
7036def stop_vpn_conn_package_capture(cmd, client, resource_group_name, virtual_network_gateway_connection_name,
7037                                  sas_url, no_wait=False):
7038    VpnPacketCaptureStopParameters = cmd.get_models('VpnPacketCaptureStopParameters')
7039    parameters = VpnPacketCaptureStopParameters(sas_url=sas_url)
7040    return sdk_no_wait(no_wait, client.begin_stop_packet_capture, resource_group_name,
7041                       virtual_network_gateway_connection_name, parameters=parameters)
7042
7043
7044def show_vpn_connection_device_config_script(cmd, client, resource_group_name, virtual_network_gateway_connection_name,
7045                                             vendor, device_family, firmware_version):
7046    VpnDeviceScriptParameters = cmd.get_models('VpnDeviceScriptParameters')
7047    parameters = VpnDeviceScriptParameters(
7048        vendor=vendor,
7049        device_family=device_family,
7050        firmware_version=firmware_version
7051    )
7052    return client.vpn_device_configuration_script(resource_group_name, virtual_network_gateway_connection_name,
7053                                                  parameters=parameters)
7054# endregion
7055
7056
7057# region IPSec Policy Commands
7058def add_vnet_gateway_ipsec_policy(cmd, resource_group_name, gateway_name,
7059                                  sa_life_time_seconds, sa_data_size_kilobytes,
7060                                  ipsec_encryption, ipsec_integrity,
7061                                  ike_encryption, ike_integrity, dh_group, pfs_group, no_wait=False):
7062    IpsecPolicy = cmd.get_models('IpsecPolicy')
7063    new_policy = IpsecPolicy(sa_life_time_seconds=sa_life_time_seconds,
7064                             sa_data_size_kilobytes=sa_data_size_kilobytes,
7065                             ipsec_encryption=ipsec_encryption,
7066                             ipsec_integrity=ipsec_integrity,
7067                             ike_encryption=ike_encryption,
7068                             ike_integrity=ike_integrity,
7069                             dh_group=dh_group,
7070                             pfs_group=pfs_group)
7071
7072    ncf = network_client_factory(cmd.cli_ctx).virtual_network_gateways
7073    gateway = ncf.get(resource_group_name, gateway_name)
7074    try:
7075        if gateway.vpn_client_configuration.vpn_client_ipsec_policies:
7076            gateway.vpn_client_configuration.vpn_client_ipsec_policies.append(new_policy)
7077        else:
7078            gateway.vpn_client_configuration.vpn_client_ipsec_policies = [new_policy]
7079    except AttributeError:
7080        raise CLIError('VPN client configuration must first be set through `az network vnet-gateway create/update`.')
7081    return sdk_no_wait(no_wait, ncf.begin_create_or_update, resource_group_name, gateway_name, gateway)
7082
7083
7084def clear_vnet_gateway_ipsec_policies(cmd, resource_group_name, gateway_name, no_wait=False):
7085    ncf = network_client_factory(cmd.cli_ctx).virtual_network_gateways
7086    gateway = ncf.get(resource_group_name, gateway_name)
7087    try:
7088        gateway.vpn_client_configuration.vpn_client_ipsec_policies = None
7089    except AttributeError:
7090        raise CLIError('VPN client configuration must first be set through `az network vnet-gateway create/update`.')
7091    if no_wait:
7092        return sdk_no_wait(no_wait, ncf.begin_create_or_update, resource_group_name, gateway_name, gateway)
7093
7094    from azure.cli.core.commands import LongRunningOperation
7095    poller = sdk_no_wait(no_wait, ncf.begin_create_or_update, resource_group_name, gateway_name, gateway)
7096    return LongRunningOperation(cmd.cli_ctx)(poller).vpn_client_configuration.vpn_client_ipsec_policies
7097
7098
7099def list_vnet_gateway_ipsec_policies(cmd, resource_group_name, gateway_name):
7100    ncf = network_client_factory(cmd.cli_ctx).virtual_network_gateways
7101    try:
7102        return ncf.get(resource_group_name, gateway_name).vpn_client_configuration.vpn_client_ipsec_policies
7103    except AttributeError:
7104        raise CLIError('VPN client configuration must first be set through `az network vnet-gateway create/update`.')
7105
7106
7107def add_vpn_conn_ipsec_policy(cmd, client, resource_group_name, connection_name,
7108                              sa_life_time_seconds, sa_data_size_kilobytes,
7109                              ipsec_encryption, ipsec_integrity,
7110                              ike_encryption, ike_integrity, dh_group, pfs_group, no_wait=False):
7111    IpsecPolicy = cmd.get_models('IpsecPolicy')
7112    new_policy = IpsecPolicy(sa_life_time_seconds=sa_life_time_seconds,
7113                             sa_data_size_kilobytes=sa_data_size_kilobytes,
7114                             ipsec_encryption=ipsec_encryption,
7115                             ipsec_integrity=ipsec_integrity,
7116                             ike_encryption=ike_encryption,
7117                             ike_integrity=ike_integrity,
7118                             dh_group=dh_group,
7119                             pfs_group=pfs_group)
7120
7121    conn = client.get(resource_group_name, connection_name)
7122    if conn.ipsec_policies:
7123        conn.ipsec_policies.append(new_policy)
7124    else:
7125        conn.ipsec_policies = [new_policy]
7126    return sdk_no_wait(no_wait, client.begin_create_or_update, resource_group_name, connection_name, conn)
7127
7128
7129def clear_vpn_conn_ipsec_policies(cmd, client, resource_group_name, connection_name, no_wait=False):
7130    conn = client.get(resource_group_name, connection_name)
7131    conn.ipsec_policies = None
7132    conn.use_policy_based_traffic_selectors = False
7133    if no_wait:
7134        return sdk_no_wait(no_wait, client.begin_create_or_update, resource_group_name, connection_name, conn)
7135
7136    from azure.cli.core.commands import LongRunningOperation
7137    poller = sdk_no_wait(no_wait, client.begin_create_or_update, resource_group_name, connection_name, conn)
7138    return LongRunningOperation(cmd.cli_ctx)(poller).ipsec_policies
7139
7140
7141def list_vpn_conn_ipsec_policies(cmd, client, resource_group_name, connection_name):
7142    return client.get(resource_group_name, connection_name).ipsec_policies
7143
7144
7145def assign_vnet_gateway_aad(cmd, resource_group_name, gateway_name,
7146                            aad_tenant, aad_audience, aad_issuer, no_wait=False):
7147    ncf = network_client_factory(cmd.cli_ctx).virtual_network_gateways
7148    gateway = ncf.get(resource_group_name, gateway_name)
7149
7150    if gateway.vpn_client_configuration is None:
7151        raise CLIError('VPN client configuration must be set first through `az network vnet-gateway create/update`.')
7152
7153    gateway.vpn_client_configuration.aad_tenant = aad_tenant
7154    gateway.vpn_client_configuration.aad_audience = aad_audience
7155    gateway.vpn_client_configuration.aad_issuer = aad_issuer
7156
7157    return sdk_no_wait(no_wait, ncf.begin_create_or_update, resource_group_name, gateway_name, gateway)
7158
7159
7160def show_vnet_gateway_aad(cmd, resource_group_name, gateway_name):
7161    ncf = network_client_factory(cmd.cli_ctx).virtual_network_gateways
7162    gateway = ncf.get(resource_group_name, gateway_name)
7163
7164    if gateway.vpn_client_configuration is None:
7165        raise CLIError('VPN client configuration must be set first through `az network vnet-gateway create/update`.')
7166
7167    return gateway.vpn_client_configuration
7168
7169
7170def remove_vnet_gateway_aad(cmd, resource_group_name, gateway_name, no_wait=False):
7171    ncf = network_client_factory(cmd.cli_ctx).virtual_network_gateways
7172    gateway = ncf.get(resource_group_name, gateway_name)
7173
7174    if gateway.vpn_client_configuration is None:
7175        raise CLIError('VPN client configuration must be set first through `az network vnet-gateway create/update`.')
7176
7177    gateway.vpn_client_configuration.aad_tenant = None
7178    gateway.vpn_client_configuration.aad_audience = None
7179    gateway.vpn_client_configuration.aad_issuer = None
7180    if cmd.supported_api_version(min_api='2020-11-01'):
7181        gateway.vpn_client_configuration.vpn_authentication_types = None
7182
7183    return sdk_no_wait(no_wait, ncf.begin_create_or_update, resource_group_name, gateway_name, gateway)
7184
7185
7186def add_vnet_gateway_nat_rule(cmd, resource_group_name, gateway_name, name, internal_mappings, external_mappings,
7187                              rule_type=None, mode=None, ip_config_id=None, no_wait=False):
7188    ncf = network_client_factory(cmd.cli_ctx).virtual_network_gateways
7189    gateway = ncf.get(resource_group_name, gateway_name)
7190
7191    VirtualNetworkGatewayNatRule, VpnNatRuleMapping = cmd.get_models('VirtualNetworkGatewayNatRule',
7192                                                                     'VpnNatRuleMapping')
7193    gateway.nat_rules.append(
7194        VirtualNetworkGatewayNatRule(type_properties_type=rule_type, mode=mode, name=name,
7195                                     internal_mappings=[VpnNatRuleMapping(address_space=i_map) for i_map in internal_mappings] if internal_mappings else None,
7196                                     external_mappings=[VpnNatRuleMapping(address_space=e_map) for e_map in external_mappings] if external_mappings else None,
7197                                     ip_configuration_id=ip_config_id))
7198
7199    return sdk_no_wait(no_wait, ncf.begin_create_or_update, resource_group_name, gateway_name, gateway)
7200
7201
7202def show_vnet_gateway_nat_rule(cmd, resource_group_name, gateway_name):
7203    ncf = network_client_factory(cmd.cli_ctx).virtual_network_gateways
7204    gateway = ncf.get(resource_group_name, gateway_name)
7205
7206    return gateway.nat_rules
7207
7208
7209def remove_vnet_gateway_nat_rule(cmd, resource_group_name, gateway_name, name, no_wait=False):
7210    ncf = network_client_factory(cmd.cli_ctx).virtual_network_gateways
7211    gateway = ncf.get(resource_group_name, gateway_name)
7212
7213    for rule in gateway.nat_rules:
7214        if name == rule.name:
7215            gateway.nat_rules.remove(rule)
7216            return sdk_no_wait(no_wait, ncf.begin_create_or_update, resource_group_name, gateway_name, gateway)
7217
7218    raise UnrecognizedArgumentError(f'Do not find nat_rules named {name}!!!')
7219# endregion
7220
7221
7222# region VirtualHub
7223def create_virtual_hub(cmd, client,
7224                       resource_group_name,
7225                       virtual_hub_name,
7226                       hosted_subnet,
7227                       public_ip_address=None,
7228                       location=None,
7229                       tags=None):
7230    from azure.core.exceptions import HttpResponseError
7231    from azure.cli.core.commands import LongRunningOperation
7232
7233    try:
7234        client.get(resource_group_name, virtual_hub_name)
7235        raise CLIError('The VirtualHub "{}" under resource group "{}" exists'.format(
7236            virtual_hub_name, resource_group_name))
7237    except HttpResponseError:
7238        pass
7239
7240    SubResource = cmd.get_models('SubResource')
7241
7242    VirtualHub, HubIpConfiguration = cmd.get_models('VirtualHub', 'HubIpConfiguration')
7243
7244    hub = VirtualHub(tags=tags, location=location,
7245                     virtual_wan=None,
7246                     sku='Standard')
7247    vhub_poller = client.begin_create_or_update(resource_group_name, virtual_hub_name, hub)
7248    LongRunningOperation(cmd.cli_ctx)(vhub_poller)
7249
7250    ip_config = HubIpConfiguration(
7251        subnet=SubResource(id=hosted_subnet),
7252        public_ip_address=SubResource(id=public_ip_address) if public_ip_address else None,
7253    )
7254    vhub_ip_config_client = network_client_factory(cmd.cli_ctx).virtual_hub_ip_configuration
7255    try:
7256        vhub_ip_poller = vhub_ip_config_client.begin_create_or_update(
7257            resource_group_name, virtual_hub_name, 'Default', ip_config)
7258        LongRunningOperation(cmd.cli_ctx)(vhub_ip_poller)
7259    except Exception as ex:
7260        logger.error(ex)
7261        try:
7262            vhub_ip_config_client.begin_delete(resource_group_name, virtual_hub_name, 'Default')
7263        except HttpResponseError:
7264            pass
7265        client.begin_delete(resource_group_name, virtual_hub_name)
7266        raise ex
7267
7268    return client.get(resource_group_name, virtual_hub_name)
7269
7270
7271def virtual_hub_update_setter(client, resource_group_name, virtual_hub_name, parameters):
7272    return client.begin_create_or_update(resource_group_name, virtual_hub_name, parameters)
7273
7274
7275def update_virtual_hub(cmd, instance,
7276                       tags=None,
7277                       allow_branch_to_branch_traffic=None):
7278    with cmd.update_context(instance) as c:
7279        c.set_param('tags', tags)
7280        c.set_param('allow_branch_to_branch_traffic', allow_branch_to_branch_traffic)
7281    return instance
7282
7283
7284def delete_virtual_hub(cmd, client, resource_group_name, virtual_hub_name, no_wait=False):
7285    from azure.cli.core.commands import LongRunningOperation
7286    vhub_ip_config_client = network_client_factory(cmd.cli_ctx).virtual_hub_ip_configuration
7287    ip_configs = list(vhub_ip_config_client.list(resource_group_name, virtual_hub_name))
7288    if ip_configs:
7289        ip_config = ip_configs[0]   # There will always be only 1
7290        poller = vhub_ip_config_client.begin_delete(resource_group_name, virtual_hub_name, ip_config.name)
7291        LongRunningOperation(cmd.cli_ctx)(poller)
7292    return sdk_no_wait(no_wait, client.begin_delete, resource_group_name, virtual_hub_name)
7293
7294
7295def list_virtual_hub(client, resource_group_name=None):
7296    if resource_group_name is not None:
7297        return client.list_by_resource_group(resource_group_name)
7298    return client.list()
7299
7300
7301def create_virtual_hub_bgp_connection(cmd, client, resource_group_name, virtual_hub_name, connection_name,
7302                                      peer_asn, peer_ip, no_wait=False):
7303    BgpConnection = cmd.get_models('BgpConnection')
7304    vhub_bgp_conn = BgpConnection(name=connection_name, peer_asn=peer_asn, peer_ip=peer_ip)
7305    return sdk_no_wait(no_wait, client.begin_create_or_update, resource_group_name,
7306                       virtual_hub_name, connection_name, vhub_bgp_conn)
7307
7308
7309def virtual_hub_bgp_connection_update_setter(client, resource_group_name,
7310                                             virtual_hub_name, connection_name,
7311                                             parameters):
7312    return client.begin_create_or_update(resource_group_name, virtual_hub_name, connection_name, parameters)
7313
7314
7315def update_virtual_hub_bgp_connection(cmd, instance, peer_asn=None, peer_ip=None):
7316    with cmd.update_context(instance) as c:
7317        c.set_param('peer_asn', peer_asn)
7318        c.set_param('peer_ip', peer_ip)
7319    return instance
7320
7321
7322def delete_virtual_hub_bgp_connection(client, resource_group_name,
7323                                      virtual_hub_name, connection_name, no_wait=False):
7324    return sdk_no_wait(no_wait, client.begin_delete, resource_group_name, virtual_hub_name, connection_name)
7325
7326
7327def list_virtual_hub_bgp_connection_learned_routes(client, resource_group_name, virtual_hub_name, connection_name):
7328    return client.begin_list_learned_routes(resource_group_name, virtual_hub_name, connection_name)
7329
7330
7331def list_virtual_hub_bgp_connection_advertised_routes(client, resource_group_name, virtual_hub_name, connection_name):
7332    return client.begin_list_advertised_routes(resource_group_name, virtual_hub_name, connection_name)
7333# endregion
7334
7335
7336# region VirtualRouter
7337def create_virtual_router(cmd,
7338                          resource_group_name,
7339                          virtual_router_name,
7340                          hosted_gateway=None,
7341                          hosted_subnet=None,
7342                          location=None,
7343                          tags=None):
7344    vrouter_client = network_client_factory(cmd.cli_ctx).virtual_routers
7345    vhub_client = network_client_factory(cmd.cli_ctx).virtual_hubs
7346
7347    from azure.core.exceptions import HttpResponseError
7348    try:
7349        vrouter_client.get(resource_group_name, virtual_router_name)
7350    except HttpResponseError:
7351        pass
7352
7353    virtual_hub_name = virtual_router_name
7354    try:
7355        vhub_client.get(resource_group_name, virtual_hub_name)
7356        raise CLIError('The VirtualRouter "{}" under resource group "{}" exists'.format(virtual_hub_name,
7357                                                                                        resource_group_name))
7358    except HttpResponseError:
7359        pass
7360
7361    SubResource = cmd.get_models('SubResource')
7362
7363    # for old VirtualRouter
7364    if hosted_gateway is not None:
7365        VirtualRouter = cmd.get_models('VirtualRouter')
7366        virtual_router = VirtualRouter(virtual_router_asn=None,
7367                                       virtual_router_ips=[],
7368                                       hosted_subnet=None,
7369                                       hosted_gateway=SubResource(id=hosted_gateway),
7370                                       location=location,
7371                                       tags=tags)
7372        return vrouter_client.begin_create_or_update(resource_group_name, virtual_router_name, virtual_router)
7373
7374    # for VirtualHub
7375    VirtualHub, HubIpConfiguration = cmd.get_models('VirtualHub', 'HubIpConfiguration')
7376
7377    hub = VirtualHub(tags=tags, location=location, virtual_wan=None, sku='Standard')
7378    ip_config = HubIpConfiguration(subnet=SubResource(id=hosted_subnet))
7379
7380    from azure.cli.core.commands import LongRunningOperation
7381
7382    vhub_poller = vhub_client.begin_create_or_update(resource_group_name, virtual_hub_name, hub)
7383    LongRunningOperation(cmd.cli_ctx)(vhub_poller)
7384
7385    vhub_ip_config_client = network_client_factory(cmd.cli_ctx).virtual_hub_ip_configuration
7386    try:
7387        vhub_ip_poller = vhub_ip_config_client.begin_create_or_update(resource_group_name,
7388                                                                      virtual_hub_name,
7389                                                                      'Default',
7390                                                                      ip_config)
7391        LongRunningOperation(cmd.cli_ctx)(vhub_ip_poller)
7392    except Exception as ex:
7393        logger.error(ex)
7394        vhub_ip_config_client.begin_delete(resource_group_name, virtual_hub_name, 'Default')
7395        vhub_client.begin_delete(resource_group_name, virtual_hub_name)
7396        raise ex
7397
7398    return vhub_client.get(resource_group_name, virtual_hub_name)
7399
7400
7401def virtual_router_update_getter(cmd, resource_group_name, virtual_router_name):
7402    from azure.core.exceptions import HttpResponseError
7403    try:
7404        vrouter_client = network_client_factory(cmd.cli_ctx).virtual_routers
7405        return vrouter_client.get(resource_group_name, virtual_router_name)
7406    except HttpResponseError:  # 404
7407        pass
7408
7409    virtual_hub_name = virtual_router_name
7410    vhub_client = network_client_factory(cmd.cli_ctx).virtual_hubs
7411    return vhub_client.get(resource_group_name, virtual_hub_name)
7412
7413
7414def virtual_router_update_setter(cmd, resource_group_name, virtual_router_name, parameters):
7415    if parameters.type == 'Microsoft.Network/virtualHubs':
7416        client = network_client_factory(cmd.cli_ctx).virtual_hubs
7417    else:
7418        client = network_client_factory(cmd.cli_ctx).virtual_routers
7419
7420    # If the client is virtual_hubs,
7421    # the virtual_router_name represents virtual_hub_name and
7422    # the parameters represents VirtualHub
7423    return client.begin_create_or_update(resource_group_name, virtual_router_name, parameters)
7424
7425
7426def update_virtual_router(cmd, instance, tags=None):
7427    # both VirtualHub and VirtualRouter own those properties
7428    with cmd.update_context(instance) as c:
7429        c.set_param('tags', tags)
7430    return instance
7431
7432
7433def list_virtual_router(cmd, resource_group_name=None):
7434    vrouter_client = network_client_factory(cmd.cli_ctx).virtual_routers
7435    vhub_client = network_client_factory(cmd.cli_ctx).virtual_hubs
7436
7437    if resource_group_name is not None:
7438        vrouters = vrouter_client.list_by_resource_group(resource_group_name)
7439        vhubs = vhub_client.list_by_resource_group(resource_group_name)
7440    else:
7441        vrouters = vrouter_client.list()
7442        vhubs = vhub_client.list()
7443
7444    return list(vrouters) + list(vhubs)
7445
7446
7447def show_virtual_router(cmd, resource_group_name, virtual_router_name):
7448    vrouter_client = network_client_factory(cmd.cli_ctx).virtual_routers
7449    vhub_client = network_client_factory(cmd.cli_ctx).virtual_hubs
7450
7451    from azure.core.exceptions import HttpResponseError
7452    try:
7453        item = vrouter_client.get(resource_group_name, virtual_router_name)
7454    except HttpResponseError:
7455        virtual_hub_name = virtual_router_name
7456        item = vhub_client.get(resource_group_name, virtual_hub_name)
7457
7458    return item
7459
7460
7461def delete_virtual_router(cmd, resource_group_name, virtual_router_name):
7462    vrouter_client = network_client_factory(cmd.cli_ctx).virtual_routers
7463    vhub_client = network_client_factory(cmd.cli_ctx).virtual_hubs
7464    vhub_ip_config_client = network_client_factory(cmd.cli_ctx).virtual_hub_ip_configuration
7465
7466    from azure.core.exceptions import HttpResponseError
7467    try:
7468        vrouter_client.get(resource_group_name, virtual_router_name)
7469        item = vrouter_client.begin_delete(resource_group_name, virtual_router_name)
7470    except HttpResponseError:
7471        from azure.cli.core.commands import LongRunningOperation
7472
7473        virtual_hub_name = virtual_router_name
7474        poller = vhub_ip_config_client.begin_delete(resource_group_name, virtual_hub_name, 'Default')
7475        LongRunningOperation(cmd.cli_ctx)(poller)
7476
7477        item = vhub_client.begin_delete(resource_group_name, virtual_hub_name)
7478
7479    return item
7480
7481
7482def create_virtual_router_peering(cmd, resource_group_name, virtual_router_name, peering_name, peer_asn, peer_ip):
7483
7484    # try VirtualRouter first
7485    from azure.core.exceptions import HttpResponseError
7486    try:
7487        vrouter_client = network_client_factory(cmd.cli_ctx).virtual_routers
7488        vrouter_client.get(resource_group_name, virtual_router_name)
7489    except HttpResponseError:
7490        pass
7491    else:
7492        vrouter_peering_client = network_client_factory(cmd.cli_ctx).virtual_router_peerings
7493        VirtualRouterPeering = cmd.get_models('VirtualRouterPeering')
7494        virtual_router_peering = VirtualRouterPeering(peer_asn=peer_asn, peer_ip=peer_ip)
7495        return vrouter_peering_client.begin_create_or_update(resource_group_name,
7496                                                             virtual_router_name,
7497                                                             peering_name,
7498                                                             virtual_router_peering)
7499
7500    virtual_hub_name = virtual_router_name
7501    bgp_conn_name = peering_name
7502
7503    # try VirtualHub then if the virtual router doesn't exist
7504    try:
7505        vhub_client = network_client_factory(cmd.cli_ctx).virtual_hubs
7506        vhub_client.get(resource_group_name, virtual_hub_name)
7507    except HttpResponseError:
7508        msg = 'The VirtualRouter "{}" under resource group "{}" was not found'.format(virtual_hub_name,
7509                                                                                      resource_group_name)
7510        raise CLIError(msg)
7511
7512    BgpConnection = cmd.get_models('BgpConnection')
7513    vhub_bgp_conn = BgpConnection(name=peering_name, peer_asn=peer_asn, peer_ip=peer_ip)
7514
7515    vhub_bgp_conn_client = network_client_factory(cmd.cli_ctx).virtual_hub_bgp_connection
7516    return vhub_bgp_conn_client.begin_create_or_update(resource_group_name, virtual_hub_name,
7517                                                       bgp_conn_name, vhub_bgp_conn)
7518
7519
7520def virtual_router_peering_update_getter(cmd, resource_group_name, virtual_router_name, peering_name):
7521    vrouter_peering_client = network_client_factory(cmd.cli_ctx).virtual_router_peerings
7522
7523    from azure.core.exceptions import HttpResponseError
7524    try:
7525        return vrouter_peering_client.get(resource_group_name, virtual_router_name, peering_name)
7526    except HttpResponseError:  # 404
7527        pass
7528
7529    virtual_hub_name = virtual_router_name
7530    bgp_conn_name = peering_name
7531
7532    vhub_bgp_conn_client = network_client_factory(cmd.cli_ctx).virtual_hub_bgp_connection
7533    return vhub_bgp_conn_client.get(resource_group_name, virtual_hub_name, bgp_conn_name)
7534
7535
7536def virtual_router_peering_update_setter(cmd, resource_group_name, virtual_router_name, peering_name, parameters):
7537    if parameters.type == 'Microsoft.Network/virtualHubs/bgpConnections':
7538        client = network_client_factory(cmd.cli_ctx).virtual_hub_bgp_connection
7539    else:
7540        client = network_client_factory(cmd.cli_ctx).virtual_router_peerings
7541
7542    # if the client is virtual_hub_bgp_connection,
7543    # the virtual_router_name represents virtual_hub_name and
7544    # the peering_name represents bgp_connection_name and
7545    # the parameters represents BgpConnection
7546    return client.begin_create_or_update(resource_group_name, virtual_router_name, peering_name, parameters)
7547
7548
7549def update_virtual_router_peering(cmd, instance, peer_asn=None, peer_ip=None):
7550    # both VirtualHub and VirtualRouter own those properties
7551    with cmd.update_context(instance) as c:
7552        c.set_param('peer_asn', peer_asn)
7553        c.set_param('peer_ip', peer_ip)
7554    return instance
7555
7556
7557def list_virtual_router_peering(cmd, resource_group_name, virtual_router_name):
7558    virtual_hub_name = virtual_router_name
7559
7560    from azure.core.exceptions import HttpResponseError
7561    try:
7562        vrouter_client = network_client_factory(cmd.cli_ctx).virtual_routers
7563        vrouter_client.get(resource_group_name, virtual_router_name)
7564    except HttpResponseError:
7565        try:
7566            vhub_client = network_client_factory(cmd.cli_ctx).virtual_hubs
7567            vhub_client.get(resource_group_name, virtual_hub_name)
7568        except HttpResponseError:
7569            msg = 'The VirtualRouter "{}" under resource group "{}" was not found'.format(virtual_hub_name,
7570                                                                                          resource_group_name)
7571            raise CLIError(msg)
7572
7573    try:
7574        vrouter_peering_client = network_client_factory(cmd.cli_ctx).virtual_router_peerings
7575        vrouter_peerings = list(vrouter_peering_client.list(resource_group_name, virtual_router_name))
7576    except HttpResponseError:
7577        vrouter_peerings = []
7578
7579    virtual_hub_name = virtual_router_name
7580    try:
7581        vhub_bgp_conn_client = network_client_factory(cmd.cli_ctx).virtual_hub_bgp_connections
7582        vhub_bgp_connections = list(vhub_bgp_conn_client.list(resource_group_name, virtual_hub_name))
7583    except HttpResponseError:
7584        vhub_bgp_connections = []
7585
7586    return list(vrouter_peerings) + list(vhub_bgp_connections)
7587
7588
7589def show_virtual_router_peering(cmd, resource_group_name, virtual_router_name, peering_name):
7590    from azure.core.exceptions import HttpResponseError
7591    try:
7592        vrouter_client = network_client_factory(cmd.cli_ctx).virtual_routers
7593        vrouter_client.get(resource_group_name, virtual_router_name)
7594    except HttpResponseError:
7595        pass
7596    else:
7597        vrouter_peering_client = network_client_factory(cmd.cli_ctx).virtual_router_peerings
7598        return vrouter_peering_client.get(resource_group_name, virtual_router_name, peering_name)
7599
7600    virtual_hub_name = virtual_router_name
7601    bgp_conn_name = peering_name
7602
7603    # try VirtualHub then if the virtual router doesn't exist
7604    try:
7605        vhub_client = network_client_factory(cmd.cli_ctx).virtual_hubs
7606        vhub_client.get(resource_group_name, virtual_hub_name)
7607    except HttpResponseError:
7608        msg = 'The VirtualRouter "{}" under resource group "{}" was not found'.format(virtual_hub_name,
7609                                                                                      resource_group_name)
7610        raise CLIError(msg)
7611
7612    vhub_bgp_conn_client = network_client_factory(cmd.cli_ctx).virtual_hub_bgp_connection
7613    return vhub_bgp_conn_client.get(resource_group_name, virtual_hub_name, bgp_conn_name)
7614
7615
7616def delete_virtual_router_peering(cmd, resource_group_name, virtual_router_name, peering_name):
7617    from azure.core.exceptions import HttpResponseError
7618    try:
7619        vrouter_client = network_client_factory(cmd.cli_ctx).virtual_routers
7620        vrouter_client.get(resource_group_name, virtual_router_name)
7621    except:  # pylint: disable=bare-except
7622        pass
7623    else:
7624        vrouter_peering_client = network_client_factory(cmd.cli_ctx).virtual_router_peerings
7625        return vrouter_peering_client.begin_delete(resource_group_name, virtual_router_name, peering_name)
7626
7627    virtual_hub_name = virtual_router_name
7628    bgp_conn_name = peering_name
7629
7630    # try VirtualHub then if the virtual router doesn't exist
7631    try:
7632        vhub_client = network_client_factory(cmd.cli_ctx).virtual_hubs
7633        vhub_client.get(resource_group_name, virtual_hub_name)
7634    except HttpResponseError:
7635        msg = 'The VirtualRouter "{}" under resource group "{}" was not found'.format(virtual_hub_name,
7636                                                                                      resource_group_name)
7637        raise CLIError(msg)
7638
7639    vhub_bgp_conn_client = network_client_factory(cmd.cli_ctx).virtual_hub_bgp_connection
7640    return vhub_bgp_conn_client.begin_delete(resource_group_name, virtual_hub_name, bgp_conn_name)
7641# endregion
7642
7643
7644# region service aliases
7645def list_service_aliases(cmd, location, resource_group_name=None):
7646    client = network_client_factory(cmd.cli_ctx).available_service_aliases
7647    if resource_group_name is not None:
7648        return client.list_by_resource_group(resource_group_name=resource_group_name, location=location)
7649    return client.list(location=location)
7650# endregion
7651
7652
7653# region bastion
7654def create_bastion_host(cmd, resource_group_name, bastion_host_name, virtual_network_name,
7655                        public_ip_address, location=None, subnet='AzureBastionSubnet', tags=None):
7656    client = network_client_factory(cmd.cli_ctx).bastion_hosts
7657    (BastionHost,
7658     BastionHostIPConfiguration,
7659     SubResource) = cmd.get_models('BastionHost',
7660                                   'BastionHostIPConfiguration',
7661                                   'SubResource')
7662    ip_config_name = "bastion_ip_config"
7663    ip_configuration = BastionHostIPConfiguration(name=ip_config_name,
7664                                                  subnet=SubResource(id=subnet),
7665                                                  public_ip_address=SubResource(id=public_ip_address))
7666
7667    bastion_host = BastionHost(ip_configurations=[ip_configuration],
7668                               location=location,
7669                               tags=tags)
7670    return client.begin_create_or_update(resource_group_name=resource_group_name,
7671                                         bastion_host_name=bastion_host_name,
7672                                         parameters=bastion_host)
7673
7674
7675def list_bastion_host(cmd, resource_group_name=None):
7676    client = network_client_factory(cmd.cli_ctx).bastion_hosts
7677    if resource_group_name is not None:
7678        return client.list_by_resource_group(resource_group_name=resource_group_name)
7679    return client.list()
7680
7681
7682SSH_EXTENSION_NAME = 'ssh'
7683SSH_EXTENSION_MODULE = 'azext_ssh.custom'
7684SSH_EXTENSION_VERSION = '0.1.3'
7685
7686
7687def _get_azext_module(extension_name, module_name):
7688    try:
7689        # Adding the installed extension in the path
7690        from azure.cli.core.extension.operations import add_extension_to_path
7691        add_extension_to_path(extension_name)
7692        # Import the extension module
7693        from importlib import import_module
7694        azext_custom = import_module(module_name)
7695        return azext_custom
7696    except ImportError as ie:
7697        raise CLIError(ie)
7698
7699
7700def _test_extension(extension_name):
7701    from azure.cli.core.extension import (get_extension)
7702    from pkg_resources import parse_version
7703    ext = get_extension(extension_name)
7704    if parse_version(ext.version) < parse_version(SSH_EXTENSION_VERSION):
7705        raise CLIError('SSH Extension (version >= "{}") must be installed'.format(SSH_EXTENSION_VERSION))
7706
7707
7708def _get_ssh_path(ssh_command="ssh"):
7709    import os
7710    ssh_path = ssh_command
7711
7712    if platform.system() == 'Windows':
7713        arch_data = platform.architecture()
7714        is_32bit = arch_data[0] == '32bit'
7715        sys_path = 'SysNative' if is_32bit else 'System32'
7716        system_root = os.environ['SystemRoot']
7717        system32_path = os.path.join(system_root, sys_path)
7718        ssh_path = os.path.join(system32_path, "openSSH", (ssh_command + ".exe"))
7719        logger.debug("Platform architecture: %s", str(arch_data))
7720        logger.debug("System Root: %s", system_root)
7721        logger.debug("Attempting to run ssh from path %s", ssh_path)
7722
7723        if not os.path.isfile(ssh_path):
7724            raise CLIError("Could not find " + ssh_command + ".exe. Is the OpenSSH client installed?")
7725    else:
7726        raise UnrecognizedArgumentError("Platform is not supported for thie command. Supported platforms: Windows")
7727
7728    return ssh_path
7729
7730
7731def _get_rdp_path(rdp_command="mstsc"):
7732    import os
7733    rdp_path = rdp_command
7734
7735    if platform.system() == 'Windows':
7736        arch_data = platform.architecture()
7737        sys_path = 'System32'
7738        system_root = os.environ['SystemRoot']
7739        system32_path = os.path.join(system_root, sys_path)
7740        rdp_path = os.path.join(system32_path, (rdp_command + ".exe"))
7741        logger.debug("Platform architecture: %s", str(arch_data))
7742        logger.debug("System Root: %s", system_root)
7743        logger.debug("Attempting to run rdp from path %s", rdp_path)
7744
7745        if not os.path.isfile(rdp_path):
7746            raise CLIError("Could not find " + rdp_command + ".exe. Is the rdp client installed?")
7747    else:
7748        raise UnrecognizedArgumentError("Platform is not supported for thie command. Supported platforms: Windows")
7749
7750    return rdp_path
7751
7752
7753def _get_host(username, ip):
7754    return username + "@" + ip
7755
7756
7757def _build_args(cert_file, private_key_file):
7758    private_key = []
7759    certificate = []
7760    if private_key_file:
7761        private_key = ["-i", private_key_file]
7762    if cert_file:
7763        certificate = ["-o", "CertificateFile=" + cert_file]
7764    return private_key + certificate
7765
7766
7767def ssh_bastion_host(cmd, auth_type, target_resource_id, resource_group_name, bastion_host_name, resource_port=None, username=None, ssh_key=None):
7768
7769    _test_extension(SSH_EXTENSION_NAME)
7770
7771    if not resource_port:
7772        resource_port = 22
7773    if not is_valid_resource_id(target_resource_id):
7774        raise InvalidArgumentValueError("Please enter a valid Virtual Machine resource Id.")
7775
7776    tunnel_server = get_tunnel(cmd, resource_group_name, bastion_host_name, target_resource_id, resource_port)
7777    t = threading.Thread(target=_start_tunnel, args=(tunnel_server,))
7778    t.daemon = True
7779    t.start()
7780    if auth_type.lower() == 'password':
7781        if username is None:
7782            raise RequiredArgumentMissingError("Please enter username with --username.")
7783        command = [_get_ssh_path(), _get_host(username, 'localhost')]
7784    elif auth_type.lower() == 'aad':
7785        azssh = _get_azext_module(SSH_EXTENSION_NAME, SSH_EXTENSION_MODULE)
7786        public_key_file, private_key_file = azssh._check_or_create_public_private_files(None, None)  # pylint: disable=protected-access
7787        cert_file, username = azssh._get_and_write_certificate(cmd, public_key_file, private_key_file + '-cert.pub')  # pylint: disable=protected-access
7788        command = [_get_ssh_path(), _get_host(username, 'localhost')]
7789        command = command + _build_args(cert_file, private_key_file)
7790    elif auth_type.lower() == 'ssh-key':
7791        if username is None or ssh_key is None:
7792            raise RequiredArgumentMissingError("Please enter username --username and ssh cert location --ssh-key.")
7793        command = [_get_ssh_path(), _get_host(username, 'localhost')]
7794        command = command + _build_args(None, ssh_key)
7795    else:
7796        raise UnrecognizedArgumentError("Unknown auth type. Use one of password, aad or ssh-key.")
7797    command = command + ["-p", str(tunnel_server.local_port)]
7798    command = command + ['-o', "StrictHostKeyChecking=no", '-o', "UserKnownHostsFile=/dev/null"]
7799    command = command + ['-o', "LogLevel=Error"]
7800    logger.debug("Running ssh command %s", ' '.join(command))
7801    try:
7802        subprocess.call(command, shell=platform.system() == 'Windows')
7803    except Exception as ex:
7804        raise CLIInternalError(ex)
7805
7806
7807def rdp_bastion_host(cmd, target_resource_id, resource_group_name, bastion_host_name, resource_port=None):
7808    if not resource_port:
7809        resource_port = 3389
7810    if not is_valid_resource_id(target_resource_id):
7811        raise InvalidArgumentValueError("Please enter a valid Virtual Machine resource Id.")
7812
7813    tunnel_server = get_tunnel(cmd, resource_group_name, bastion_host_name, target_resource_id, resource_port)
7814    t = threading.Thread(target=_start_tunnel, args=(tunnel_server,))
7815    t.daemon = True
7816    t.start()
7817    command = [_get_rdp_path(), "/v:localhost:{0}".format(tunnel_server.local_port)]
7818    logger.debug("Running rdp command %s", ' '.join(command))
7819    subprocess.call(command, shell=platform.system() == 'Windows')
7820    tunnel_server.cleanup()
7821
7822
7823def get_tunnel(cmd, resource_group_name, name, vm_id, resource_port, port=None):
7824    from .tunnel import TunnelServer
7825    client = network_client_factory(cmd.cli_ctx).bastion_hosts
7826    bastion = client.get(resource_group_name, name)
7827    if port is None:
7828        port = 0  # Will auto-select a free port from 1024-65535
7829    tunnel_server = TunnelServer(cmd.cli_ctx, 'localhost', port, bastion, vm_id, resource_port)
7830    return tunnel_server
7831
7832
7833def create_bastion_tunnel(cmd, target_resource_id, resource_group_name, bastion_host_name, resource_port, port, timeout=None):
7834    if not is_valid_resource_id(target_resource_id):
7835        raise InvalidArgumentValueError("Please enter a valid Virtual Machine resource Id.")
7836    tunnel_server = get_tunnel(cmd, resource_group_name, bastion_host_name, target_resource_id, resource_port, port)
7837    t = threading.Thread(target=_start_tunnel, args=(tunnel_server,))
7838    t.daemon = True
7839    t.start()
7840    logger.warning('Opening tunnel on port: %s', tunnel_server.local_port)
7841    logger.warning('Tunnel is ready, connect on port %s', tunnel_server.local_port)
7842    logger.warning('Ctrl + C to close')
7843
7844    if timeout:
7845        time.sleep(int(timeout))
7846    else:
7847        while t.is_alive():
7848            time.sleep(5)
7849
7850
7851def _start_tunnel(tunnel_server):
7852    tunnel_server.start_server()
7853# endregion
7854
7855
7856# region security partner provider
7857def create_security_partner_provider(cmd, resource_group_name, security_partner_provider_name,
7858                                     security_provider_name, virtual_hub, location=None, tags=None):
7859    client = network_client_factory(cmd.cli_ctx).security_partner_providers
7860    SecurityPartnerProvider, SubResource = cmd.get_models('SecurityPartnerProvider', 'SubResource')
7861
7862    security_partner_provider = SecurityPartnerProvider(security_provider_name=security_provider_name,
7863                                                        virtual_hub=SubResource(id=virtual_hub),
7864                                                        location=location,
7865                                                        tags=tags)
7866    return client.begin_create_or_update(resource_group_name=resource_group_name,
7867                                         security_partner_provider_name=security_partner_provider_name,
7868                                         parameters=security_partner_provider)
7869
7870
7871def update_security_partner_provider(instance, cmd, security_provider_name=None, virtual_hub=None, tags=None):
7872    with cmd.update_context(instance) as c:
7873        c.set_param('security_provider_name', security_provider_name)
7874        c.set_param('virtual_hub', virtual_hub)
7875        c.set_param('tags', tags)
7876    return instance
7877
7878
7879def list_security_partner_provider(cmd, resource_group_name=None):
7880    client = network_client_factory(cmd.cli_ctx).security_partner_providers
7881    if resource_group_name is not None:
7882        return client.list_by_resource_group(resource_group_name=resource_group_name)
7883    return client.list()
7884# endregion
7885
7886
7887# region network gateway connection
7888def reset_shared_key(cmd, client, virtual_network_gateway_connection_name, key_length, resource_group_name=None):
7889    ConnectionResetSharedKey = cmd.get_models('ConnectionResetSharedKey')
7890    shared_key = ConnectionResetSharedKey(key_length=key_length)
7891    return client.begin_reset_shared_key(resource_group_name=resource_group_name,
7892                                         virtual_network_gateway_connection_name=virtual_network_gateway_connection_name,  # pylint: disable=line-too-long
7893                                         parameters=shared_key)
7894
7895
7896def update_shared_key(cmd, instance, value):
7897    with cmd.update_context(instance) as c:
7898        c.set_param('value', value)
7899    return instance
7900# endregion
7901
7902
7903# region network virtual appliance
7904def create_network_virtual_appliance(cmd, client, resource_group_name, network_virtual_appliance_name,
7905                                     vendor, bundled_scale_unit, market_place_version,
7906                                     virtual_hub, boot_strap_configuration_blobs=None,
7907                                     cloud_init_configuration_blobs=None,
7908                                     cloud_init_configuration=None, asn=None,
7909                                     location=None, tags=None, no_wait=False):
7910    (NetworkVirtualAppliance,
7911     SubResource,
7912     VirtualApplianceSkuProperties) = cmd.get_models('NetworkVirtualAppliance',
7913                                                     'SubResource',
7914                                                     'VirtualApplianceSkuProperties')
7915
7916    virtual_appliance = NetworkVirtualAppliance(boot_strap_configuration_blobs=boot_strap_configuration_blobs,
7917                                                cloud_init_configuration_blobs=cloud_init_configuration_blobs,
7918                                                cloud_init_configuration=cloud_init_configuration,
7919                                                virtual_appliance_asn=asn,
7920                                                virtual_hub=SubResource(id=virtual_hub),
7921                                                nva_sku=VirtualApplianceSkuProperties(
7922                                                    vendor=vendor,
7923                                                    bundled_scale_unit=bundled_scale_unit,
7924                                                    market_place_version=market_place_version
7925                                                ),
7926                                                location=location,
7927                                                tags=tags)
7928
7929    return sdk_no_wait(no_wait, client.begin_create_or_update,
7930                       resource_group_name, network_virtual_appliance_name, virtual_appliance)
7931
7932
7933def update_network_virtual_appliance(instance, cmd, cloud_init_configuration=None, asn=None):
7934    with cmd.update_context(instance) as c:
7935        c.set_param('virtual_appliance_asn', asn)
7936        c.set_param('cloud_init_configuration', cloud_init_configuration)
7937    return instance
7938
7939
7940def list_network_virtual_appliance(cmd, client, resource_group_name=None):
7941    if resource_group_name:
7942        return client.list_by_resource_group(resource_group_name=resource_group_name)
7943    return client.list()
7944
7945
7946def create_network_virtual_appliance_site(cmd, client, resource_group_name, network_virtual_appliance_name,
7947                                          site_name, address_prefix, allow=None, optimize=None, default=None,
7948                                          no_wait=False):
7949    (BreakOutCategoryPolicies,
7950     Office365PolicyProperties,
7951     VirtualApplianceSite) = cmd.get_models('BreakOutCategoryPolicies',
7952                                            'Office365PolicyProperties',
7953                                            'VirtualApplianceSite')
7954
7955    virtual_appliance_site = VirtualApplianceSite(address_prefix=address_prefix,
7956                                                  o365_policy=Office365PolicyProperties(
7957                                                      break_out_categories=BreakOutCategoryPolicies(
7958                                                          allow=allow,
7959                                                          optimize=optimize,
7960                                                          default=default
7961                                                      )))
7962    return sdk_no_wait(no_wait, client.begin_create_or_update,
7963                       resource_group_name, network_virtual_appliance_name, site_name, virtual_appliance_site)
7964
7965
7966def update_network_virtual_appliance_site(instance, cmd, address_prefix, allow=None, optimize=None, default=None):
7967    with cmd.update_context(instance) as c:
7968        c.set_param('address_prefix', address_prefix)
7969        c.set_param('o365_policy.break_out_categories.allow', allow)
7970        c.set_param('o365_policy.break_out_categories.optimize', optimize)
7971        c.set_param('o365_policy.break_out_categories.default', default)
7972    return instance
7973# endregion
7974