1# --------------------------------------------------------------------------------------------
2# Copyright (c) Microsoft Corporation. All rights reserved.
3# Licensed under the MIT License. See License.txt in the project root for license information.
4# --------------------------------------------------------------------------------------------
5
6
7import argparse
8import platform
9
10from azure.cli.core import EXCLUDED_PARAMS
11from azure.cli.core.commands.constants import CLI_PARAM_KWARGS, CLI_POSITIONAL_PARAM_KWARGS
12from azure.cli.core.commands.validators import validate_tag, validate_tags, generate_deployment_name
13from azure.cli.core.decorators import Completer
14from azure.cli.core.profiles import ResourceType
15from azure.cli.core.local_context import LocalContextAttribute, LocalContextAction, ALL
16
17from knack.arguments import (
18    CLIArgumentType, CaseInsensitiveList, ignore_type, ArgumentsContext)
19from knack.log import get_logger
20from knack.util import CLIError
21
22logger = get_logger(__name__)
23
24
25def get_subscription_locations(cli_ctx):
26    from azure.cli.core.commands.client_factory import get_subscription_service_client
27    subscription_client, subscription_id = get_subscription_service_client(cli_ctx)
28    return list(subscription_client.subscriptions.list_locations(subscription_id))
29
30
31@Completer
32def get_location_completion_list(cmd, prefix, namespace, **kwargs):  # pylint: disable=unused-argument
33    result = get_subscription_locations(cmd.cli_ctx)
34    return [item.name for item in result]
35
36
37# pylint: disable=redefined-builtin
38def get_datetime_type(help=None, date=True, time=True, timezone=True):
39
40    help_string = help + ' ' if help else ''
41    accepted_formats = []
42    if date:
43        accepted_formats.append('date (yyyy-mm-dd)')
44    if time:
45        accepted_formats.append('time (hh:mm:ss.xxxxx)')
46    if timezone:
47        accepted_formats.append('timezone (+/-hh:mm)')
48    help_string = help_string + 'Format: ' + ' '.join(accepted_formats)
49
50    # pylint: disable=too-few-public-methods
51    class DatetimeAction(argparse.Action):
52
53        def __call__(self, parser, namespace, values, option_string=None):
54            """ Parse a date value and return the ISO8601 string. """
55            import dateutil.parser
56            import dateutil.tz
57
58            value_string = ' '.join(values)
59            dt_val = None
60            try:
61                # attempt to parse ISO 8601
62                dt_val = dateutil.parser.parse(value_string)
63            except ValueError:
64                pass
65
66            # TODO: custom parsing attempts here
67            if not dt_val:
68                raise CLIError("Unable to parse: '{}'. Expected format: {}".format(value_string, help_string))
69
70            if not dt_val.tzinfo and timezone:
71                dt_val = dt_val.replace(tzinfo=dateutil.tz.tzlocal())
72
73            # Issue warning if any supplied data will be ignored
74            if not date and any([dt_val.day, dt_val.month, dt_val.year]):
75                logger.warning('Date info will be ignored in %s.', value_string)
76
77            if not time and any([dt_val.hour, dt_val.minute, dt_val.second, dt_val.microsecond]):
78                logger.warning('Time info will be ignored in %s.', value_string)
79
80            if not timezone and dt_val.tzinfo:
81                logger.warning('Timezone info will be ignored in %s.', value_string)
82
83            iso_string = dt_val.isoformat()
84            setattr(namespace, self.dest, iso_string)
85
86    return CLIArgumentType(action=DatetimeAction, nargs='+', help=help_string)
87
88
89def file_type(path):
90    import os
91    return os.path.expanduser(path)
92
93
94def get_location_name_type(cli_ctx):
95    def location_name_type(name):
96        if ' ' in name:
97            # if display name is provided, attempt to convert to short form name
98            name = next((location.name for location in get_subscription_locations(cli_ctx)
99                         if location.display_name.lower() == name.lower()), name)
100        return name
101    return location_name_type
102
103
104def get_one_of_subscription_locations(cli_ctx):
105    result = get_subscription_locations(cli_ctx)
106    if result:
107        return next((r.name for r in result if r.name.lower() == 'westus'), result[0].name)
108    raise CLIError('Current subscription does not have valid location list')
109
110
111def get_resource_groups(cli_ctx):
112    from azure.cli.core.commands.client_factory import get_mgmt_service_client
113    rcf = get_mgmt_service_client(cli_ctx, ResourceType.MGMT_RESOURCE_RESOURCES)
114    return list(rcf.resource_groups.list())
115
116
117@Completer
118def get_resource_group_completion_list(cmd, prefix, namespace, **kwargs):  # pylint: disable=unused-argument
119    result = get_resource_groups(cmd.cli_ctx)
120    return [item.name for item in result]
121
122
123def get_resources_in_resource_group(cli_ctx, resource_group_name, resource_type=None):
124    from azure.cli.core.commands.client_factory import get_mgmt_service_client
125    from azure.cli.core.profiles import supported_api_version
126
127    rcf = get_mgmt_service_client(cli_ctx, ResourceType.MGMT_RESOURCE_RESOURCES)
128    filter_str = "resourceType eq '{}'".format(resource_type) if resource_type else None
129    if supported_api_version(cli_ctx, ResourceType.MGMT_RESOURCE_RESOURCES, max_api='2016-09-01'):
130        return list(rcf.resource_groups.list_resources(resource_group_name, filter=filter_str))
131    return list(rcf.resources.list_by_resource_group(resource_group_name, filter=filter_str))
132
133
134def get_resources_in_subscription(cli_ctx, resource_type=None):
135    from azure.cli.core.commands.client_factory import get_mgmt_service_client
136    rcf = get_mgmt_service_client(cli_ctx, ResourceType.MGMT_RESOURCE_RESOURCES)
137    filter_str = "resourceType eq '{}'".format(resource_type) if resource_type else None
138    return list(rcf.resources.list(filter=filter_str))
139
140
141def get_resource_name_completion_list(resource_type=None):
142
143    @Completer
144    def completer(cmd, prefix, namespace, **kwargs):  # pylint: disable=unused-argument
145        rg = getattr(namespace, 'resource_group_name', None)
146        if rg:
147            return [r.name for r in get_resources_in_resource_group(cmd.cli_ctx, rg, resource_type=resource_type)]
148        return [r.name for r in get_resources_in_subscription(cmd.cli_ctx, resource_type)]
149
150    return completer
151
152
153def get_generic_completion_list(generic_list):
154
155    @Completer
156    def completer(cmd, prefix, namespace, **kwargs):  # pylint: disable=unused-argument
157        return generic_list
158    return completer
159
160
161def get_three_state_flag(positive_label='true', negative_label='false', invert=False, return_label=False):
162    """ Creates a flag-like argument that can also accept positive/negative values. This allows
163    consistency between create commands that typically use flags and update commands that require
164    positive/negative values without introducing breaking changes. Flag-like behavior always
165    implies the affirmative unless invert=True then invert the logic.
166    - positive_label: label for the positive value (ex: 'enabled')
167    - negative_label: label for the negative value (ex: 'disabled')
168    - invert: invert the boolean logic for the flag
169    - return_label: if true, return the corresponding label. Otherwise, return a boolean value
170    """
171    choices = [positive_label, negative_label]
172
173    # pylint: disable=too-few-public-methods
174    class ThreeStateAction(argparse.Action):
175
176        def __call__(self, parser, namespace, values, option_string=None):
177            values = values or positive_label
178            is_positive = values.lower() == positive_label.lower()
179            is_positive = not is_positive if invert else is_positive
180            set_val = None
181            if return_label:
182                set_val = positive_label if is_positive else negative_label
183            else:
184                set_val = is_positive
185            setattr(namespace, self.dest, set_val)
186
187    params = {
188        'choices': CaseInsensitiveList(choices),
189        'nargs': '?',
190        'action': ThreeStateAction
191    }
192    return CLIArgumentType(**params)
193
194
195def get_enum_type(data, default=None):
196    """ Creates the argparse choices and type kwargs for a supplied enum type or list of strings. """
197    if not data:
198        return None
199
200    # transform enum types, otherwise assume list of string choices
201    try:
202        choices = [x.value for x in data]
203    except AttributeError:
204        choices = data
205
206    # pylint: disable=too-few-public-methods
207    class DefaultAction(argparse.Action):
208
209        def __call__(self, parser, args, values, option_string=None):
210
211            def _get_value(val):
212                return next((x for x in self.choices if x.lower() == val.lower()), val)
213
214            if isinstance(values, list):
215                values = [_get_value(v) for v in values]
216            else:
217                values = _get_value(values)
218            setattr(args, self.dest, values)
219
220    def _type(value):
221        return next((x for x in choices if x.lower() == value.lower()), value) if value else value
222
223    default_value = None
224    if default:
225        default_value = next((x for x in choices if x.lower() == default.lower()), None)
226        if not default_value:
227            raise CLIError("Command authoring exception: unrecognized default '{}' from choices '{}'"
228                           .format(default, choices))
229        arg_type = CLIArgumentType(choices=CaseInsensitiveList(choices), action=DefaultAction, default=default_value)
230    else:
231        arg_type = CLIArgumentType(choices=CaseInsensitiveList(choices), action=DefaultAction)
232    return arg_type
233
234
235# GLOBAL ARGUMENT DEFINITIONS
236
237resource_group_name_type = CLIArgumentType(
238    options_list=['--resource-group', '-g'],
239    completer=get_resource_group_completion_list,
240    id_part='resource_group',
241    help="Name of resource group. You can configure the default group using `az configure --defaults group=<name>`",
242    configured_default='group',
243    local_context_attribute=LocalContextAttribute(
244        name='resource_group_name',
245        actions=[LocalContextAction.SET, LocalContextAction.GET],
246        scopes=[ALL]
247    ))
248
249name_type = CLIArgumentType(options_list=['--name', '-n'], help='the primary resource name')
250
251edge_zone_type = CLIArgumentType(options_list='--edge-zone', help='The name of edge zone.', is_preview=True)
252
253
254def get_location_type(cli_ctx):
255    location_type = CLIArgumentType(
256        options_list=['--location', '-l'],
257        completer=get_location_completion_list,
258        type=get_location_name_type(cli_ctx),
259        help="Location. Values from: `az account list-locations`. "
260             "You can configure the default location using `az configure --defaults location=<location>`.",
261        metavar='LOCATION',
262        configured_default='location',
263        local_context_attribute=LocalContextAttribute(
264            name='location',
265            actions=[LocalContextAction.SET, LocalContextAction.GET],
266            scopes=[ALL]
267        ))
268    return location_type
269
270
271deployment_name_type = CLIArgumentType(
272    help=argparse.SUPPRESS,
273    required=False,
274    validator=generate_deployment_name
275)
276
277quotes = '""' if platform.system() == 'Windows' else "''"
278quote_text = 'Use {} to clear existing tags.'.format(quotes)
279
280tags_type = CLIArgumentType(
281    validator=validate_tags,
282    help="space-separated tags: key[=value] [key[=value] ...]. {}".format(quote_text),
283    nargs='*'
284)
285
286tag_type = CLIArgumentType(
287    type=validate_tag,
288    help="a single tag in 'key[=value]' format. {}".format(quote_text),
289    nargs='?',
290    const=''
291)
292
293no_wait_type = CLIArgumentType(
294    options_list=['--no-wait', ],
295    help='do not wait for the long-running operation to finish',
296    action='store_true'
297)
298
299zones_type = CLIArgumentType(
300    options_list=['--zones', '-z'],
301    nargs='+',
302    help='Space-separated list of availability zones into which to provision the resource.',
303    choices=['1', '2', '3']
304)
305
306zone_type = CLIArgumentType(
307    options_list=['--zone', '-z'],
308    help='Availability zone into which to provision the resource.',
309    choices=['1', '2', '3'],
310    nargs=1
311)
312
313vnet_name_type = CLIArgumentType(
314    local_context_attribute=LocalContextAttribute(name='vnet_name', actions=[LocalContextAction.GET])
315)
316
317subnet_name_type = CLIArgumentType(
318    local_context_attribute=LocalContextAttribute(name='subnet_name', actions=[LocalContextAction.GET]))
319
320
321def patch_arg_make_required(argument):
322    argument.settings['required'] = True
323
324
325def patch_arg_make_optional(argument):
326    argument.settings['required'] = False
327
328
329def patch_arg_update_description(description):
330    def _patch_action(argument):
331        argument.settings['help'] = description
332
333    return _patch_action
334
335
336class AzArgumentContext(ArgumentsContext):
337
338    def __init__(self, command_loader, scope, **kwargs):
339        from azure.cli.core.commands import _merge_kwargs as merge_kwargs
340        super(AzArgumentContext, self).__init__(command_loader, scope)
341        self.scope = scope  # this is called "command" in knack, but that is not an accurate name
342        self.group_kwargs = merge_kwargs(kwargs, command_loader.module_kwargs, CLI_PARAM_KWARGS)
343
344    def __enter__(self):
345        return self
346
347    def __exit__(self, exc_type, exc_val, exc_tb):
348        self.is_stale = True
349
350    def _flatten_kwargs(self, kwargs, arg_type):
351        merged_kwargs = self._merge_kwargs(kwargs)
352        if arg_type:
353            arg_type_copy = arg_type.settings.copy()
354            arg_type_copy.update(merged_kwargs)
355            return arg_type_copy
356        return merged_kwargs
357
358    def _merge_kwargs(self, kwargs, base_kwargs=None):
359        from azure.cli.core.commands import _merge_kwargs as merge_kwargs
360        base = base_kwargs if base_kwargs is not None else getattr(self, 'group_kwargs')
361        return merge_kwargs(kwargs, base, CLI_PARAM_KWARGS)
362
363    def _ignore_if_not_registered(self, dest):
364        scope = self.scope
365        arg_registry = self.command_loader.argument_registry
366        match = arg_registry.arguments[scope].get(dest, {})
367        if not match:
368            super(AzArgumentContext, self).argument(dest, arg_type=ignore_type)
369
370    # pylint: disable=arguments-differ
371    def argument(self, dest, arg_type=None, **kwargs):
372        self._check_stale()
373        if not self._applicable():
374            return
375
376        merged_kwargs = self._flatten_kwargs(kwargs, arg_type)
377        resource_type = merged_kwargs.get('resource_type', None)
378        min_api = merged_kwargs.get('min_api', None)
379        max_api = merged_kwargs.get('max_api', None)
380        operation_group = merged_kwargs.get('operation_group', None)
381
382        if merged_kwargs.get('options_list', None) == []:
383            del merged_kwargs['options_list']
384
385        if self.command_loader.supported_api_version(resource_type=resource_type,
386                                                     min_api=min_api,
387                                                     max_api=max_api,
388                                                     operation_group=operation_group):
389            super(AzArgumentContext, self).argument(dest, **merged_kwargs)
390        else:
391            self._ignore_if_not_registered(dest)
392
393    def positional(self, dest, arg_type=None, **kwargs):
394        self._check_stale()
395        if not self._applicable():
396            return
397
398        merged_kwargs = self._flatten_kwargs(kwargs, arg_type)
399        merged_kwargs = {k: v for k, v in merged_kwargs.items() if k in CLI_POSITIONAL_PARAM_KWARGS}
400        merged_kwargs['options_list'] = []
401
402        resource_type = merged_kwargs.get('resource_type', None)
403        min_api = merged_kwargs.get('min_api', None)
404        max_api = merged_kwargs.get('max_api', None)
405        operation_group = merged_kwargs.get('operation_group', None)
406        if self.command_loader.supported_api_version(resource_type=resource_type,
407                                                     min_api=min_api,
408                                                     max_api=max_api,
409                                                     operation_group=operation_group):
410            super(AzArgumentContext, self).positional(dest, **merged_kwargs)
411        else:
412            self._ignore_if_not_registered(dest)
413
414    def expand(self, dest, model_type, group_name=None, patches=None):
415        # TODO:
416        # two privates symbols are imported here. they should be made public or this utility class
417        # should be moved into azure.cli.core
418        from knack.introspection import extract_args_from_signature, option_descriptions
419
420        self._check_stale()
421        if not self._applicable():
422            return
423
424        if not patches:
425            patches = dict()
426
427        # fetch the documentation for model parameters first. for models, which are the classes
428        # derive from msrest.serialization.Model and used in the SDK API to carry parameters, the
429        # document of their properties are attached to the classes instead of constructors.
430        parameter_docs = option_descriptions(model_type)
431
432        def get_complex_argument_processor(expanded_arguments, assigned_arg, model_type):
433            """
434            Return a validator which will aggregate multiple arguments to one complex argument.
435            """
436
437            def _expansion_validator_impl(namespace):
438                """
439                The validator create a argument of a given type from a specific set of arguments from CLI
440                command.
441                :param namespace: The argparse namespace represents the CLI arguments.
442                :return: The argument of specific type.
443                """
444                ns = vars(namespace)
445                kwargs = dict((k, ns[k]) for k in ns if k in set(expanded_arguments))
446
447                setattr(namespace, assigned_arg, model_type(**kwargs))
448
449            return _expansion_validator_impl
450
451        expanded_arguments = []
452        for name, arg in extract_args_from_signature(model_type.__init__, excluded_params=EXCLUDED_PARAMS):
453            arg = arg.type
454            if name in parameter_docs:
455                arg.settings['help'] = parameter_docs[name]
456
457            if group_name:
458                arg.settings['arg_group'] = group_name
459
460            if name in patches:
461                patches[name](arg)
462
463            self.extra(name, arg_type=arg)
464            expanded_arguments.append(name)
465
466        dest_option = ['--__{}'.format(dest.upper())]
467        self.argument(dest,
468                      arg_type=ignore_type,
469                      options_list=dest_option,
470                      validator=get_complex_argument_processor(expanded_arguments, dest, model_type))
471
472    def ignore(self, *args):
473        self._check_stale()
474        if not self._applicable():
475            return
476
477        for arg in args:
478            super(AzArgumentContext, self).ignore(arg)
479
480    def extra(self, dest, arg_type=None, **kwargs):
481
482        merged_kwargs = self._flatten_kwargs(kwargs, arg_type)
483        resource_type = merged_kwargs.get('resource_type', None)
484        min_api = merged_kwargs.get('min_api', None)
485        max_api = merged_kwargs.get('max_api', None)
486        operation_group = merged_kwargs.get('operation_group', None)
487        if self.command_loader.supported_api_version(resource_type=resource_type,
488                                                     min_api=min_api,
489                                                     max_api=max_api,
490                                                     operation_group=operation_group):
491            # Restore when knack #132 is fixed
492            # merged_kwargs.pop('dest', None)
493            # super(AzArgumentContext, self).extra(dest, **merged_kwargs)
494            from knack.arguments import CLICommandArgument
495            self._check_stale()
496            if not self._applicable():
497                return
498
499            if self.command_scope in self.command_loader.command_group_table:
500                raise ValueError("command authoring error: extra argument '{}' cannot be registered to a group-level "
501                                 "scope '{}'. It must be registered to a specific command.".format(
502                                     dest, self.command_scope))
503
504            deprecate_action = self._handle_deprecations(dest, **merged_kwargs)
505            if deprecate_action:
506                merged_kwargs['action'] = deprecate_action
507            merged_kwargs.pop('dest', None)
508            self.command_loader.extra_argument_registry[self.command_scope][dest] = CLICommandArgument(
509                dest, **merged_kwargs)
510