1# --------------------------------------------------------------------------------------------
2# Copyright (c) Microsoft Corporation. All rights reserved.
3# Licensed under the MIT License. See License.txt in the project root for license information.
4# --------------------------------------------------------------------------------------------
5# pylint: disable=line-too-long
6
7__version__ = "2.29.2"
8
9import os
10import sys
11import timeit
12
13from knack.cli import CLI
14from knack.commands import CLICommandsLoader
15from knack.completion import ARGCOMPLETE_ENV_NAME
16from knack.introspection import extract_args_from_signature, extract_full_summary_from_signature
17from knack.log import get_logger
18from knack.preview import PreviewItem
19from knack.experimental import ExperimentalItem
20from knack.util import CLIError
21from knack.arguments import ArgumentsContext, CaseInsensitiveList  # pylint: disable=unused-import
22from .local_context import AzCLILocalContext, LocalContextAction
23
24logger = get_logger(__name__)
25
26EXCLUDED_PARAMS = ['self', 'raw', 'polling', 'custom_headers', 'operation_config',
27                   'content_version', 'kwargs', 'client', 'no_wait']
28EVENT_FAILED_EXTENSION_LOAD = 'MainLoader.OnFailedExtensionLoad'
29
30# [Reserved, in case of future usage]
31# Modules that will always be loaded. They don't expose commands but hook into CLI core.
32ALWAYS_LOADED_MODULES = []
33# Extensions that will always be loaded if installed. They don't expose commands but hook into CLI core.
34ALWAYS_LOADED_EXTENSIONS = ['azext_ai_examples', 'azext_next']
35
36
37def _configure_knack():
38    """Override consts defined in knack to make them Azure CLI-specific."""
39
40    # Customize status tag messages.
41    from knack.util import status_tag_messages
42    ref_message = "Reference and support levels: https://aka.ms/CLI_refstatus"
43    # Override the preview message.
44    status_tag_messages['preview'] = "{} is in preview and under development. " + ref_message
45    # Override the experimental message.
46    status_tag_messages['experimental'] = "{} is experimental and under development. " + ref_message
47
48    # Allow logs from 'azure' logger to be displayed.
49    from knack.log import cli_logger_names
50    cli_logger_names.append('azure')
51
52
53_configure_knack()
54
55
56class AzCli(CLI):
57
58    def __init__(self, **kwargs):
59        super(AzCli, self).__init__(**kwargs)
60
61        from azure.cli.core.commands import register_cache_arguments
62        from azure.cli.core.commands.arm import (
63            register_ids_argument, register_global_subscription_argument)
64        from azure.cli.core.cloud import get_active_cloud
65        from azure.cli.core.commands.transform import register_global_transforms
66        from azure.cli.core._session import ACCOUNT, CONFIG, SESSION, INDEX, VERSIONS
67        from azure.cli.core.util import handle_version_update
68        from azure.cli.core.commands.query_examples import register_global_query_examples_argument
69
70        from knack.util import ensure_dir
71
72        self.data['headers'] = {}
73        self.data['command'] = 'unknown'
74        self.data['command_extension_name'] = None
75        self.data['completer_active'] = ARGCOMPLETE_ENV_NAME in os.environ
76        self.data['query_active'] = False
77
78        azure_folder = self.config.config_dir
79        ensure_dir(azure_folder)
80        ACCOUNT.load(os.path.join(azure_folder, 'azureProfile.json'))
81        CONFIG.load(os.path.join(azure_folder, 'az.json'))
82        SESSION.load(os.path.join(azure_folder, 'az.sess'), max_age=3600)
83        INDEX.load(os.path.join(azure_folder, 'commandIndex.json'))
84        VERSIONS.load(os.path.join(azure_folder, 'versionCheck.json'))
85        handle_version_update()
86
87        self.cloud = get_active_cloud(self)
88        logger.debug('Current cloud config:\n%s', str(self.cloud.name))
89        self.local_context = AzCLILocalContext(self)
90        register_global_transforms(self)
91        register_global_subscription_argument(self)
92        register_global_query_examples_argument(self)
93        register_ids_argument(self)  # global subscription must be registered first!
94        register_cache_arguments(self)
95
96        self.progress_controller = None
97
98        self._configure_style()
99
100    def refresh_request_id(self):
101        """Assign a new random GUID as x-ms-client-request-id
102
103        The method must be invoked before each command execution in order to ensure
104        unique client-side request ID is generated.
105        """
106        import uuid
107        self.data['headers']['x-ms-client-request-id'] = str(uuid.uuid1())
108
109    def get_progress_controller(self, det=False, spinner=None):
110        import azure.cli.core.commands.progress as progress
111        if not self.progress_controller:
112            self.progress_controller = progress.ProgressHook()
113
114        self.progress_controller.init_progress(progress.get_progress_view(det, spinner=spinner))
115        return self.progress_controller
116
117    def get_cli_version(self):
118        return __version__
119
120    def show_version(self):
121        from azure.cli.core.util import get_az_version_string, show_updates
122        from azure.cli.core.commands.constants import SURVEY_PROMPT_STYLED, UX_SURVEY_PROMPT_STYLED
123        from azure.cli.core.style import print_styled_text
124
125        ver_string, updates_available_components = get_az_version_string()
126        print(ver_string)
127        show_updates(updates_available_components)
128
129        show_link = self.config.getboolean('output', 'show_survey_link', True)
130        if show_link:
131            print_styled_text()
132            print_styled_text(SURVEY_PROMPT_STYLED)
133            print_styled_text(UX_SURVEY_PROMPT_STYLED)
134
135    def exception_handler(self, ex):  # pylint: disable=no-self-use
136        from azure.cli.core.util import handle_exception
137        return handle_exception(ex)
138
139    def save_local_context(self, parsed_args, argument_definitions, specified_arguments):
140        """ Local Context Attribute arguments
141
142        Save argument value to local context if it is defined as SET and user specify a value for it.
143
144        :param parsed_args: Parsed args which return by AzCliCommandParser parse_args
145        :type parsed_args: Namespace
146        :param argument_definitions: All available argument definitions
147        :type argument_definitions: dict
148        :param specified_arguments: Arguments which user specify in this command
149        :type specified_arguments: list
150        """
151        local_context_args = []
152        for argument_name in specified_arguments:
153            # make sure SET is defined
154            if argument_name not in argument_definitions:
155                continue
156            argtype = argument_definitions[argument_name].type
157            lca = argtype.settings.get('local_context_attribute', None)
158            if not lca or not lca.actions or LocalContextAction.SET not in lca.actions:
159                continue
160            # get the specified value
161            value = getattr(parsed_args, argument_name)
162            # save when name and scopes have value
163            if lca.name and lca.scopes:
164                self.local_context.set(lca.scopes, lca.name, value)
165            options = argtype.settings.get('options_list', None)
166            if options:
167                local_context_args.append((options[0], value))
168
169        # print warning if there are values saved to local context
170        if local_context_args:
171            logger.warning('Parameter persistence is turned on. Its information is saved in working directory %s. '
172                           'You can run `az config param-persist off` to turn it off.',
173                           self.local_context.effective_working_directory())
174            args_str = []
175            for name, value in local_context_args:
176                args_str.append('{}: {}'.format(name, value))
177            logger.warning('Your preference of %s now saved as persistent parameter. To learn more, type in `az '
178                           'config param-persist --help`',
179                           ', '.join(args_str) + (' is' if len(args_str) == 1 else ' are'))
180
181    def _configure_style(self):
182        from azure.cli.core.util import in_cloud_console
183        from azure.cli.core.style import format_styled_text, get_theme_dict, Style
184
185        # Configure Style
186        if self.enable_color:
187            theme = self.config.get('core', 'theme',
188                                    fallback="cloud-shell" if in_cloud_console() else "dark")
189
190            theme_dict = get_theme_dict(theme)
191
192            if theme_dict:
193                # If theme is used, also apply it to knack's logger
194                from knack.util import color_map
195                color_map['error'] = theme_dict[Style.ERROR]
196                color_map['warning'] = theme_dict[Style.WARNING]
197        else:
198            theme = 'none'
199        format_styled_text.theme = theme
200
201
202class MainCommandsLoader(CLICommandsLoader):
203
204    # Format string for pretty-print the command module table
205    header_mod = "%-20s %10s %9s %9s" % ("Name", "Load Time", "Groups", "Commands")
206    item_format_string = "%-20s %10.3f %9d %9d"
207    header_ext = header_mod + "  Directory"
208    item_ext_format_string = item_format_string + "  %s"
209
210    def __init__(self, cli_ctx=None):
211        super(MainCommandsLoader, self).__init__(cli_ctx)
212        self.cmd_to_loader_map = {}
213        self.loaders = []
214
215    def _update_command_definitions(self):
216        for cmd_name in self.command_table:
217            loaders = self.cmd_to_loader_map[cmd_name]
218            for loader in loaders:
219                loader.command_table = self.command_table
220                loader._update_command_definitions()  # pylint: disable=protected-access
221
222    # pylint: disable=too-many-statements, too-many-locals
223    def load_command_table(self, args):
224        from importlib import import_module
225        import pkgutil
226        import traceback
227        from azure.cli.core.commands import (
228            _load_module_command_loader, _load_extension_command_loader, BLOCKED_MODS, ExtensionCommandSource)
229        from azure.cli.core.extension import (
230            get_extensions, get_extension_path, get_extension_modname)
231
232        def _update_command_table_from_modules(args, command_modules=None):
233            """Loads command tables from modules and merge into the main command table.
234
235            :param args: Arguments of the command.
236            :param list command_modules: Command modules to load, in the format like ['resource', 'profile'].
237             If None, will do module discovery and load all modules.
238             If [], only ALWAYS_LOADED_MODULES will be loaded.
239             Otherwise, the list will be extended using ALWAYS_LOADED_MODULES.
240            """
241
242            # As command modules are built-in, the existence of modules in ALWAYS_LOADED_MODULES is NOT checked
243            if command_modules is not None:
244                command_modules.extend(ALWAYS_LOADED_MODULES)
245            else:
246                # Perform module discovery
247                command_modules = []
248                try:
249                    mods_ns_pkg = import_module('azure.cli.command_modules')
250                    command_modules = [modname for _, modname, _ in
251                                       pkgutil.iter_modules(mods_ns_pkg.__path__)]
252                    logger.debug('Discovered command modules: %s', command_modules)
253                except ImportError as e:
254                    logger.warning(e)
255
256            count = 0
257            cumulative_elapsed_time = 0
258            cumulative_group_count = 0
259            cumulative_command_count = 0
260            logger.debug("Loading command modules:")
261            logger.debug(self.header_mod)
262
263            for mod in [m for m in command_modules if m not in BLOCKED_MODS]:
264                try:
265                    start_time = timeit.default_timer()
266                    module_command_table, module_group_table = _load_module_command_loader(self, args, mod)
267                    for cmd in module_command_table.values():
268                        cmd.command_source = mod
269                    self.command_table.update(module_command_table)
270                    self.command_group_table.update(module_group_table)
271
272                    elapsed_time = timeit.default_timer() - start_time
273                    logger.debug(self.item_format_string, mod, elapsed_time,
274                                 len(module_group_table), len(module_command_table))
275                    count += 1
276                    cumulative_elapsed_time += elapsed_time
277                    cumulative_group_count += len(module_group_table)
278                    cumulative_command_count += len(module_command_table)
279                except Exception as ex:  # pylint: disable=broad-except
280                    # Changing this error message requires updating CI script that checks for failed
281                    # module loading.
282                    import azure.cli.core.telemetry as telemetry
283                    logger.error("Error loading command module '%s': %s", mod, ex)
284                    telemetry.set_exception(exception=ex, fault_type='module-load-error-' + mod,
285                                            summary='Error loading module: {}'.format(mod))
286                    logger.debug(traceback.format_exc())
287            # Summary line
288            logger.debug(self.item_format_string,
289                         "Total ({})".format(count), cumulative_elapsed_time,
290                         cumulative_group_count, cumulative_command_count)
291
292        def _update_command_table_from_extensions(ext_suppressions, extension_modname=None):
293            """Loads command tables from extensions and merge into the main command table.
294
295            :param ext_suppressions: Extension suppression information.
296            :param extension_modname: Command modules to load, in the format like ['azext_timeseriesinsights'].
297             If None, will do extension discovery and load all extensions.
298             If [], only ALWAYS_LOADED_EXTENSIONS will be loaded.
299             Otherwise, the list will be extended using ALWAYS_LOADED_EXTENSIONS.
300             If the extensions in the list are not installed, it will be skipped.
301            """
302            def _handle_extension_suppressions(extensions):
303                filtered_extensions = []
304                for ext in extensions:
305                    should_include = True
306                    for suppression in ext_suppressions:
307                        if should_include and suppression.handle_suppress(ext):
308                            should_include = False
309                    if should_include:
310                        filtered_extensions.append(ext)
311                return filtered_extensions
312
313            def _filter_modname(extensions):
314                # Extension's name may not be the same as its modname. eg. name: virtual-wan, modname: azext_vwan
315                filtered_extensions = []
316                for ext in extensions:
317                    ext_mod = get_extension_modname(ext.name, ext.path)
318                    # Filter the extensions according to the index
319                    if ext_mod in extension_modname:
320                        filtered_extensions.append(ext)
321                        extension_modname.remove(ext_mod)
322                if extension_modname:
323                    logger.debug("These extensions are not installed and will be skipped: %s", extension_modname)
324                return filtered_extensions
325
326            extensions = get_extensions()
327            if extensions:
328                if extension_modname is not None:
329                    extension_modname.extend(ALWAYS_LOADED_EXTENSIONS)
330                    extensions = _filter_modname(extensions)
331                allowed_extensions = _handle_extension_suppressions(extensions)
332                module_commands = set(self.command_table.keys())
333
334                count = 0
335                cumulative_elapsed_time = 0
336                cumulative_group_count = 0
337                cumulative_command_count = 0
338                logger.debug("Loading extensions:")
339                logger.debug(self.header_ext)
340
341                for ext in allowed_extensions:
342                    try:
343                        # Import in the `for` loop because `allowed_extensions` can be []. In such case we
344                        # don't need to import `check_version_compatibility` at all.
345                        from azure.cli.core.extension.operations import check_version_compatibility
346                        check_version_compatibility(ext.get_metadata())
347                    except CLIError as ex:
348                        # issue warning and skip loading extensions that aren't compatible with the CLI core
349                        logger.warning(ex)
350                        continue
351                    ext_name = ext.name
352                    ext_dir = ext.path or get_extension_path(ext_name)
353                    sys.path.append(ext_dir)
354                    try:
355                        ext_mod = get_extension_modname(ext_name, ext_dir=ext_dir)
356                        # Add to the map. This needs to happen before we load commands as registering a command
357                        # from an extension requires this map to be up-to-date.
358                        # self._mod_to_ext_map[ext_mod] = ext_name
359                        start_time = timeit.default_timer()
360                        extension_command_table, extension_group_table = \
361                            _load_extension_command_loader(self, args, ext_mod)
362
363                        for cmd_name, cmd in extension_command_table.items():
364                            cmd.command_source = ExtensionCommandSource(
365                                extension_name=ext_name,
366                                overrides_command=cmd_name in module_commands,
367                                preview=ext.preview,
368                                experimental=ext.experimental)
369
370                        self.command_table.update(extension_command_table)
371                        self.command_group_table.update(extension_group_table)
372
373                        elapsed_time = timeit.default_timer() - start_time
374                        logger.debug(self.item_ext_format_string, ext_name, elapsed_time,
375                                     len(extension_group_table), len(extension_command_table),
376                                     ext_dir)
377                        count += 1
378                        cumulative_elapsed_time += elapsed_time
379                        cumulative_group_count += len(extension_group_table)
380                        cumulative_command_count += len(extension_command_table)
381                    except Exception as ex:  # pylint: disable=broad-except
382                        self.cli_ctx.raise_event(EVENT_FAILED_EXTENSION_LOAD, extension_name=ext_name)
383                        logger.warning("Unable to load extension '%s: %s'. Use --debug for more information.",
384                                       ext_name, ex)
385                        logger.debug(traceback.format_exc())
386                # Summary line
387                logger.debug(self.item_ext_format_string,
388                             "Total ({})".format(count), cumulative_elapsed_time,
389                             cumulative_group_count, cumulative_command_count, "")
390
391        def _wrap_suppress_extension_func(func, ext):
392            """ Wrapper method to handle centralization of log messages for extension filters """
393            res = func(ext)
394            should_suppress = res
395            reason = "Use --debug for more information."
396            if isinstance(res, tuple):
397                should_suppress, reason = res
398            suppress_types = (bool, type(None))
399            if not isinstance(should_suppress, suppress_types):
400                raise ValueError("Command module authoring error: "
401                                 "Valid extension suppression values are {} in {}".format(suppress_types, func))
402            if should_suppress:
403                logger.warning("Extension %s (%s) has been suppressed. %s",
404                               ext.name, ext.version, reason)
405                logger.debug("Extension %s (%s) suppressed from being loaded due "
406                             "to %s", ext.name, ext.version, func)
407            return should_suppress
408
409        def _get_extension_suppressions(mod_loaders):
410            res = []
411            for m in mod_loaders:
412                suppressions = getattr(m, 'suppress_extension', None)
413                if suppressions:
414                    suppressions = suppressions if isinstance(suppressions, list) else [suppressions]
415                    for sup in suppressions:
416                        if isinstance(sup, ModExtensionSuppress):
417                            res.append(sup)
418            return res
419
420        # Clear the tables to make this method idempotent
421        self.command_group_table.clear()
422        self.command_table.clear()
423
424        command_index = None
425        # Set fallback=False to turn off command index in case of regression
426        use_command_index = self.cli_ctx.config.getboolean('core', 'use_command_index', fallback=True)
427        if use_command_index:
428            command_index = CommandIndex(self.cli_ctx)
429            index_result = command_index.get(args)
430            if index_result:
431                index_modules, index_extensions = index_result
432                # Always load modules and extensions, because some of them (like those in
433                # ALWAYS_LOADED_EXTENSIONS) don't expose a command, but hooks into handlers in CLI core
434                _update_command_table_from_modules(args, index_modules)
435                # The index won't contain suppressed extensions
436                _update_command_table_from_extensions([], index_extensions)
437
438                logger.debug("Loaded %d groups, %d commands.", len(self.command_group_table), len(self.command_table))
439                from azure.cli.core.util import roughly_parse_command
440                # The index may be outdated. Make sure the command appears in the loaded command table
441                raw_cmd = roughly_parse_command(args)
442                for cmd in self.command_table:
443                    if raw_cmd.startswith(cmd):
444                        # For commands with positional arguments, the raw command won't match the one in the
445                        # command table. For example, `az find vm create` won't exist in the command table, but the
446                        # corresponding command should be `az find`.
447                        # raw command  : az find vm create
448                        # command table: az find
449                        # remaining    :         vm create
450                        logger.debug("Found a match in the command table.")
451                        logger.debug("Raw command  : %s", raw_cmd)
452                        logger.debug("Command table: %s", cmd)
453                        remaining = raw_cmd[len(cmd) + 1:]
454                        if remaining:
455                            logger.debug("remaining    : %s %s", ' ' * len(cmd), remaining)
456                        return self.command_table
457                # For command group, it must be an exact match, as no positional argument is supported by
458                # command group operations.
459                if raw_cmd in self.command_group_table:
460                    logger.debug("Found a match in the command group table for '%s'.", raw_cmd)
461                    return self.command_table
462
463                logger.debug("Could not find a match in the command or command group table for '%s'. "
464                             "The index may be outdated.", raw_cmd)
465            else:
466                logger.debug("No module found from index for '%s'", args)
467
468        # No module found from the index. Load all command modules and extensions
469        logger.debug("Loading all modules and extensions")
470        _update_command_table_from_modules(args)
471
472        ext_suppressions = _get_extension_suppressions(self.loaders)
473        # We always load extensions even if the appropriate module has been loaded
474        # as an extension could override the commands already loaded.
475        _update_command_table_from_extensions(ext_suppressions)
476        logger.debug("Loaded %d groups, %d commands.", len(self.command_group_table), len(self.command_table))
477
478        if use_command_index:
479            command_index.update(self.command_table)
480
481        return self.command_table
482
483    def load_arguments(self, command=None):
484        from azure.cli.core.commands.parameters import (
485            resource_group_name_type, get_location_type, deployment_name_type, vnet_name_type, subnet_name_type)
486        from knack.arguments import ignore_type
487
488        # omit specific command to load everything
489        if command is None:
490            command_loaders = set()
491            for loaders in self.cmd_to_loader_map.values():
492                command_loaders = command_loaders.union(set(loaders))
493            logger.info('Applying %s command loaders...', len(command_loaders))
494        else:
495            command_loaders = self.cmd_to_loader_map.get(command, None)
496
497        if command_loaders:
498            for loader in command_loaders:
499
500                # register global args
501                with loader.argument_context('') as c:
502                    c.argument('resource_group_name', resource_group_name_type)
503                    c.argument('location', get_location_type(self.cli_ctx))
504                    c.argument('vnet_name', vnet_name_type)
505                    c.argument('subnet', subnet_name_type)
506                    c.argument('deployment_name', deployment_name_type)
507                    c.argument('cmd', ignore_type)
508
509                if command is None:
510                    # load all arguments via reflection
511                    for cmd in loader.command_table.values():
512                        cmd.load_arguments()  # this loads the arguments via reflection
513                    loader.skip_applicability = True
514                    loader.load_arguments('')  # this adds entries to the argument registries
515                else:
516                    loader.command_name = command
517                    self.command_table[command].load_arguments()  # this loads the arguments via reflection
518                    loader.load_arguments(command)  # this adds entries to the argument registries
519                self.argument_registry.arguments.update(loader.argument_registry.arguments)
520                self.extra_argument_registry.update(loader.extra_argument_registry)
521                loader._update_command_definitions()  # pylint: disable=protected-access
522
523
524class CommandIndex:
525
526    _COMMAND_INDEX = 'commandIndex'
527    _COMMAND_INDEX_VERSION = 'version'
528    _COMMAND_INDEX_CLOUD_PROFILE = 'cloudProfile'
529
530    def __init__(self, cli_ctx=None):
531        """Class to manage command index.
532
533        :param cli_ctx: Only needed when `get` or `update` is called.
534        """
535        from azure.cli.core._session import INDEX
536        self.INDEX = INDEX
537        if cli_ctx:
538            self.version = __version__
539            self.cloud_profile = cli_ctx.cloud.profile
540
541    def get(self, args):
542        """Get the corresponding module and extension list of a command.
543
544        :param args: command arguments, like ['network', 'vnet', 'create', '-h']
545        :return: a tuple containing a list of modules and a list of extensions.
546        """
547        # If the command index version or cloud profile doesn't match those of the current command,
548        # invalidate the command index.
549        index_version = self.INDEX[self._COMMAND_INDEX_VERSION]
550        cloud_profile = self.INDEX[self._COMMAND_INDEX_CLOUD_PROFILE]
551        if not (index_version and index_version == self.version and
552                cloud_profile and cloud_profile == self.cloud_profile):
553            logger.debug("Command index version or cloud profile is invalid or doesn't match the current command.")
554            self.invalidate()
555            return None
556
557        # Make sure the top-level command is provided, like `az version`.
558        # Skip command index for `az` or `az --help`.
559        if not args or args[0].startswith('-'):
560            return None
561
562        # Get the top-level command, like `network` in `network vnet create -h`
563        top_command = args[0]
564        index = self.INDEX[self._COMMAND_INDEX]
565        # Check the command index for (command: [module]) mapping, like
566        # "network": ["azure.cli.command_modules.natgateway", "azure.cli.command_modules.network", "azext_firewall"]
567        index_modules_extensions = index.get(top_command)
568
569        if index_modules_extensions:
570            # This list contains both built-in modules and extensions
571            index_builtin_modules = []
572            index_extensions = []
573            # Found modules from index
574            logger.debug("Modules found from index for '%s': %s", top_command, index_modules_extensions)
575            command_module_prefix = 'azure.cli.command_modules.'
576            for m in index_modules_extensions:
577                if m.startswith(command_module_prefix):
578                    # The top-level command is from a command module
579                    index_builtin_modules.append(m[len(command_module_prefix):])
580                elif m.startswith('azext_'):
581                    # The top-level command is from an extension
582                    index_extensions.append(m)
583                else:
584                    logger.warning("Unrecognized module: %s", m)
585            return index_builtin_modules, index_extensions
586
587        return None
588
589    def update(self, command_table):
590        """Update the command index according to the given command table.
591
592        :param command_table: The command table built by azure.cli.core.MainCommandsLoader.load_command_table
593        """
594        start_time = timeit.default_timer()
595        self.INDEX[self._COMMAND_INDEX_VERSION] = __version__
596        self.INDEX[self._COMMAND_INDEX_CLOUD_PROFILE] = self.cloud_profile
597        from collections import defaultdict
598        index = defaultdict(list)
599
600        # self.cli_ctx.invocation.commands_loader.command_table doesn't exist in DummyCli due to the lack of invocation
601        for command_name, command in command_table.items():
602            # Get the top-level name: <vm> create
603            top_command = command_name.split()[0]
604            # Get module name, like azure.cli.command_modules.vm, azext_webapp
605            module_name = command.loader.__module__
606            if module_name not in index[top_command]:
607                index[top_command].append(module_name)
608        elapsed_time = timeit.default_timer() - start_time
609        self.INDEX[self._COMMAND_INDEX] = index
610        logger.debug("Updated command index in %.3f seconds.", elapsed_time)
611
612    def invalidate(self):
613        """Invalidate the command index.
614
615        This function MUST be called when installing or updating extensions. Otherwise, when an extension
616            1. overrides a built-in command, or
617            2. extends an existing command group,
618        the command or command group will only be loaded from the command modules as per the stale command index,
619        making the newly installed extension be ignored.
620
621        This function can be called when removing extensions.
622        """
623        self.INDEX[self._COMMAND_INDEX_VERSION] = ""
624        self.INDEX[self._COMMAND_INDEX_CLOUD_PROFILE] = ""
625        self.INDEX[self._COMMAND_INDEX] = {}
626        logger.debug("Command index has been invalidated.")
627
628
629class ModExtensionSuppress:  # pylint: disable=too-few-public-methods
630
631    def __init__(self, mod_name, suppress_extension_name, suppress_up_to_version, reason=None, recommend_remove=False,
632                 recommend_update=False):
633        self.mod_name = mod_name
634        self.suppress_extension_name = suppress_extension_name
635        self.suppress_up_to_version = suppress_up_to_version
636        self.reason = reason
637        self.recommend_remove = recommend_remove
638        self.recommend_update = recommend_update
639
640    def handle_suppress(self, ext):
641        from packaging.version import parse
642        should_suppress = ext.name == self.suppress_extension_name and ext.version and \
643            parse(ext.version) <= parse(self.suppress_up_to_version)
644        if should_suppress:
645            reason = self.reason or "Use --debug for more information."
646            logger.warning("Extension %s (%s) has been suppressed. %s",
647                           ext.name, ext.version, reason)
648            logger.debug("Extension %s (%s) suppressed from being loaded due "
649                         "to %s", ext.name, ext.version, self.mod_name)
650            if self.recommend_remove:
651                logger.warning("Remove this extension with 'az extension remove --name %s'", ext.name)
652            if self.recommend_update:
653                logger.warning("Update this extension with 'az extension update --name %s'", ext.name)
654        return should_suppress
655
656
657class AzCommandsLoader(CLICommandsLoader):  # pylint: disable=too-many-instance-attributes
658
659    def __init__(self, cli_ctx=None, command_group_cls=None, argument_context_cls=None,
660                 suppress_extension=None, **kwargs):
661        from azure.cli.core.commands import AzCliCommand, AzCommandGroup, AzArgumentContext
662
663        super(AzCommandsLoader, self).__init__(cli_ctx=cli_ctx,
664                                               command_cls=AzCliCommand,
665                                               excluded_command_handler_args=EXCLUDED_PARAMS)
666        self.suppress_extension = suppress_extension
667        self.module_kwargs = kwargs
668        self.command_name = None
669        self.skip_applicability = False
670        self._command_group_cls = command_group_cls or AzCommandGroup
671        self._argument_context_cls = argument_context_cls or AzArgumentContext
672
673    def _update_command_definitions(self):
674        master_arg_registry = self.cli_ctx.invocation.commands_loader.argument_registry
675        master_extra_arg_registry = self.cli_ctx.invocation.commands_loader.extra_argument_registry
676
677        for command_name, command in self.command_table.items():
678            # Add any arguments explicitly registered for this command
679            for argument_name, argument_definition in master_extra_arg_registry[command_name].items():
680                command.arguments[argument_name] = argument_definition
681
682            for argument_name in command.arguments:
683                overrides = master_arg_registry.get_cli_argument(command_name, argument_name)
684                command.update_argument(argument_name, overrides)
685
686    def _apply_doc_string(self, dest, command_kwargs):
687        from azure.cli.core.profiles._shared import APIVersionException
688        doc_string_source = command_kwargs.get('doc_string_source', None)
689        if not doc_string_source:
690            return
691        if not isinstance(doc_string_source, str):
692            raise CLIError("command authoring error: applying doc_string_source '{}' directly will cause slowdown. "
693                           'Import by string name instead.'.format(doc_string_source.__name__))
694
695        model = doc_string_source
696        try:
697            model = self.get_models(doc_string_source)
698        except APIVersionException:
699            model = None
700        if not model:
701            from importlib import import_module
702            (path, model_name) = doc_string_source.split('#', 1)
703            method_name = None
704            if '.' in model_name:
705                (model_name, method_name) = model_name.split('.', 1)
706            module = import_module(path)
707            model = getattr(module, model_name)
708            if method_name:
709                model = getattr(model, method_name, None)
710        if not model:
711            raise CLIError("command authoring error: source '{}' not found.".format(doc_string_source))
712        dest.__doc__ = model.__doc__
713
714    def _get_resource_type(self):
715        resource_type = self.module_kwargs.get('resource_type', None)
716        if not resource_type:
717            command_type = self.module_kwargs.get('command_type', None)
718            resource_type = command_type.settings.get('resource_type', None) if command_type else None
719        return resource_type
720
721    def get_api_version(self, resource_type=None, operation_group=None):
722        from azure.cli.core.profiles import get_api_version
723        resource_type = resource_type or self._get_resource_type()
724        version = get_api_version(self.cli_ctx, resource_type)
725        if isinstance(version, str):
726            return version
727        version = getattr(version, operation_group, None)
728        if version:
729            return version
730        from azure.cli.core.profiles._shared import APIVersionException
731        raise APIVersionException(operation_group, self.cli_ctx.cloud.profile)
732
733    def supported_api_version(self, resource_type=None, min_api=None, max_api=None, operation_group=None):
734        from azure.cli.core.profiles import supported_api_version
735        if not min_api and not max_api:
736            # optimistically assume that fully supported if no api restriction listed
737            return True
738        api_support = supported_api_version(
739            cli_ctx=self.cli_ctx,
740            resource_type=resource_type or self._get_resource_type(),
741            min_api=min_api,
742            max_api=max_api,
743            operation_group=operation_group)
744        if isinstance(api_support, bool):
745            return api_support
746        if operation_group:
747            return getattr(api_support, operation_group)
748        return api_support
749
750    def supported_resource_type(self, resource_type=None):
751        from azure.cli.core.profiles import supported_resource_type
752        return supported_resource_type(
753            cli_ctx=self.cli_ctx,
754            resource_type=resource_type or self._get_resource_type())
755
756    def get_sdk(self, *attr_args, **kwargs):
757        from azure.cli.core.profiles import get_sdk
758        return get_sdk(self.cli_ctx, kwargs.pop('resource_type', self._get_resource_type()),
759                       *attr_args, **kwargs)
760
761    def get_models(self, *attr_args, **kwargs):
762        from azure.cli.core.profiles import get_sdk
763        resource_type = kwargs.get('resource_type', self._get_resource_type())
764        operation_group = kwargs.get('operation_group', self.module_kwargs.get('operation_group', None))
765        return get_sdk(self.cli_ctx, resource_type, *attr_args, mod='models', operation_group=operation_group)
766
767    def command_group(self, group_name, command_type=None, **kwargs):
768        if command_type:
769            kwargs['command_type'] = command_type
770        if 'deprecate_info' in kwargs:
771            kwargs['deprecate_info'].target = group_name
772        if kwargs.get('is_preview', False):
773            kwargs['preview_info'] = PreviewItem(
774                cli_ctx=self.cli_ctx,
775                target=group_name,
776                object_type='command group'
777            )
778        if kwargs.get('is_experimental', False):
779            kwargs['experimental_info'] = ExperimentalItem(
780                cli_ctx=self.cli_ctx,
781                target=group_name,
782                object_type='command group'
783            )
784        return self._command_group_cls(self, group_name, **kwargs)
785
786    def argument_context(self, scope, **kwargs):
787        return self._argument_context_cls(self, scope, **kwargs)
788
789    # Please use add_cli_command instead of _cli_command.
790    # Currently "keyvault" and "batch" modules are still rely on this function, so it cannot be removed now.
791    def _cli_command(self, name, operation=None, handler=None, argument_loader=None, description_loader=None, **kwargs):
792
793        from knack.deprecation import Deprecated
794
795        kwargs['deprecate_info'] = Deprecated.ensure_new_style_deprecation(self.cli_ctx, kwargs, 'command')
796
797        if operation and not isinstance(operation, str):
798            raise TypeError("Operation must be a string. Got '{}'".format(operation))
799        if handler and not callable(handler):
800            raise TypeError("Handler must be a callable. Got '{}'".format(operation))
801        if bool(operation) == bool(handler):
802            raise TypeError("Must specify exactly one of either 'operation' or 'handler'")
803
804        name = ' '.join(name.split())
805
806        client_factory = kwargs.get('client_factory', None)
807
808        def default_command_handler(command_args):
809            from azure.cli.core.util import get_arg_list, augment_no_wait_handler_args
810            from azure.cli.core.commands.client_factory import resolve_client_arg_name
811
812            op = handler or self.get_op_handler(operation, operation_group=kwargs.get('operation_group'))
813            op_args = get_arg_list(op)
814            cmd = command_args.get('cmd') if 'cmd' in op_args else command_args.pop('cmd')
815
816            client = client_factory(cmd.cli_ctx, command_args) if client_factory else None
817            supports_no_wait = kwargs.get('supports_no_wait', None)
818            if supports_no_wait:
819                no_wait_enabled = command_args.pop('no_wait', False)
820                augment_no_wait_handler_args(no_wait_enabled, op, command_args)
821            if client:
822                client_arg_name = resolve_client_arg_name(operation, kwargs)
823                if client_arg_name in op_args:
824                    command_args[client_arg_name] = client
825            return op(**command_args)
826
827        def default_arguments_loader():
828            op = handler or self.get_op_handler(operation, operation_group=kwargs.get('operation_group'))
829            self._apply_doc_string(op, kwargs)
830            cmd_args = list(extract_args_from_signature(op, excluded_params=self.excluded_command_handler_args))
831            return cmd_args
832
833        def default_description_loader():
834            op = handler or self.get_op_handler(operation, operation_group=kwargs.get('operation_group'))
835            self._apply_doc_string(op, kwargs)
836            return extract_full_summary_from_signature(op)
837
838        kwargs['arguments_loader'] = argument_loader or default_arguments_loader
839        kwargs['description_loader'] = description_loader or default_description_loader
840
841        if self.supported_api_version(resource_type=kwargs.get('resource_type'),
842                                      min_api=kwargs.get('min_api'),
843                                      max_api=kwargs.get('max_api'),
844                                      operation_group=kwargs.get('operation_group')):
845            self._populate_command_group_table_with_subgroups(' '.join(name.split()[:-1]))
846            self.command_table[name] = self.command_cls(self, name,
847                                                        handler or default_command_handler,
848                                                        **kwargs)
849
850    def add_cli_command(self, name, command_operation, **kwargs):
851        """Register a command in command_table with command operation provided"""
852        from knack.deprecation import Deprecated
853        from .commands.command_operation import BaseCommandOperation
854        if not issubclass(type(command_operation), BaseCommandOperation):
855            raise TypeError("CommandOperation must be an instance of subclass of BaseCommandOperation."
856                            " Got instance of '{}'".format(type(command_operation)))
857
858        kwargs['deprecate_info'] = Deprecated.ensure_new_style_deprecation(self.cli_ctx, kwargs, 'command')
859
860        name = ' '.join(name.split())
861
862        if self.supported_api_version(resource_type=kwargs.get('resource_type'),
863                                      min_api=kwargs.get('min_api'),
864                                      max_api=kwargs.get('max_api'),
865                                      operation_group=kwargs.get('operation_group')):
866            self._populate_command_group_table_with_subgroups(' '.join(name.split()[:-1]))
867            self.command_table[name] = self.command_cls(loader=self,
868                                                        name=name,
869                                                        handler=command_operation.handler,
870                                                        arguments_loader=command_operation.arguments_loader,
871                                                        description_loader=command_operation.description_loader,
872                                                        command_operation=command_operation,
873                                                        **kwargs)
874
875    def get_op_handler(self, operation, operation_group=None):
876        """ Import and load the operation handler """
877        # Patch the unversioned sdk path to include the appropriate API version for the
878        # resource type in question.
879        from importlib import import_module
880        import types
881
882        from azure.cli.core.profiles import AZURE_API_PROFILES
883        from azure.cli.core.profiles._shared import get_versioned_sdk_path
884
885        for rt in AZURE_API_PROFILES[self.cli_ctx.cloud.profile]:
886            if operation.startswith(rt.import_prefix + '.'):
887                operation = operation.replace(rt.import_prefix,
888                                              get_versioned_sdk_path(self.cli_ctx.cloud.profile, rt,
889                                                                     operation_group=operation_group))
890                break
891
892        try:
893            mod_to_import, attr_path = operation.split('#')
894            op = import_module(mod_to_import)
895            for part in attr_path.split('.'):
896                op = getattr(op, part)
897            if isinstance(op, types.FunctionType):
898                return op
899            return op.__func__
900        except (ValueError, AttributeError):
901            raise ValueError("The operation '{}' is invalid.".format(operation))
902
903
904def get_default_cli():
905    from azure.cli.core.azlogging import AzCliLogging
906    from azure.cli.core.commands import AzCliCommandInvoker
907    from azure.cli.core.parser import AzCliCommandParser
908    from azure.cli.core._config import GLOBAL_CONFIG_DIR, ENV_VAR_PREFIX
909    from azure.cli.core._help import AzCliHelp
910    from azure.cli.core._output import AzOutputProducer
911
912    return AzCli(cli_name='az',
913                 config_dir=GLOBAL_CONFIG_DIR,
914                 config_env_var_prefix=ENV_VAR_PREFIX,
915                 commands_loader_cls=MainCommandsLoader,
916                 invocation_cls=AzCliCommandInvoker,
917                 parser_cls=AzCliCommandParser,
918                 logging_cls=AzCliLogging,
919                 output_cls=AzOutputProducer,
920                 help_cls=AzCliHelp)
921