1# --------------------------------------------------------------------------------------------
2# Copyright (c) Microsoft Corporation. All rights reserved.
3# Licensed under the MIT License. See License.txt in the project root for license information.
4# --------------------------------------------------------------------------------------------
5
6import json
7import os
8import re
9
10from azure.cli.core.commands.arm import ArmTemplateBuilder
11
12try:
13    from urllib.parse import urlparse
14except ImportError:
15    from urlparse import urlparse  # pylint: disable=import-error
16
17from knack.log import get_logger
18from knack.util import CLIError
19
20logger = get_logger(__name__)
21
22
23MSI_LOCAL_ID = '[system]'
24
25
26def get_target_network_api(cli_ctx):
27    """ Since most compute calls don't need advanced network functionality, we can target a supported, but not
28        necessarily latest, network API version is order to avoid having to re-record every test that uses VM create
29        (which there are a lot) whenever NRP bumps their API version (which is often)!
30    """
31    from azure.cli.core.profiles import get_api_version, ResourceType, AD_HOC_API_VERSIONS
32    version = get_api_version(cli_ctx, ResourceType.MGMT_NETWORK)
33    if cli_ctx.cloud.profile == 'latest':
34        version = AD_HOC_API_VERSIONS[ResourceType.MGMT_NETWORK]['vm_default_target_network']
35    return version
36
37
38def read_content_if_is_file(string_or_file):
39    content = string_or_file
40    if os.path.exists(string_or_file):
41        with open(string_or_file, 'r') as f:
42            content = f.read()
43    return content
44
45
46def _resolve_api_version(cli_ctx, provider_namespace, resource_type, parent_path):
47    from azure.cli.core.commands.client_factory import get_mgmt_service_client
48    from azure.cli.core.profiles import ResourceType
49    client = get_mgmt_service_client(cli_ctx, ResourceType.MGMT_RESOURCE_RESOURCES)
50    provider = client.providers.get(provider_namespace)
51
52    # If available, we will use parent resource's api-version
53    resource_type_str = (parent_path.split('/')[0] if parent_path else resource_type)
54
55    rt = [t for t in provider.resource_types  # pylint: disable=no-member
56          if t.resource_type.lower() == resource_type_str.lower()]
57    if not rt:
58        raise CLIError('Resource type {} not found.'.format(resource_type_str))
59    if len(rt) == 1 and rt[0].api_versions:
60        npv = [v for v in rt[0].api_versions if 'preview' not in v.lower()]
61        return npv[0] if npv else rt[0].api_versions[0]
62    raise CLIError(
63        'API version is required and could not be resolved for resource {}'
64        .format(resource_type))
65
66
67def log_pprint_template(template):
68    logger.info('==== BEGIN TEMPLATE ====')
69    logger.info(json.dumps(template, indent=2))
70    logger.info('==== END TEMPLATE ====')
71
72
73def check_existence(cli_ctx, value, resource_group, provider_namespace, resource_type,
74                    parent_name=None, parent_type=None):
75    # check for name or ID and set the type flags
76    from azure.cli.core.commands.client_factory import get_mgmt_service_client
77    from azure.core.exceptions import HttpResponseError
78    from msrestazure.tools import parse_resource_id
79    from azure.cli.core.profiles import ResourceType
80    id_parts = parse_resource_id(value)
81    resource_client = get_mgmt_service_client(cli_ctx, ResourceType.MGMT_RESOURCE_RESOURCES,
82                                              subscription_id=id_parts.get('subscription', None)).resources
83    rg = id_parts.get('resource_group', resource_group)
84    ns = id_parts.get('namespace', provider_namespace)
85
86    if parent_name and parent_type:
87        parent_path = '{}/{}'.format(parent_type, parent_name)
88        resource_name = id_parts.get('child_name_1', value)
89        resource_type = id_parts.get('child_type_1', resource_type)
90    else:
91        parent_path = ''
92        resource_name = id_parts['name']
93        resource_type = id_parts.get('type', resource_type)
94    api_version = _resolve_api_version(cli_ctx, provider_namespace, resource_type, parent_path)
95
96    try:
97        resource_client.get(rg, ns, parent_path, resource_type, resource_name, api_version)
98        return True
99    except HttpResponseError:
100        return False
101
102
103def create_keyvault_data_plane_client(cli_ctx):
104    from azure.cli.command_modules.keyvault._client_factory import keyvault_data_plane_factory
105    return keyvault_data_plane_factory(cli_ctx)
106
107
108def get_key_vault_base_url(cli_ctx, vault_name):
109    suffix = cli_ctx.cloud.suffixes.keyvault_dns
110    return 'https://{}{}'.format(vault_name, suffix)
111
112
113def list_sku_info(cli_ctx, location=None):
114    from ._client_factory import _compute_client_factory
115
116    def _match_location(loc, locations):
117        return next((x for x in locations if x.lower() == loc.lower()), None)
118
119    client = _compute_client_factory(cli_ctx)
120    result = client.resource_skus.list()
121    if location:
122        result = [r for r in result if _match_location(location, r.locations)]
123    return result
124
125
126# pylint: disable=too-many-statements
127def normalize_disk_info(image_data_disks=None,
128                        data_disk_sizes_gb=None, attach_data_disks=None, storage_sku=None,
129                        os_disk_caching=None, data_disk_cachings=None, size='', ephemeral_os_disk=False,
130                        data_disk_delete_option=None):
131    from msrestazure.tools import is_valid_resource_id
132    from ._validators import validate_delete_options
133    is_lv_size = re.search('_L[0-9]+s', size, re.I)
134    # we should return a dictionary with info like below
135    # {
136    #   'os': { caching: 'Read', write_accelerator: None},
137    #   0: { caching: 'None', write_accelerator: True},
138    #   1: { caching: 'None', write_accelerator: True},
139    # }
140    info = {}
141    used_luns = set()
142
143    attach_data_disks = attach_data_disks or []
144    data_disk_sizes_gb = data_disk_sizes_gb or []
145    image_data_disks = image_data_disks or []
146
147    data_disk_delete_option = validate_delete_options(attach_data_disks, data_disk_delete_option)
148    info['os'] = {}
149    # update os diff disk settings
150    if ephemeral_os_disk:
151        info['os']['diffDiskSettings'] = {'option': 'Local'}
152        # local os disks require readonly caching, default to ReadOnly if os_disk_caching not specified.
153        if not os_disk_caching:
154            os_disk_caching = 'ReadOnly'
155
156    # add managed image data disks
157    for data_disk in image_data_disks:
158        i = data_disk['lun']
159        info[i] = {
160            'lun': i,
161            'managedDisk': {'storageAccountType': None},
162            'createOption': 'fromImage'
163        }
164        used_luns.add(i)
165
166    # add empty data disks, do not use existing luns
167    i = 0
168    sizes_copy = list(data_disk_sizes_gb)
169    while sizes_copy:
170        # get free lun
171        while i in used_luns:
172            i += 1
173
174        used_luns.add(i)
175
176        info[i] = {
177            'lun': i,
178            'managedDisk': {'storageAccountType': None},
179            'createOption': 'empty',
180            'diskSizeGB': sizes_copy.pop(0),
181            'deleteOption': data_disk_delete_option if isinstance(data_disk_delete_option, str) else None
182        }
183
184    # update storage skus for managed data disks
185    if storage_sku is not None:
186        update_disk_sku_info(info, storage_sku)
187
188    # check that os storage account type is not UltraSSD_LRS
189    if info['os'].get('storageAccountType', "").lower() == 'ultrassd_lrs':
190        logger.warning("Managed os disk storage account sku cannot be UltraSSD_LRS. Using service default.")
191        info['os']['storageAccountType'] = None
192
193    # add attached data disks
194    i = 0
195    attach_data_disks_copy = list(attach_data_disks)
196    while attach_data_disks_copy:
197        # get free lun
198        while i in used_luns:
199            i += 1
200
201        used_luns.add(i)
202
203        # use free lun
204        info[i] = {
205            'lun': i,
206            'createOption': 'attach'
207        }
208
209        d = attach_data_disks_copy.pop(0)
210
211        if is_valid_resource_id(d):
212            info[i]['managedDisk'] = {'id': d}
213            if data_disk_delete_option:
214                info[i]['deleteOption'] = data_disk_delete_option if isinstance(data_disk_delete_option, str) \
215                    else data_disk_delete_option.get(info[i]['name'], None)
216        else:
217            info[i]['vhd'] = {'uri': d}
218            info[i]['name'] = d.split('/')[-1].split('.')[0]
219            if data_disk_delete_option:
220                info[i]['deleteOption'] = data_disk_delete_option if isinstance(data_disk_delete_option, str) \
221                    else data_disk_delete_option.get(info[i]['name'], None)
222
223    # fill in data disk caching
224    if data_disk_cachings:
225        update_disk_caching(info, data_disk_cachings)
226
227    # default os disk caching to 'ReadWrite' unless set otherwise
228    if os_disk_caching:
229        info['os']['caching'] = os_disk_caching
230    else:
231        info['os']['caching'] = 'None' if is_lv_size else 'ReadWrite'
232
233    # error out on invalid vm sizes
234    if is_lv_size:
235        for v in info.values():
236            if v.get('caching', 'None').lower() != 'none':
237                raise CLIError('usage error: for Lv series of machines, "None" is the only supported caching mode')
238
239    result_info = {'os': info['os']}
240
241    # in python 3 insertion order matters during iteration. This ensures that luns are retrieved in numerical order
242    for key in sorted([key for key in info if key != 'os']):
243        result_info[key] = info[key]
244
245    return result_info
246
247
248def update_disk_caching(model, caching_settings):
249
250    def _update(model, lun, value):
251        if isinstance(model, dict):
252            luns = model.keys() if lun is None else [lun]
253            for lun_item in luns:
254                if lun_item not in model:
255                    raise CLIError("Data disk with lun of '{}' doesn't exist. Existing luns: {}."
256                                   .format(lun_item, list(model.keys())))
257                model[lun_item]['caching'] = value
258        else:
259            if lun is None:
260                disks = [model.os_disk] + (model.data_disks or [])
261            elif lun == 'os':
262                disks = [model.os_disk]
263            else:
264                disk = next((d for d in model.data_disks if d.lun == lun), None)
265                if not disk:
266                    raise CLIError("data disk with lun of '{}' doesn't exist".format(lun))
267                disks = [disk]
268            for disk in disks:
269                disk.caching = value
270
271    if len(caching_settings) == 1 and '=' not in caching_settings[0]:
272        _update(model, None, caching_settings[0])
273    else:
274        for x in caching_settings:
275            if '=' not in x:
276                raise CLIError("usage error: please use 'LUN=VALUE' to configure caching on individual disk")
277            lun, value = x.split('=', 1)
278            lun = lun.lower()
279            lun = int(lun) if lun != 'os' else lun
280            _update(model, lun, value)
281
282
283def update_write_accelerator_settings(model, write_accelerator_settings):
284
285    def _update(model, lun, value):
286        if isinstance(model, dict):
287            luns = model.keys() if lun is None else [lun]
288            for lun_item in luns:
289                if lun_item not in model:
290                    raise CLIError("data disk with lun of '{}' doesn't exist".format(lun_item))
291                model[lun_item]['writeAcceleratorEnabled'] = value
292        else:
293            if lun is None:
294                disks = [model.os_disk] + (model.data_disks or [])
295            elif lun == 'os':
296                disks = [model.os_disk]
297            else:
298                disk = next((d for d in model.data_disks if d.lun == lun), None)
299                if not disk:
300                    raise CLIError("data disk with lun of '{}' doesn't exist".format(lun))
301                disks = [disk]
302            for disk in disks:
303                disk.write_accelerator_enabled = value
304
305    if len(write_accelerator_settings) == 1 and '=' not in write_accelerator_settings[0]:
306        _update(model, None, write_accelerator_settings[0].lower() == 'true')
307    else:
308        for x in write_accelerator_settings:
309            if '=' not in x:
310                raise CLIError("usage error: please use 'LUN=VALUE' to configure write accelerator"
311                               " on individual disk")
312            lun, value = x.split('=', 1)
313            lun = lun.lower()
314            lun = int(lun) if lun != 'os' else lun
315            _update(model, lun, value.lower() == 'true')
316
317
318def get_storage_blob_uri(cli_ctx, storage):
319    from azure.cli.core.profiles._shared import ResourceType
320    from azure.cli.core.commands.client_factory import get_mgmt_service_client
321    if urlparse(storage).scheme:
322        storage_uri = storage
323    else:
324        storage_mgmt_client = get_mgmt_service_client(cli_ctx, ResourceType.MGMT_STORAGE)
325        storage_accounts = storage_mgmt_client.storage_accounts.list()
326        storage_account = next((a for a in list(storage_accounts)
327                                if a.name.lower() == storage.lower()), None)
328        if storage_account is None:
329            raise CLIError('{} does\'t exist.'.format(storage))
330        storage_uri = storage_account.primary_endpoints.blob
331    return storage_uri
332
333
334def update_disk_sku_info(info_dict, skus):
335    usage_msg = 'Usage:\n\t[--storage-sku SKU | --storage-sku ID=SKU ID=SKU ID=SKU...]\n' \
336                'where each ID is "os" or a 0-indexed lun.'
337
338    def _update(info, lun, value):
339        luns = info.keys()
340        if lun not in luns:
341            raise CLIError("Data disk with lun of '{}' doesn't exist. Existing luns: {}.".format(lun, luns))
342        if lun == 'os':
343            info[lun]['storageAccountType'] = value
344        else:
345            info[lun]['managedDisk']['storageAccountType'] = value
346
347    if len(skus) == 1 and '=' not in skus[0]:
348        for lun in info_dict.keys():
349            _update(info_dict, lun, skus[0])
350    else:
351        for sku in skus:
352            if '=' not in sku:
353                raise CLIError("A sku's format is incorrect.\n{}".format(usage_msg))
354
355            lun, value = sku.split('=', 1)
356            lun = lun.lower()
357            try:
358                lun = int(lun) if lun != "os" else lun
359            except ValueError:
360                raise CLIError("A sku ID is incorrect.\n{}".format(usage_msg))
361            _update(info_dict, lun, value)
362
363
364def is_shared_gallery_image_id(image_reference):
365    if not image_reference:
366        return False
367
368    shared_gallery_id_pattern = re.compile(r'^/SharedGalleries/[^/]*/Images/[^/]*/Versions/.*$', re.IGNORECASE)
369    if shared_gallery_id_pattern.match(image_reference):
370        return True
371
372    return False
373
374
375def parse_shared_gallery_image_id(image_reference):
376    from azure.cli.core.azclierror import InvalidArgumentValueError
377
378    if not image_reference:
379        raise InvalidArgumentValueError(
380            'Please pass in the shared gallery image id through the parameter --image')
381
382    image_info = re.search(r'^/SharedGalleries/([^/]*)/Images/([^/]*)/Versions/.*$', image_reference, re.IGNORECASE)
383    if not image_info or len(image_info.groups()) < 2:
384        raise InvalidArgumentValueError(
385            'The shared gallery image id is invalid. The valid format should be '
386            '"/SharedGalleries/{gallery_unique_name}/Images/{gallery_image_name}/Versions/{image_version}"')
387
388    # Return the gallery unique name and gallery image name parsed from shared gallery image id
389    return image_info.group(1), image_info.group(2)
390
391
392class ArmTemplateBuilder20190401(ArmTemplateBuilder):
393
394    def __init__(self):
395        super().__init__()
396        self.template['$schema'] = 'https://schema.management.azure.com/schemas/2019-04-01/deploymentTemplate.json#'
397