1# --------------------------------------------------------------------------------------------
2# Copyright (c) Microsoft Corporation. All rights reserved.
3# Licensed under the MIT License. See License.txt in the project root for license information.
4# --------------------------------------------------------------------------------------------
5
6import abc
7import os
8
9from azure.cli.core._help import (HelpExample, CliHelpFile)
10
11from knack.util import CLIError
12from knack.log import get_logger
13
14import yaml
15
16logger = get_logger(__name__)
17
18try:
19    ABC = abc.ABC
20except AttributeError:  # Python 2.7, abc exists, but not ABC
21    ABC = abc.ABCMeta('ABC', (object,), {'__slots__': ()})
22
23
24# BaseHelpLoader defining versioned loader interface. Also contains some helper methods.
25class BaseHelpLoader(ABC):
26    def __init__(self, help_ctx=None):
27        self.help_ctx = help_ctx
28        self._entry_data = None
29        self._file_content_dict = {}
30
31    def versioned_load(self, help_obj, parser):
32        if not self._file_content_dict:
33            return
34        self._entry_data = None
35        # Cycle through versioned_load helpers
36        self.load_entry_data(help_obj, parser)
37        if self._data_is_applicable():
38            self.load_help_body(help_obj)
39            self.load_help_parameters(help_obj)
40            self.load_help_examples(help_obj)
41        self._entry_data = None
42
43    def update_file_contents(self, file_contents):
44        self._file_content_dict.update(file_contents)
45
46    @abc.abstractmethod
47    def get_noun_help_file_names(self, nouns):
48        pass
49
50    @property
51    @abc.abstractmethod
52    def version(self):
53        pass
54
55    def _data_is_applicable(self):
56        return self._entry_data and self.version == self._entry_data.get('version')
57
58    @abc.abstractmethod
59    def load_entry_data(self, help_obj, parser):
60        pass
61
62    @abc.abstractmethod
63    def load_help_body(self, help_obj):
64        pass
65
66    @abc.abstractmethod
67    def load_help_parameters(self, help_obj):
68        pass
69
70    @abc.abstractmethod
71    def load_help_examples(self, help_obj):
72        pass
73
74    # Loader static helper methods
75
76    # Update a help file object from a data dict using the attribute to key mapping
77    @staticmethod
78    def _update_obj_from_data_dict(obj, data, attr_key_tups):
79        for attr, key in attr_key_tups:
80            try:
81                setattr(obj, attr, data[key] or attr)
82            except (AttributeError, KeyError):
83                pass
84
85    # update relevant help file object parameters from data.
86    @staticmethod
87    def _update_help_obj_params(help_obj, data_params, params_equal, attr_key_tups):
88        loaded_params = []
89        for param_obj in help_obj.parameters:
90            loaded_param = next((n for n in data_params if params_equal(param_obj, n)), None)
91            if loaded_param:
92                BaseHelpLoader._update_obj_from_data_dict(param_obj, loaded_param, attr_key_tups)
93            loaded_params.append(param_obj)
94        help_obj.parameters = loaded_params
95
96
97class YamlLoaderMixin:  # pylint:disable=too-few-public-methods
98    """A class containing helper methods for Yaml Loaders."""
99
100    # get the list of yaml help file names for the command or group
101    @staticmethod
102    def _get_yaml_help_files_list(nouns, cmd_loader_map_ref):
103        import inspect
104
105        command_nouns = " ".join(nouns)
106        # if command in map, get the loader. Path of loader is path of helpfile.
107        ldr_or_none = cmd_loader_map_ref.get(command_nouns, [None])[0]
108        if ldr_or_none:
109            loaders = {ldr_or_none}
110        else:
111            loaders = set()
112
113        # otherwise likely a group, try to find all command loaders under group as the group help could be defined
114        # in either.
115        if not loaders:
116            for cmd_name, cmd_ldr in cmd_loader_map_ref.items():
117                # if first word in loader name is the group, this is a command in the command group
118                if cmd_name.startswith(command_nouns + " "):
119                    loaders.add(cmd_ldr[0])
120
121        results = []
122        if loaders:
123            for loader in loaders:
124                loader_file_path = inspect.getfile(loader.__class__)
125                dir_name = os.path.dirname(loader_file_path)
126                files = os.listdir(dir_name)
127                for file in files:
128                    if file.endswith("help.yaml") or file.endswith("help.yml"):
129                        help_file_path = os.path.join(dir_name, file)
130                        results.append(help_file_path)
131        return results
132
133    @staticmethod
134    def _parse_yaml_from_string(text, help_file_path):
135        dir_name, base_name = os.path.split(help_file_path)
136        pretty_file_path = os.path.join(os.path.basename(dir_name), base_name)
137
138        if not text:
139            raise CLIError("No content passed for {}.".format(pretty_file_path))
140
141        try:
142            return yaml.safe_load(text)
143        except yaml.YAMLError as e:
144            raise CLIError("Error parsing {}:\n\n{}".format(pretty_file_path, e))
145
146
147class HelpLoaderV0(BaseHelpLoader):
148
149    @property
150    def version(self):
151        return 0
152
153    def versioned_load(self, help_obj, parser):
154        super(CliHelpFile, help_obj).load(parser)  # pylint:disable=bad-super-call
155
156    def get_noun_help_file_names(self, nouns):
157        pass
158
159    def load_entry_data(self, help_obj, parser):
160        pass
161
162    def load_help_body(self, help_obj):
163        pass
164
165    def load_help_parameters(self, help_obj):
166        pass
167
168    def load_help_examples(self, help_obj):
169        pass
170
171
172class HelpLoaderV1(BaseHelpLoader, YamlLoaderMixin):
173    core_attrs_to_keys = [("short_summary", "summary"), ("long_summary", "description")]
174    body_attrs_to_keys = core_attrs_to_keys + [("links", "links")]
175    param_attrs_to_keys = core_attrs_to_keys + [("value_sources", "value-sources")]
176
177    @property
178    def version(self):
179        return 1
180
181    def get_noun_help_file_names(self, nouns):
182        cmd_loader_map_ref = self.help_ctx.cli_ctx.invocation.commands_loader.cmd_to_loader_map
183        return self._get_yaml_help_files_list(nouns, cmd_loader_map_ref)
184
185    def update_file_contents(self, file_contents):
186        for file_name in file_contents:
187            if file_name not in self._file_content_dict:
188                data_dict = {file_name: self._parse_yaml_from_string(file_contents[file_name], file_name)}
189                self._file_content_dict.update(data_dict)
190
191    def load_entry_data(self, help_obj, parser):
192        prog = parser.prog if hasattr(parser, "prog") else parser._prog_prefix  # pylint: disable=protected-access
193        command_nouns = prog.split()[1:]
194        cmd_loader_map_ref = self.help_ctx.cli_ctx.invocation.commands_loader.cmd_to_loader_map
195
196        files_list = self._get_yaml_help_files_list(command_nouns, cmd_loader_map_ref)
197        data_list = [self._file_content_dict[name] for name in files_list]
198
199        self._entry_data = self._get_entry_data(help_obj.command, data_list)
200
201    def load_help_body(self, help_obj):
202        help_obj.long_summary = ""  # similar to knack...
203        self._update_obj_from_data_dict(help_obj, self._entry_data, self.body_attrs_to_keys)
204
205    def load_help_parameters(self, help_obj):
206        def params_equal(param, param_dict):
207            if param_dict['name'].startswith("--"):  # for optionals, help file name must be one of the  long options
208                return param_dict['name'] in param.name.split()
209            # for positionals, help file must name must match param name shown when -h is run
210            return param_dict['name'] == param.name
211
212        if help_obj.type == "command" and hasattr(help_obj, "parameters") and self._entry_data.get("arguments"):
213            loaded_params = []
214            for param_obj in help_obj.parameters:
215                loaded_param = next((n for n in self._entry_data["arguments"] if params_equal(param_obj, n)), None)
216                if loaded_param:
217                    self._update_obj_from_data_dict(param_obj, loaded_param, self.param_attrs_to_keys)
218                loaded_params.append(param_obj)
219            help_obj.parameters = loaded_params
220
221    def load_help_examples(self, help_obj):
222        if help_obj.type == "command" and self._entry_data.get("examples"):
223            help_obj.examples = [HelpExample(**ex) for ex in self._entry_data["examples"] if help_obj._should_include_example(ex)]  # pylint: disable=line-too-long, protected-access
224
225    @staticmethod
226    def _get_entry_data(cmd_name, data_list):
227        for data in data_list:
228            if data and data.get("content"):
229                try:
230                    entry_data = next(value for elem in data.get("content")
231                                      for key, value in elem.items() if value.get("name") == cmd_name)
232                    entry_data["version"] = data['version']
233                    return entry_data
234                except StopIteration:
235                    continue
236        return None
237