1# --------------------------------------------------------------------------------------------
2# Copyright (c) Microsoft Corporation. All rights reserved.
3# Licensed under the MIT License. See License.txt in the project root for license information.
4# --------------------------------------------------------------------------------------------
5
6import re
7from six import string_types
8
9from knack.arguments import CLICommandArgument, IgnoreAction
10from knack.introspection import extract_full_summary_from_signature, extract_args_from_signature
11
12from azure.cli.command_modules.batch import _validators as validators
13from azure.cli.command_modules.batch import _format as transformers
14from azure.cli.command_modules.batch import _parameter_format as pformat
15
16from azure.cli.core import EXCLUDED_PARAMS
17from azure.cli.core.commands import CONFIRM_PARAM_NAME
18from azure.cli.core.commands import AzCommandGroup
19from azure.cli.core.util import get_file_json
20
21
22_CLASS_NAME = re.compile(r"~(.*)")  # Strip model name from class docstring
23_UNDERSCORE_CASE = re.compile('(?!^)([A-Z]+)')  # Convert from CamelCase to underscore_case
24
25
26def _load_model(name):
27    """Load a model class from the SDK in order to inspect for
28    parameter names and whether they're required.
29    :param str name: The model class name to load.
30    :returns: Model class
31    """
32    if name.startswith('azure.'):
33        namespace = name.split('.')
34    else:
35        namespace = ['azure', 'batch', 'models', name]
36    model = __import__(namespace[0])
37    for level in namespace[1:]:
38        model = getattr(model, level)
39    return model
40
41
42def _join_prefix(prefix, name):
43    """Filter certain superflous parameter name suffixes
44    from argument names.
45    :param str prefix: The potential prefix that will be filtered.
46    :param str name: The arg name to be prefixed.
47    :returns: Combined name with prefix.
48    """
49    if prefix.endswith("_specification"):
50        return prefix[:-14] + "_" + name
51    if prefix.endswith("_patch_parameter"):
52        return prefix[:-16] + "_" + name
53    if prefix.endswith("_update_parameter"):
54        return prefix[:-17] + "_" + name
55    return prefix + "_" + name
56
57
58def _build_prefix(arg, param, path):
59    """Recursively build a command line argument prefix from the request
60    parameter object to avoid name conflicts.
61    :param str arg: Current argument name.
62    :param str param: Original request parameter name.
63    :param str path: Request parameter namespace.
64    """
65    prefix_list = path.split('.')
66    if len(prefix_list) == 1:
67        return arg
68    resolved_name = _join_prefix(prefix_list[0], param)
69    if arg == resolved_name:
70        return arg
71    for prefix in prefix_list[1:]:
72        new_name = _join_prefix(prefix, param)
73        if new_name == arg:
74            return resolved_name
75        resolved_name = new_name
76    return resolved_name
77
78
79def find_param_type(model, param):
80    """Parse the parameter type from the model docstring.
81    :param class model: Model class.
82    :param str param: The name of the parameter.
83    :returns: str
84    """
85    # Search for the :type param_name: in the docstring
86    pattern = r":type {}:(.*?)\n(\s*:param |\s*:rtype:|\s*:raises:|\s*\"{{3}})".format(param)
87    param_type = re.search(pattern, model.__doc__, re.DOTALL)
88    return re.sub(r"\n\s*", "", param_type.group(1).strip())
89
90
91def find_param_help(model, param):
92    """Parse the parameter help info from the model docstring.
93    :param class model: Model class.
94    :param str param: The name of the parameter.
95    :returns: str
96    """
97    # Search for :param param_name: in the docstring
98    pattern = r":param {}:(.*?)\n\s*:type ".format(param)
99    param_doc = re.search(pattern, model.__doc__, re.DOTALL)
100    return re.sub(r"\n\s*", " ", param_doc.group(1).strip())
101
102
103# pylint: disable=inconsistent-return-statements
104def find_return_type(model):
105    """Parse the parameter help info from the model docstring.
106    :param class model: Model class.
107    :returns: str
108    """
109    # Search for :rtype: in the docstring
110    pattern = r':rtype: (.*?)( or)?\n.*(:raises:)?'
111    return_type = re.search(pattern, model.__doc__, re.DOTALL)
112    if return_type:
113        return re.sub(r"\n\s*", "", return_type.group(1))
114
115
116def enum_value(enum_str):
117    """Strip chars around enum value str.
118    :param str enum_str: Enum value.
119    """
120    return enum_str.strip(' \'').lower()
121
122
123def class_name(type_str):
124    """Extract class name from type docstring.
125    :param str type_str: Parameter type docstring.
126    :returns: class name
127    """
128    return _CLASS_NAME.findall(type_str)[0]
129
130
131def operations_name(class_str):
132    """Convert the operations class name into Python case.
133    :param str class_str: The class name.
134    """
135    if class_str.endswith('Operations'):
136        class_str = class_str[:-10]
137    return _UNDERSCORE_CASE.sub(r'_\1', class_str).lower()
138
139
140def full_name(arg_details):
141    """Create a full path to the complex object parameter of a
142    given argument.
143    :param dict arg_details: The details of the argument.
144    :returns: str
145    """
146    return ".".join([arg_details['path'], arg_details['root']])
147
148
149def group_title(path):
150    """Create a group title from the argument path.
151    :param str path: The complex object path of the argument.
152    :returns: str
153    """
154
155    def filter_group(group):
156        for suffix in ["_patch_parameter", "_update_parameter", "_parameter"]:
157            if group.endswith(suffix):
158                group = group[:0 - len(suffix)]
159        return group
160
161    group_path = path.split('.')
162    group_path = list(map(filter_group, group_path))
163    title = ': '.join(group_path)
164    for each in group_path:
165        title = title.replace(each, " ".join([n.title() for n in each.split('_')]), 1)
166    return title
167
168
169def arg_name(name):
170    """Convert snake case argument name to a command line name.
171    :param str name: The argument parameter name.
172    :returns: str
173    """
174    return "--" + name.replace('_', '-')
175
176
177def format_options_name(operation):
178    """Format the name of the request options parameter from the
179    operation name and path.
180    :param str operation: Operation path
181    :returns: str - options parameter name.
182    """
183    operation = operation.split('#')[-1]
184    op_class, op_function = operation.split('.')
185    op_class = operations_name(op_class)
186    return "{}_{}_options".format(op_class, op_function)
187
188
189class BatchArgumentTree:
190    """Dependency tree parser for arguments of complex objects"""
191
192    def __init__(self, validator):
193        self._arg_tree = {}
194        self._request_param = {}
195        self._custom_validator = validator
196        self.done = False
197
198    def __iter__(self):
199        """Iterate over arguments"""
200        for arg, details in self._arg_tree.items():
201            yield arg, details
202
203    def _is_silent(self, name):
204        """Whether argument should not be exposed"""
205        arg = self._arg_tree[name]
206        full_path = full_name(arg)
207        return arg['path'] in pformat.SILENT_PARAMETERS or full_path in pformat.SILENT_PARAMETERS
208
209    def _is_bool(self, name):
210        """Whether argument value is a boolean"""
211        return self._arg_tree[name]['type'] == 'bool'
212
213    def _is_list(self, name):
214        """Whether argument value is a list"""
215        return self._arg_tree[name]['type'].startswith('[')
216
217    def _is_datetime(self, name):
218        """Whether argument value is a timestamp"""
219        return self._arg_tree[name]['type'] in ['iso-8601', 'rfc-1123']
220
221    def _is_duration(self, name):
222        """Whether argument is value is a duration"""
223        return self._arg_tree[name]['type'] == 'duration'
224
225    def _help(self, name, text):
226        """Append phrase to existing help text"""
227        self._arg_tree[name]['options']['help'] += ' ' + text
228
229    def _get_children(self, group):
230        """Find all the arguments under to a specific complex argument group.
231        :param str group: The namespace of the complex parameter.
232        :returns: The names of the related arugments.
233        """
234        return [arg for arg, value in self._arg_tree.items() if value['path'].startswith(group)]
235
236    def _get_siblings(self, group):
237        """Find all the arguments at the same level of a specific complex argument group.
238        :param str group: The namespace of the complex parameter.
239        :returns: The names of the related arugments.
240        """
241        return [arg for arg, value in self._arg_tree.items() if value['path'] == group]
242
243    def _parse(self, namespace, path, required):
244        """Parse dependency tree to list all required command line arguments based on
245        current inputs.
246        :param namespace: The namespace container all current argument inputs
247        :param path: The current complex object path
248        :param required: Whether the args in this object path are required
249        """
250        required_args = []
251        children = self._get_children(path)
252        if not required:
253            if not any(getattr(namespace, n) for n in children):
254                return []
255        siblings = self._get_siblings(path)
256        if not siblings:
257            raise ValueError("Invalid argument dependency tree")  # TODO
258        dependencies = self._arg_tree[siblings[0]]['dependencies']
259        for child_arg in children:
260            if child_arg in required_args:
261                continue
262            details = self._arg_tree[child_arg]
263            if full_name(details) in dependencies:
264                required_args.append(child_arg)
265            elif details['path'] in dependencies:
266                required_args.extend(self._parse(namespace, details['path'], True))
267            elif details['path'] == path:
268                continue
269            else:
270                required_args.extend(self._parse(namespace, details['path'], False))
271        return set(required_args)
272
273    def set_request_param(self, name, model):
274        """Set the name of the parameter that will be serialized for the
275        request body.
276        :param str name: The name of the parameter
277        :param str model: The name of the class
278        """
279        self._request_param['name'] = name
280        self._request_param['model'] = model.split('.')[-1]
281
282    def deserialize_json(self, kwargs, json_obj):
283        """Deserialize the contents of a JSON file into the request body
284        parameter.
285        :param dict kwargs: The request kwargs
286        :param dict json_obj: The loaded JSON content
287        """
288        from msrest.exceptions import DeserializationError
289        message = "Failed to deserialized JSON file into object {}"
290        try:
291            import azure.batch.models
292            model_type = getattr(azure.batch.models, self._request_param['model'])
293            # Use from_dict in order to deserialize with case insensitive
294            kwargs[self._request_param['name']] = model_type.from_dict(json_obj)
295        except DeserializationError as error:
296            message += ": {}".format(error)
297            raise ValueError(message.format(self._request_param['model']))
298        else:
299            if kwargs[self._request_param['name']] is None:
300                raise ValueError(message.format(self._request_param['model']))
301
302    def queue_argument(self, name=None, path=None, root=None,
303                       options=None, type=None,  # pylint: disable=redefined-builtin
304                       dependencies=None):
305        """Add pending command line argument
306        :param str name: The name of the command line argument.
307        :param str path: The complex object path to the parameter.
308        :param str root: The original name of the parameter.
309        :param dict options: The kwargs to be used to instantiate CLICommandArgument.
310        :param list dependencies: A list of complete paths to other parameters that
311         are required if this parameter is set.
312        """
313        self._arg_tree[name] = {
314            'path': path,
315            'root': root,
316            'options': options,
317            'type': type,
318            'dependencies': [".".join([path, arg]) for arg in dependencies]
319        }
320
321    def dequeue_argument(self, name):
322        """Remove pending command line argument for modification
323        :param str name: The command line argument to remove.
324        :returns: The details of the argument.
325        """
326        return self._arg_tree.pop(name, {})
327
328    def compile_args(self):
329        """Generator to convert pending arguments into CLICommandArgument
330        objects.
331        """
332        for name, details in self._arg_tree.items():
333            if self._is_bool(name):
334                if self._request_param['name'].endswith('patch_parameter'):
335                    self._help(name, "Specify either 'true' or 'false' to update the property.")
336                else:
337                    details['options']['action'] = 'store_true'
338                    self._help(name, "True if flag present.")
339            elif self._is_list(name):
340                details['options']['nargs'] = '+'
341            elif self._is_datetime(name):
342                details['options']['type'] = validators.datetime_format
343                self._help(name, "Expected format is an ISO-8601 timestamp.")
344            elif self._is_duration(name):
345                details['options']['type'] = validators.duration_format
346                self._help(name, "Expected format is an ISO-8601 duration.")
347            elif self._is_silent(name):
348                import argparse
349                details['options']['nargs'] = '?'
350                details['options']['help'] = argparse.SUPPRESS
351                details['options']['required'] = False
352                details['options']['action'] = IgnoreAction
353            yield (name, CLICommandArgument(dest=name, **details['options']))
354
355    def existing(self, name):
356        """Whether the argument name is already used by a pending
357        argument.
358        :param str name: The name of the argument to check.
359        :returns: bool
360        """
361        return name in self._arg_tree
362
363    def parse_mutually_exclusive(self, namespace, required, params):
364        """Validate whether two or more mutually exclusive arguments or
365        argument groups have been set correctly.
366        :param bool required: Whether one of the parameters must be set.
367        :param list params: List of namespace paths for mutually exclusive
368         request properties.
369        """
370        argtree = self._arg_tree.items()
371        ex_arg_names = [a for a, v in argtree if full_name(v) in params]
372        ex_args = [getattr(namespace, a) for a, v in argtree if a in ex_arg_names]
373        ex_args = [x for x in ex_args if x is not None]
374        ex_group_names = []
375        ex_groups = []
376        for arg_group in params:
377            child_args = self._get_children(arg_group)
378            if child_args:
379                ex_group_names.append(group_title(arg_group))
380                if any(getattr(namespace, arg) for arg in child_args):
381                    ex_groups.append(ex_group_names[-1])
382
383        message = None
384        if not ex_groups and not ex_args and required:
385            message = "One of the following arguments, or argument groups are required: \n"
386        elif len(ex_groups) > 1 or len(ex_args) > 1 or (ex_groups and ex_args):
387            message = ("The follow arguments or argument groups are mutually "
388                       "exclusive and cannot be combined: \n")
389        if message:
390            missing = [arg_name(n) for n in ex_arg_names] + ex_group_names
391            message += '\n'.join(missing)
392            raise ValueError(message)
393
394    def parse(self, namespace):
395        """Parse all arguments in the namespace to validate whether all required
396        arguments have been set.
397        :param namespace: The namespace object.
398        :raises: ValueError if a require argument was not set.
399        """
400        if self._custom_validator:
401            try:
402                self._custom_validator(namespace, self)
403            except TypeError:
404                raise ValueError("Custom validator must be a function that takes two arguments.")
405        try:
406            if namespace.json_file:
407                try:
408                    namespace.json_file = get_file_json(namespace.json_file)
409                except EnvironmentError:
410                    raise ValueError("Cannot access JSON request file: " + namespace.json_file)
411                except ValueError as err:
412                    raise ValueError("Invalid JSON file: {}".format(err))
413                other_values = [arg_name(n) for n in self._arg_tree if getattr(namespace, n)]
414                if other_values:
415                    message = "--json-file cannot be combined with:\n"
416                    raise ValueError(message + '\n'.join(other_values))
417                self.done = True
418                return
419        except AttributeError:
420            pass
421        required_args = self._parse(namespace, self._request_param['name'], True)
422        missing_args = [n for n in required_args if not getattr(namespace, n)]
423        if missing_args:
424            message = "The following additional arguments are required:\n"
425            message += "\n".join([arg_name(m) for m in missing_args])
426            raise ValueError(message)
427        self.done = True
428
429
430class AzureBatchDataPlaneCommand:
431    # pylint: disable=too-many-instance-attributes, too-few-public-methods, too-many-statements
432    def __init__(self, operation, command_loader, client_factory=None, validator=None, **kwargs):
433
434        if not isinstance(operation, string_types):
435            raise ValueError("Operation must be a string. Got '{}'".format(operation))
436
437        self._flatten = kwargs.pop('flatten', pformat.FLATTEN)  # Number of object levels to flatten
438        self._head_cmd = False
439
440        self.parser = None
441        self.validator = validator
442        self.client_factory = client_factory
443        self.confirmation = 'delete' in operation
444        self._operation_func = None
445
446        # The name of the request options parameter
447        self._options_param = format_options_name(operation)
448        # Arguments used for request options
449        self._options_attrs = []
450        # The loaded options model to populate for the request
451        self._options_model = None
452
453        def _get_operation():
454            if not self._operation_func:
455                self._operation_func = command_loader.get_op_handler(operation)
456
457            return self._operation_func
458
459        def _load_arguments():
460            return self._load_transformed_arguments(_get_operation())
461
462        def _load_descriptions():
463            return extract_full_summary_from_signature(_get_operation())
464
465        # pylint: disable=inconsistent-return-statements
466        def _execute_command(kwargs):
467            from msrest.paging import Paged
468            from msrest.exceptions import ValidationError, ClientRequestError
469            from azure.batch.models import BatchErrorException
470            from knack.util import CLIError
471            cmd = kwargs.pop('cmd')
472
473            try:
474                client = self.client_factory(cmd.cli_ctx, kwargs)
475                self._build_options(kwargs)
476
477                stream_output = kwargs.pop('destination', None)
478                json_file = kwargs.pop('json_file', None)
479
480                # Build the request parameters from command line arguments
481                if json_file:
482                    self.parser.deserialize_json(kwargs, json_file)
483                    for arg, _ in self.parser:
484                        del kwargs[arg]
485                else:
486                    for arg, details in self.parser:
487                        try:
488                            param_value = kwargs.pop(arg)
489                            if param_value is None:
490                                continue
491                            self._build_parameters(
492                                details['path'],
493                                kwargs,
494                                details['root'],
495                                param_value)
496                        except KeyError:
497                            continue
498
499                # Make request
500                if self._head_cmd:
501                    kwargs['raw'] = True
502                result = _get_operation()(client, **kwargs)
503
504                # Head output
505                if self._head_cmd:
506                    return transformers.transform_response_headers(result)
507
508                # File download
509                if stream_output:
510                    with open(stream_output, "wb") as file_handle:
511                        for data in result:
512                            file_handle.write(data)
513                    return
514
515                # Otherwise handle based on return type of results
516                if isinstance(result, Paged):
517                    return list(result)
518
519                return result
520            except BatchErrorException as ex:
521                try:
522                    message = ex.error.message.value
523                    if ex.error.values:
524                        for detail in ex.error.values:
525                            message += "\n{}: {}".format(detail.key, detail.value)
526                    raise CLIError(message)
527                except AttributeError:
528                    raise CLIError(ex)
529            except (ValidationError, ClientRequestError) as ex:
530                raise CLIError(ex)
531
532        self.table_transformer = None
533        try:
534            transform_func = operation.split('.')[-1].replace('-', '_')
535            self.table_transformer = getattr(transformers, transform_func + "_table_format")
536        except AttributeError:
537            pass
538
539        self.handler = _execute_command
540        self.argument_loader = _load_arguments
541        self.description_loader = _load_descriptions
542        self.merged_kwargs = kwargs
543
544    def get_kwargs(self):
545        args = {
546            'handler': self.handler,
547            'argument_loader': self.argument_loader,
548            'description_loader': self.description_loader,
549            'table_transformer': self.table_transformer,
550            'confirmation': self.confirmation,
551            'client_factory': self.client_factory
552        }
553        args.update(self.merged_kwargs)
554        return args
555
556    def _build_parameters(self, path, kwargs, param, value):
557        """Recursively build request parameter dictionary from command line args.
558        :param str path: Current parameter namespace.
559        :param dict kwargs: The request arguments being built.
560        :param param: The name of the request parameter.
561        :param value: The value of the request parameter.
562        """
563        keys = path.split('.')
564        if keys[0] not in kwargs:
565            kwargs[keys[0]] = {}
566        if len(keys) < 2:
567            kwargs[keys[0]][param] = value
568        else:
569            self._build_parameters('.'.join(keys[1:]), kwargs[keys[0]], param, value)
570
571        path = param[0]
572        return path.split('.')[0]
573
574    def _build_options(self, kwargs):
575        """Build request options model from command line arguments.
576        :param dict kwargs: The request arguments being built.
577        """
578        kwargs[self._options_param] = self._options_model
579        for param in self._options_attrs:
580            if param in pformat.IGNORE_OPTIONS:
581                continue
582            param_value = kwargs.pop(param)
583            if param_value is None:
584                continue
585            setattr(kwargs[self._options_param], param, param_value)
586
587    def _load_options_model(self, func_obj):
588        """Load the request headers options model to gather arguments.
589        :param func func_obj: The request function.
590        """
591        option_type = find_param_type(func_obj, self._options_param)
592        option_type = class_name(option_type)
593        self._options_model = _load_model(option_type)()
594        self._options_attrs = list(self._options_model.__dict__.keys())
595
596    def _should_flatten(self, param):
597        """Check whether the current parameter object should be flattened.
598        :param str param: The parameter name with complete namespace.
599        :returns: bool
600        """
601        return param.count('.') < self._flatten and param not in pformat.IGNORE_PARAMETERS
602
603    def _get_attrs(self, model, path):
604        """Get all the attributes from the complex parameter model that should
605        be exposed as command line arguments.
606        :param class model: The parameter model class.
607        :param str path: Request parameter namespace.
608        """
609        for attr, details in model._attribute_map.items():  # pylint: disable=protected-access
610            conditions = []
611            full_path = '.'.join([self.parser._request_param['name'], path, attr])  # pylint: disable=protected-access
612            conditions.append(
613                model._validation.get(attr, {}).get('readonly'))  # pylint: disable=protected-access
614            conditions.append(
615                model._validation.get(attr, {}).get('constant'))  # pylint: disable=protected-access
616            conditions.append(any(i for i in pformat.IGNORE_PARAMETERS if i in full_path))
617            conditions.append(details['type'][0] in ['{'])
618            if not any(conditions):
619                yield attr, details
620
621    def _process_options(self):
622        """Process the request options parameter to expose as arguments."""
623        for param in [o for o in self._options_attrs if o not in pformat.IGNORE_OPTIONS]:
624            options = {}
625            options['required'] = False
626            options['arg_group'] = 'Pre-condition and Query'
627            if param in ['if_modified_since', 'if_unmodified_since']:
628                options['type'] = validators.datetime_format
629            if param in pformat.FLATTEN_OPTIONS:
630                for f_param, f_docstring in pformat.FLATTEN_OPTIONS[param].items():
631                    options['default'] = None
632                    options['help'] = f_docstring
633                    options['options_list'] = [arg_name(f_param)]
634                    options['validator'] = validators.validate_options
635                    yield (f_param, CLICommandArgument(f_param, **options))
636            else:
637                options['default'] = getattr(self._options_model, param)
638                options['help'] = find_param_help(self._options_model, param)
639                options['options_list'] = [arg_name(param)]
640                yield (param, CLICommandArgument(param, **options))
641
642    def _resolve_conflict(self,
643                          arg, param, path, options, typestr, dependencies, conflicting):
644        """Resolve conflicting command line arguments.
645        :param str arg: Name of the command line argument.
646        :param str param: Original request parameter name.
647        :param str path: Request parameter namespace.
648        :param dict options: The kwargs to be used to instantiate CLICommandArgument.
649        :param list dependencies: A list of complete paths to other parameters that are required
650         if this parameter is set.
651        :param list conflicting: A list of the argument names that have already conflicted.
652        """
653        if self.parser.existing(arg):
654            conflicting.append(arg)
655            existing = self.parser.dequeue_argument(arg)
656            existing['name'] = _build_prefix(arg, existing['root'], existing['path'])
657            existing['options']['options_list'] = [arg_name(existing['name'])]
658            self.parser.queue_argument(**existing)
659            new = _build_prefix(arg, param, path)
660            options['options_list'] = [arg_name(new)]
661            self._resolve_conflict(new, param, path, options, typestr, dependencies, conflicting)
662        elif arg in conflicting or arg in pformat.QUALIFIED_PROPERTIES:
663            new = _build_prefix(arg, param, path)
664            if new in conflicting or new in pformat.QUALIFIED_PROPERTIES and '.' not in path:
665                self.parser.queue_argument(arg, path, param, options, typestr, dependencies)
666            else:
667                options['options_list'] = [arg_name(new)]
668                self._resolve_conflict(new, param, path, options,
669                                       typestr, dependencies, conflicting)
670        else:
671            self.parser.queue_argument(arg, path, param, options, typestr, dependencies)
672
673    def _flatten_object(self, path, param_model, conflict_names=None):
674        """Flatten a complex parameter object into command line arguments.
675        :param str path: The complex parameter namespace.
676        :param class param_model: The complex parameter class.
677        :param list conflict_names: List of argument names that conflict.
678        """
679        conflict_names = conflict_names or []
680
681        if self._should_flatten(path):
682            validations = param_model._validation.items()  # pylint: disable=protected-access
683            required_attrs = [key for key, val in validations if val.get('required')]
684
685            for param_attr, details in self._get_attrs(param_model, path):
686                options = {}
687                options['options_list'] = [arg_name(param_attr)]
688                options['required'] = False
689                options['arg_group'] = group_title(path)
690                options['help'] = find_param_help(param_model, param_attr)
691                options['validator'] = \
692                    lambda ns: validators.validate_required_parameter(ns, self.parser)
693                options['default'] = None  # Extract details from signature
694
695                if details['type'] in pformat.BASIC_TYPES:
696                    self._resolve_conflict(param_attr, param_attr, path, options,
697                                           details['type'], required_attrs, conflict_names)
698                elif details['type'].startswith('['):
699                    # We only expose a list arg if there's a validator for it
700                    # This will fail for 2D arrays - though Batch doesn't have any yet
701                    inner_type = details['type'][1:-1]
702                    if inner_type in pformat.BASIC_TYPES:
703                        options['help'] += " Space-separated values."
704                        self._resolve_conflict(
705                            param_attr, param_attr, path, options,
706                            details['type'], required_attrs, conflict_names)
707                    else:
708                        inner_type = operations_name(inner_type)
709                        try:
710                            validator = getattr(validators, inner_type + "_format")
711                            options['help'] += ' ' + validator.__doc__
712                            options['type'] = validator
713                            self._resolve_conflict(
714                                param_attr, param_attr, path, options,
715                                details['type'], required_attrs, conflict_names)
716                        except AttributeError:
717                            continue
718                else:
719                    attr_model = _load_model(details['type'])
720                    if not hasattr(attr_model, '_attribute_map'):  # Must be an enum
721                        values_index = options['help'].find(' Possible values include')
722                        if values_index >= 0:
723                            choices = options['help'][values_index + 25:].split(', ')
724                            options['choices'] = [enum_value(c)
725                                                  for c in choices if enum_value(c) != "unmapped"]
726                            options['help'] = options['help'][0:values_index]
727                        self._resolve_conflict(param_attr, param_attr, path, options,
728                                               details['type'], required_attrs, conflict_names)
729                    else:
730                        self._flatten_object('.'.join([path, param_attr]), attr_model)
731
732    def _load_transformed_arguments(self, handler):
733        """Load all the command line arguments from the request parameters.
734        :param func handler: The operation function.
735        """
736        from azure.cli.core.commands.parameters import file_type
737        from argcomplete.completers import FilesCompleter, DirectoriesCompleter
738        self.parser = BatchArgumentTree(self.validator)
739        self._load_options_model(handler)
740        args = []
741        for arg in extract_args_from_signature(handler, excluded_params=EXCLUDED_PARAMS):
742            arg_type = find_param_type(handler, arg[0])
743            if arg[0] == self._options_param:
744                for option_arg in self._process_options():
745                    args.append(option_arg)
746            elif arg_type.startswith("str or"):
747                docstring = find_param_help(handler, arg[0])
748                choices = []
749                values_index = docstring.find(' Possible values include')
750                if values_index >= 0:
751                    choices = docstring[values_index + 25:].split(', ')
752                    choices = [enum_value(c) for c in choices if enum_value(c) != "'unmapped'"]
753                    docstring = docstring[0:values_index]
754                args.append(((arg[0], CLICommandArgument(arg[0],
755                                                         options_list=[arg_name(arg[0])],
756                                                         required=False,
757                                                         default=None,
758                                                         choices=choices,
759                                                         help=docstring))))
760            elif arg_type.startswith("~"):  # TODO: could add handling for enums
761                param_type = class_name(arg_type)
762                self.parser.set_request_param(arg[0], param_type)
763                param_model = _load_model(param_type)
764                self._flatten_object(arg[0], param_model)
765                for flattened_arg in self.parser.compile_args():
766                    args.append(flattened_arg)
767                param = 'json_file'
768                docstring = "A file containing the {} specification in JSON " \
769                            "(formatted to match the respective REST API body). " \
770                            "If this parameter is specified, all '{} Arguments'" \
771                            " are ignored.".format(arg[0].replace('_', ' '), group_title(arg[0]))
772                args.append((param, CLICommandArgument(param,
773                                                       options_list=[arg_name(param)],
774                                                       required=False,
775                                                       default=None,
776                                                       type=file_type,
777                                                       completer=FilesCompleter(),
778                                                       help=docstring)))
779            elif arg[0] not in pformat.IGNORE_PARAMETERS:
780                args.append(arg)
781        return_type = find_return_type(handler)
782        if return_type and return_type.startswith('Generator'):
783            param = 'destination'
784            docstring = "The path to the destination file or directory."
785            args.append((param, CLICommandArgument(param,
786                                                   options_list=[arg_name(param)],
787                                                   required=True,
788                                                   default=None,
789                                                   completer=DirectoriesCompleter(),
790                                                   type=file_type,
791                                                   validator=validators.validate_file_destination,
792                                                   help=docstring)))
793        if return_type == 'None' and handler.__name__.startswith('get'):
794            self._head_cmd = True
795        if self.confirmation:
796            param = CONFIRM_PARAM_NAME
797            docstring = 'Do not prompt for confirmation.'
798            args.append((param, CLICommandArgument(param,
799                                                   options_list=['--yes', '-y'],
800                                                   required=False,
801                                                   action='store_true',
802                                                   help=docstring)))
803        auth_group_name = 'Batch Account'
804        args.append(('cmd', CLICommandArgument('cmd', action=IgnoreAction)))
805        args.append(('account_name', CLICommandArgument(
806            'account_name', options_list=['--account-name'], required=False, default=None,
807            validator=validators.validate_client_parameters, arg_group=auth_group_name,
808            help='Batch account name. Alternatively, set by environment variable: AZURE_BATCH_ACCOUNT')))
809        args.append(('account_key', CLICommandArgument(
810            'account_key', options_list=['--account-key'], required=False, default=None, arg_group=auth_group_name,
811            help='Batch account key. Alternatively, set by environment variable: AZURE_BATCH_ACCESS_KEY')))
812        args.append(('account_endpoint', CLICommandArgument(
813            'account_endpoint', options_list=['--account-endpoint'], required=False,
814            default=None, arg_group=auth_group_name,
815            help='Batch service endpoint. Alternatively, set by environment variable: AZURE_BATCH_ENDPOINT')))
816        return args
817
818
819class BatchCommandGroup(AzCommandGroup):
820
821    def batch_command(self, name, method_name=None, command_type=None, **kwargs):
822        self._check_stale()
823        merged_kwargs = self.group_kwargs.copy()
824        group_command_type = merged_kwargs.get('command_type', None)
825        if command_type:
826            merged_kwargs.update(command_type.settings)
827        elif group_command_type:
828            merged_kwargs.update(group_command_type.settings)
829        merged_kwargs.update(kwargs)
830
831        operations_tmpl = merged_kwargs.get('operations_tmpl')
832        command_name = '{} {}'.format(self.group_name, name) if self.group_name else name
833        operation = operations_tmpl.format(method_name) if operations_tmpl else None
834        command = AzureBatchDataPlaneCommand(operation, self.command_loader, **merged_kwargs)
835
836        self.command_loader._cli_command(command_name, **command.get_kwargs())  # pylint: disable=protected-access
837