1# -------------------------------------------------------------------------------------------- 2# Copyright (c) Microsoft Corporation. All rights reserved. 3# Licensed under the MIT License. See License.txt in the project root for license information. 4# -------------------------------------------------------------------------------------------- 5 6import os 7import azure.batch.models 8from azure.cli.core.util import get_file_json 9from six.moves.urllib.parse import urlsplit # pylint: disable=import-error 10 11 12# TYPES VALIDATORS 13 14def datetime_format(value): 15 """Validate the correct format of a datetime string and deserialize.""" 16 from msrest.serialization import Deserializer 17 from msrest.exceptions import DeserializationError 18 try: 19 datetime_obj = Deserializer.deserialize_iso(value) 20 except DeserializationError: 21 message = "Argument {} is not a valid ISO-8601 datetime format" 22 raise ValueError(message.format(value)) 23 return datetime_obj 24 25 26def disk_encryption_target_format(value): 27 """Space seperated target disks to be encrypted. Values can either be OsDisk or TemporaryDisk""" 28 if value == 'OsDisk': 29 return azure.batch.models.DiskEncryptionTarget.os_disk 30 if value == 'TemporaryDisk': 31 return azure.batch.models.DiskEncryptionTarget.temporary_disk 32 message = 'Argument {} is not a valid disk_encryption_target' 33 raise ValueError(message.format(value)) 34 35 36def disk_encryption_configuration_format(value): 37 targets = value.split(' ') 38 parsed_targets = [] 39 for target in targets: 40 parsed_targets.append(disk_encryption_target_format(target)) 41 return targets 42 43 44def duration_format(value): 45 """Validate the correct format of a timespan string and deserilize.""" 46 from msrest.serialization import Deserializer 47 from msrest.exceptions import DeserializationError 48 try: 49 duration_obj = Deserializer.deserialize_duration(value) 50 except DeserializationError: 51 message = "Argument {} is not in a valid ISO-8601 duration format" 52 raise ValueError(message.format(value)) 53 return duration_obj 54 55 56def metadata_item_format(value): 57 """Space-separated values in 'key=value' format.""" 58 try: 59 data_name, data_value = value.split('=') 60 except ValueError: 61 message = ("Incorrectly formatted metadata. " 62 "Argument values should be in the format a=b c=d") 63 raise ValueError(message) 64 return {'name': data_name, 'value': data_value} 65 66 67def environment_setting_format(value): 68 """Space-separated values in 'key=value' format.""" 69 try: 70 env_name, env_value = value.split('=') 71 except ValueError: 72 message = ("Incorrectly formatted environment settings. " 73 "Argument values should be in the format a=b c=d") 74 raise ValueError(message) 75 return {'name': env_name, 'value': env_value} 76 77 78def application_package_reference_format(value): 79 """Space-separated application IDs with optional version in 'id[#version]' format.""" 80 app_reference = value.split('#', 1) 81 package = {'application_id': app_reference[0]} 82 try: 83 package['version'] = app_reference[1] 84 except IndexError: # No specified version - ignore 85 pass 86 return package 87 88 89def certificate_reference_format(value): 90 """Space-separated certificate thumbprints.""" 91 cert = {'thumbprint': value, 'thumbprint_algorithm': 'sha1'} 92 return cert 93 94 95def task_id_ranges_format(value): 96 """Space-separated number ranges in 'start-end' format.""" 97 try: 98 start, end = [int(i) for i in value.split('-')] 99 except ValueError: 100 message = ("Incorrectly formatted task ID range. " 101 "Argument values should be numbers in the format 'start-end'") 102 raise ValueError(message) 103 return {'start': start, 'end': end} 104 105 106def resource_file_format(value): 107 """Space-separated resource references in filename=httpurl format.""" 108 try: 109 file_name, http_url = value.split('=', 1) 110 except ValueError: 111 message = ("Incorrectly formatted resource reference. " 112 "Argument values should be in the format filename=httpurl") 113 raise ValueError(message) 114 return {'file_path': file_name, 'http_url': http_url} 115 116 117# COMMAND NAMESPACE VALIDATORS 118 119def validate_required_parameter(namespace, parser): 120 """Validates required parameters in Batch complex objects""" 121 if not parser.done: 122 parser.parse(namespace) 123 124 125def storage_account_id(cmd, namespace): 126 """Validate storage account name""" 127 from azure.cli.core.profiles import ResourceType 128 from azure.cli.core.commands.client_factory import get_mgmt_service_client 129 130 if (namespace.storage_account and not 131 ('/providers/Microsoft.ClassicStorage/storageAccounts/' in namespace.storage_account or 132 '/providers/Microsoft.Storage/storageAccounts/' in namespace.storage_account)): 133 storage_client = get_mgmt_service_client(cmd.cli_ctx, ResourceType.MGMT_STORAGE) 134 acc = storage_client.storage_accounts.get_properties(namespace.resource_group_name, 135 namespace.storage_account) 136 if not acc: 137 raise ValueError("Storage account named '{}' not found in the resource group '{}'.". 138 format(namespace.storage_account, namespace.resource_group_name)) 139 namespace.storage_account = acc.id # pylint: disable=no-member 140 141 142def keyvault_id(cmd, namespace): 143 """Validate storage account name""" 144 from azure.cli.core.profiles import ResourceType 145 from azure.cli.core.commands.client_factory import get_mgmt_service_client 146 if not namespace.keyvault: 147 return 148 if '/providers/Microsoft.KeyVault/vaults/' in namespace.keyvault: 149 resource = namespace.keyvault.split('/') 150 kv_name = resource[resource.index('Microsoft.KeyVault') + 2] 151 kv_rg = resource[resource.index('resourceGroups') + 1] 152 else: 153 kv_name = namespace.keyvault 154 kv_rg = namespace.resource_group_name 155 try: 156 keyvault_client = get_mgmt_service_client(cmd.cli_ctx, ResourceType.MGMT_KEYVAULT) 157 vault = keyvault_client.vaults.get(kv_rg, kv_name) 158 if not vault: 159 raise ValueError("KeyVault named '{}' not found in the resource group '{}'.". 160 format(kv_name, kv_rg)) 161 namespace.keyvault = vault.id # pylint: disable=no-member 162 namespace.keyvault_url = vault.properties.vault_uri 163 except Exception as exp: 164 raise ValueError('Invalid KeyVault reference: {}\n{}'.format(namespace.keyvault, exp)) 165 166 167def application_enabled(cmd, namespace): 168 """Validates account has auto-storage enabled""" 169 from azure.mgmt.batch import BatchManagementClient 170 from azure.cli.core.commands.client_factory import get_mgmt_service_client 171 172 client = get_mgmt_service_client(cmd.cli_ctx, BatchManagementClient) 173 acc = client.batch_account.get(namespace.resource_group_name, namespace.account_name) 174 if not acc: 175 raise ValueError("Batch account '{}' not found.".format(namespace.account_name)) 176 if not acc.auto_storage or not acc.auto_storage.storage_account_id: # pylint: disable=no-member 177 raise ValueError("Batch account '{}' needs auto-storage enabled.". 178 format(namespace.account_name)) 179 180 181def validate_pool_resize_parameters(namespace): 182 """Validate pool resize parameters correct""" 183 if not namespace.abort and not namespace.target_dedicated_nodes: 184 raise ValueError("The target-dedicated-nodes parameter is required to resize the pool.") 185 186 187def validate_json_file(namespace): 188 """Validate the give json file existing""" 189 if namespace.json_file: 190 try: 191 get_file_json(namespace.json_file) 192 except EnvironmentError: 193 raise ValueError("Cannot access JSON request file: " + namespace.json_file) 194 except ValueError as err: 195 raise ValueError("Invalid JSON file: {}".format(err)) 196 197 198def validate_cert_file(namespace): 199 """Validate the give cert file existing""" 200 try: 201 with open(namespace.certificate_file, "rb"): 202 pass 203 except EnvironmentError: 204 raise ValueError("Cannot access certificate file: " + namespace.certificate_file) 205 206 207def validate_options(namespace): 208 """Validate any flattened request header option arguments.""" 209 try: 210 start = namespace.start_range 211 end = namespace.end_range 212 except AttributeError: 213 return 214 else: 215 namespace.ocp_range = None 216 del namespace.start_range 217 del namespace.end_range 218 if start or end: 219 start = start if start else 0 220 end = end if end else "" 221 namespace.ocp_range = "bytes={}-{}".format(start, end) 222 223 224def validate_file_destination(namespace): 225 """Validate the destination path for a file download.""" 226 try: 227 path = namespace.destination 228 except AttributeError: 229 return 230 else: 231 # TODO: Need to confirm this logic... 232 file_path = path 233 file_dir = os.path.dirname(path) 234 if os.path.isdir(path): 235 file_name = os.path.basename(namespace.file_name) 236 file_path = os.path.join(path, file_name) 237 elif not os.path.isdir(file_dir): 238 try: 239 os.mkdir(file_dir) 240 except EnvironmentError as exp: 241 message = "Directory {} does not exist, and cannot be created: {}" 242 raise ValueError(message.format(file_dir, exp)) 243 if os.path.isfile(file_path): 244 raise ValueError("File {} already exists.".format(file_path)) 245 namespace.destination = file_path 246 247# CUSTOM REQUEST VALIDATORS 248 249 250def validate_pool_settings(namespace, parser): 251 """Custom parsing to enfore that either PaaS or IaaS instances are configured 252 in the add pool request body. 253 """ 254 if not namespace.json_file: 255 if namespace.node_agent_sku_id and not namespace.image: 256 raise ValueError("Missing required argument: --image") 257 if namespace.image: 258 try: 259 namespace.publisher, namespace.offer, namespace.sku = namespace.image.split(':', 2) 260 try: 261 namespace.sku, namespace.version = namespace.sku.split(':', 1) 262 except ValueError: 263 pass 264 except ValueError: 265 if '/' not in namespace.image: 266 message = ("Incorrect format for VM image. Should be in the format: \n" 267 "'publisher:offer:sku[:version]' OR a URL to an ARM image.") 268 raise ValueError(message) 269 270 namespace.virtual_machine_image_id = namespace.image 271 del namespace.image 272 if namespace.disk_encryption_targets: 273 namespace.targets = namespace.disk_encryption_targets 274 del namespace.disk_encryption_targets 275 groups = ['pool.cloud_service_configuration', 'pool.virtual_machine_configuration'] 276 parser.parse_mutually_exclusive(namespace, True, groups) 277 278 paas_sizes = ['small', 'medium', 'large', 'extralarge'] 279 if namespace.vm_size and namespace.vm_size.lower() in paas_sizes and not namespace.os_family: 280 message = ("The selected VM size is incompatible with Virtual Machine Configuration. " 281 "Please swap for the equivalent: Standard_A1 (small), Standard_A2 " 282 "(medium), Standard_A3 (large), or Standard_A4 (extra large).") 283 raise ValueError(message) 284 if namespace.auto_scale_formula: 285 namespace.enable_auto_scale = True 286 287 288def validate_cert_settings(namespace): 289 """Custom parsing for certificate commands - adds default thumbprint 290 algorithm. 291 """ 292 namespace.thumbprint_algorithm = 'sha1' 293 294 295def validate_client_parameters(cmd, namespace): 296 """Retrieves Batch connection parameters from environment variables""" 297 from azure.mgmt.batch import BatchManagementClient 298 from azure.cli.core.commands.client_factory import get_mgmt_service_client 299 300 # simply try to retrieve the remaining variables from environment variables 301 if not namespace.account_name: 302 namespace.account_name = cmd.cli_ctx.config.get('batch', 'account', None) 303 if not namespace.account_key: 304 namespace.account_key = cmd.cli_ctx.config.get('batch', 'access_key', None) 305 if not namespace.account_endpoint: 306 namespace.account_endpoint = cmd.cli_ctx.config.get('batch', 'endpoint', None) 307 308 # Simple validation for account_endpoint 309 if not (namespace.account_endpoint.startswith('https://') or 310 namespace.account_endpoint.startswith('http://')): 311 namespace.account_endpoint = 'https://' + namespace.account_endpoint 312 namespace.account_endpoint = namespace.account_endpoint.rstrip('/') 313 # if account name is specified but no key, attempt to query if we use shared key auth 314 if namespace.account_name and namespace.account_endpoint and not namespace.account_key: 315 if cmd.cli_ctx.config.get('batch', 'auth_mode', 'shared_key') == 'shared_key': 316 endpoint = urlsplit(namespace.account_endpoint) 317 host = endpoint.netloc 318 client = get_mgmt_service_client(cmd.cli_ctx, BatchManagementClient) 319 acc = next((x for x in client.batch_account.list() 320 if x.name == namespace.account_name and x.account_endpoint == host), None) 321 if acc: 322 from msrestazure.tools import parse_resource_id 323 rg = parse_resource_id(acc.id)['resource_group'] 324 namespace.account_key = \ 325 client.batch_account.get_keys(rg, # pylint: disable=no-member 326 namespace.account_name).primary 327 else: 328 raise ValueError("Batch account '{}' not found.".format(namespace.account_name)) 329 else: 330 if not namespace.account_name: 331 raise ValueError("Specify batch account in command line or environment variable.") 332 if not namespace.account_endpoint: 333 raise ValueError("Specify batch endpoint in command line or environment variable.") 334 335 if cmd.cli_ctx.config.get('batch', 'auth_mode', 'shared_key') == 'aad': 336 namespace.account_key = None 337