1# --------------------------------------------------------------------------------------------
2# Copyright (c) Microsoft Corporation. All rights reserved.
3# Licensed under the MIT License. See License.txt in the project root for license information.
4# --------------------------------------------------------------------------------------------
5
6import os
7import azure.batch.models
8from azure.cli.core.util import get_file_json
9from six.moves.urllib.parse import urlsplit  # pylint: disable=import-error
10
11
12# TYPES VALIDATORS
13
14def datetime_format(value):
15    """Validate the correct format of a datetime string and deserialize."""
16    from msrest.serialization import Deserializer
17    from msrest.exceptions import DeserializationError
18    try:
19        datetime_obj = Deserializer.deserialize_iso(value)
20    except DeserializationError:
21        message = "Argument {} is not a valid ISO-8601 datetime format"
22        raise ValueError(message.format(value))
23    return datetime_obj
24
25
26def disk_encryption_target_format(value):
27    """Space seperated target disks to be encrypted. Values can either be OsDisk or TemporaryDisk"""
28    if value == 'OsDisk':
29        return azure.batch.models.DiskEncryptionTarget.os_disk
30    if value == 'TemporaryDisk':
31        return azure.batch.models.DiskEncryptionTarget.temporary_disk
32    message = 'Argument {} is not a valid disk_encryption_target'
33    raise ValueError(message.format(value))
34
35
36def disk_encryption_configuration_format(value):
37    targets = value.split(' ')
38    parsed_targets = []
39    for target in targets:
40        parsed_targets.append(disk_encryption_target_format(target))
41    return targets
42
43
44def duration_format(value):
45    """Validate the correct format of a timespan string and deserilize."""
46    from msrest.serialization import Deserializer
47    from msrest.exceptions import DeserializationError
48    try:
49        duration_obj = Deserializer.deserialize_duration(value)
50    except DeserializationError:
51        message = "Argument {} is not in a valid ISO-8601 duration format"
52        raise ValueError(message.format(value))
53    return duration_obj
54
55
56def metadata_item_format(value):
57    """Space-separated values in 'key=value' format."""
58    try:
59        data_name, data_value = value.split('=')
60    except ValueError:
61        message = ("Incorrectly formatted metadata. "
62                   "Argument values should be in the format a=b c=d")
63        raise ValueError(message)
64    return {'name': data_name, 'value': data_value}
65
66
67def environment_setting_format(value):
68    """Space-separated values in 'key=value' format."""
69    try:
70        env_name, env_value = value.split('=')
71    except ValueError:
72        message = ("Incorrectly formatted environment settings. "
73                   "Argument values should be in the format a=b c=d")
74        raise ValueError(message)
75    return {'name': env_name, 'value': env_value}
76
77
78def application_package_reference_format(value):
79    """Space-separated application IDs with optional version in 'id[#version]' format."""
80    app_reference = value.split('#', 1)
81    package = {'application_id': app_reference[0]}
82    try:
83        package['version'] = app_reference[1]
84    except IndexError:  # No specified version - ignore
85        pass
86    return package
87
88
89def certificate_reference_format(value):
90    """Space-separated certificate thumbprints."""
91    cert = {'thumbprint': value, 'thumbprint_algorithm': 'sha1'}
92    return cert
93
94
95def task_id_ranges_format(value):
96    """Space-separated number ranges in 'start-end' format."""
97    try:
98        start, end = [int(i) for i in value.split('-')]
99    except ValueError:
100        message = ("Incorrectly formatted task ID range. "
101                   "Argument values should be numbers in the format 'start-end'")
102        raise ValueError(message)
103    return {'start': start, 'end': end}
104
105
106def resource_file_format(value):
107    """Space-separated resource references in filename=httpurl format."""
108    try:
109        file_name, http_url = value.split('=', 1)
110    except ValueError:
111        message = ("Incorrectly formatted resource reference. "
112                   "Argument values should be in the format filename=httpurl")
113        raise ValueError(message)
114    return {'file_path': file_name, 'http_url': http_url}
115
116
117# COMMAND NAMESPACE VALIDATORS
118
119def validate_required_parameter(namespace, parser):
120    """Validates required parameters in Batch complex objects"""
121    if not parser.done:
122        parser.parse(namespace)
123
124
125def storage_account_id(cmd, namespace):
126    """Validate storage account name"""
127    from azure.cli.core.profiles import ResourceType
128    from azure.cli.core.commands.client_factory import get_mgmt_service_client
129
130    if (namespace.storage_account and not
131            ('/providers/Microsoft.ClassicStorage/storageAccounts/' in namespace.storage_account or
132             '/providers/Microsoft.Storage/storageAccounts/' in namespace.storage_account)):
133        storage_client = get_mgmt_service_client(cmd.cli_ctx, ResourceType.MGMT_STORAGE)
134        acc = storage_client.storage_accounts.get_properties(namespace.resource_group_name,
135                                                             namespace.storage_account)
136        if not acc:
137            raise ValueError("Storage account named '{}' not found in the resource group '{}'.".
138                             format(namespace.storage_account, namespace.resource_group_name))
139        namespace.storage_account = acc.id  # pylint: disable=no-member
140
141
142def keyvault_id(cmd, namespace):
143    """Validate storage account name"""
144    from azure.cli.core.profiles import ResourceType
145    from azure.cli.core.commands.client_factory import get_mgmt_service_client
146    if not namespace.keyvault:
147        return
148    if '/providers/Microsoft.KeyVault/vaults/' in namespace.keyvault:
149        resource = namespace.keyvault.split('/')
150        kv_name = resource[resource.index('Microsoft.KeyVault') + 2]
151        kv_rg = resource[resource.index('resourceGroups') + 1]
152    else:
153        kv_name = namespace.keyvault
154        kv_rg = namespace.resource_group_name
155    try:
156        keyvault_client = get_mgmt_service_client(cmd.cli_ctx, ResourceType.MGMT_KEYVAULT)
157        vault = keyvault_client.vaults.get(kv_rg, kv_name)
158        if not vault:
159            raise ValueError("KeyVault named '{}' not found in the resource group '{}'.".
160                             format(kv_name, kv_rg))
161        namespace.keyvault = vault.id  # pylint: disable=no-member
162        namespace.keyvault_url = vault.properties.vault_uri
163    except Exception as exp:
164        raise ValueError('Invalid KeyVault reference: {}\n{}'.format(namespace.keyvault, exp))
165
166
167def application_enabled(cmd, namespace):
168    """Validates account has auto-storage enabled"""
169    from azure.mgmt.batch import BatchManagementClient
170    from azure.cli.core.commands.client_factory import get_mgmt_service_client
171
172    client = get_mgmt_service_client(cmd.cli_ctx, BatchManagementClient)
173    acc = client.batch_account.get(namespace.resource_group_name, namespace.account_name)
174    if not acc:
175        raise ValueError("Batch account '{}' not found.".format(namespace.account_name))
176    if not acc.auto_storage or not acc.auto_storage.storage_account_id:  # pylint: disable=no-member
177        raise ValueError("Batch account '{}' needs auto-storage enabled.".
178                         format(namespace.account_name))
179
180
181def validate_pool_resize_parameters(namespace):
182    """Validate pool resize parameters correct"""
183    if not namespace.abort and not namespace.target_dedicated_nodes:
184        raise ValueError("The target-dedicated-nodes parameter is required to resize the pool.")
185
186
187def validate_json_file(namespace):
188    """Validate the give json file existing"""
189    if namespace.json_file:
190        try:
191            get_file_json(namespace.json_file)
192        except EnvironmentError:
193            raise ValueError("Cannot access JSON request file: " + namespace.json_file)
194        except ValueError as err:
195            raise ValueError("Invalid JSON file: {}".format(err))
196
197
198def validate_cert_file(namespace):
199    """Validate the give cert file existing"""
200    try:
201        with open(namespace.certificate_file, "rb"):
202            pass
203    except EnvironmentError:
204        raise ValueError("Cannot access certificate file: " + namespace.certificate_file)
205
206
207def validate_options(namespace):
208    """Validate any flattened request header option arguments."""
209    try:
210        start = namespace.start_range
211        end = namespace.end_range
212    except AttributeError:
213        return
214    else:
215        namespace.ocp_range = None
216        del namespace.start_range
217        del namespace.end_range
218        if start or end:
219            start = start if start else 0
220            end = end if end else ""
221            namespace.ocp_range = "bytes={}-{}".format(start, end)
222
223
224def validate_file_destination(namespace):
225    """Validate the destination path for a file download."""
226    try:
227        path = namespace.destination
228    except AttributeError:
229        return
230    else:
231        # TODO: Need to confirm this logic...
232        file_path = path
233        file_dir = os.path.dirname(path)
234        if os.path.isdir(path):
235            file_name = os.path.basename(namespace.file_name)
236            file_path = os.path.join(path, file_name)
237        elif not os.path.isdir(file_dir):
238            try:
239                os.mkdir(file_dir)
240            except EnvironmentError as exp:
241                message = "Directory {} does not exist, and cannot be created: {}"
242                raise ValueError(message.format(file_dir, exp))
243        if os.path.isfile(file_path):
244            raise ValueError("File {} already exists.".format(file_path))
245        namespace.destination = file_path
246
247# CUSTOM REQUEST VALIDATORS
248
249
250def validate_pool_settings(namespace, parser):
251    """Custom parsing to enfore that either PaaS or IaaS instances are configured
252    in the add pool request body.
253    """
254    if not namespace.json_file:
255        if namespace.node_agent_sku_id and not namespace.image:
256            raise ValueError("Missing required argument: --image")
257        if namespace.image:
258            try:
259                namespace.publisher, namespace.offer, namespace.sku = namespace.image.split(':', 2)
260                try:
261                    namespace.sku, namespace.version = namespace.sku.split(':', 1)
262                except ValueError:
263                    pass
264            except ValueError:
265                if '/' not in namespace.image:
266                    message = ("Incorrect format for VM image. Should be in the format: \n"
267                               "'publisher:offer:sku[:version]' OR a URL to an ARM image.")
268                    raise ValueError(message)
269
270                namespace.virtual_machine_image_id = namespace.image
271            del namespace.image
272            if namespace.disk_encryption_targets:
273                namespace.targets = namespace.disk_encryption_targets
274                del namespace.disk_encryption_targets
275        groups = ['pool.cloud_service_configuration', 'pool.virtual_machine_configuration']
276        parser.parse_mutually_exclusive(namespace, True, groups)
277
278        paas_sizes = ['small', 'medium', 'large', 'extralarge']
279        if namespace.vm_size and namespace.vm_size.lower() in paas_sizes and not namespace.os_family:
280            message = ("The selected VM size is incompatible with Virtual Machine Configuration. "
281                       "Please swap for the equivalent: Standard_A1 (small), Standard_A2 "
282                       "(medium), Standard_A3 (large), or Standard_A4 (extra large).")
283            raise ValueError(message)
284        if namespace.auto_scale_formula:
285            namespace.enable_auto_scale = True
286
287
288def validate_cert_settings(namespace):
289    """Custom parsing for certificate commands - adds default thumbprint
290    algorithm.
291    """
292    namespace.thumbprint_algorithm = 'sha1'
293
294
295def validate_client_parameters(cmd, namespace):
296    """Retrieves Batch connection parameters from environment variables"""
297    from azure.mgmt.batch import BatchManagementClient
298    from azure.cli.core.commands.client_factory import get_mgmt_service_client
299
300    # simply try to retrieve the remaining variables from environment variables
301    if not namespace.account_name:
302        namespace.account_name = cmd.cli_ctx.config.get('batch', 'account', None)
303    if not namespace.account_key:
304        namespace.account_key = cmd.cli_ctx.config.get('batch', 'access_key', None)
305    if not namespace.account_endpoint:
306        namespace.account_endpoint = cmd.cli_ctx.config.get('batch', 'endpoint', None)
307
308    # Simple validation for account_endpoint
309    if not (namespace.account_endpoint.startswith('https://') or
310            namespace.account_endpoint.startswith('http://')):
311        namespace.account_endpoint = 'https://' + namespace.account_endpoint
312    namespace.account_endpoint = namespace.account_endpoint.rstrip('/')
313    # if account name is specified but no key, attempt to query if we use shared key auth
314    if namespace.account_name and namespace.account_endpoint and not namespace.account_key:
315        if cmd.cli_ctx.config.get('batch', 'auth_mode', 'shared_key') == 'shared_key':
316            endpoint = urlsplit(namespace.account_endpoint)
317            host = endpoint.netloc
318            client = get_mgmt_service_client(cmd.cli_ctx, BatchManagementClient)
319            acc = next((x for x in client.batch_account.list()
320                        if x.name == namespace.account_name and x.account_endpoint == host), None)
321            if acc:
322                from msrestazure.tools import parse_resource_id
323                rg = parse_resource_id(acc.id)['resource_group']
324                namespace.account_key = \
325                    client.batch_account.get_keys(rg,  # pylint: disable=no-member
326                                                  namespace.account_name).primary
327            else:
328                raise ValueError("Batch account '{}' not found.".format(namespace.account_name))
329    else:
330        if not namespace.account_name:
331            raise ValueError("Specify batch account in command line or environment variable.")
332        if not namespace.account_endpoint:
333            raise ValueError("Specify batch endpoint in command line or environment variable.")
334
335    if cmd.cli_ctx.config.get('batch', 'auth_mode', 'shared_key') == 'aad':
336        namespace.account_key = None
337