1# --------------------------------------------------------------------------------------------
2# Copyright (c) Microsoft Corporation. All rights reserved.
3# Licensed under the MIT License. See License.txt in the project root for license information.
4# --------------------------------------------------------------------------------------------
5
6import os
7import shutil
8import configparser
9import enum
10
11from knack.log import get_logger
12from knack.config import _ConfigFile
13
14ALL = 'all'   # effective level of local context, ALL means all commands can share this parameter value
15LOCAL_CONTEXT_FILE = '.local_context_{}'  # each user has a separate file, an example is .local_context_username
16LOCAL_CONTEXT_ON_OFF_CONFIG_SECTION = 'local_context'
17LOCAL_CONTEXT_NOTICE = '; This file is used to store persistent parameters.\n'\
18                       '; DO NOT modify it manually unless you know it well.\n'
19logger = get_logger(__name__)
20
21
22class LocalContextAction(enum.Enum):
23    SET = 1   # action for a parameter in local context, SET means its value will be saved to local context
24    GET = 2   # action for a parameter in local context, GET means will read value from local context for this parameter
25
26
27def _get_current_system_username():
28    try:
29        import getpass
30        return getpass.getuser()
31    except Exception:  # pylint: disable=broad-except
32        pass
33    return None
34
35
36class AzCLILocalContext:  # pylint: disable=too-many-instance-attributes
37
38    def __init__(self, cli_ctx):
39        self.cli_ctx = cli_ctx
40        self.config = cli_ctx.config
41        self.dir_name = os.path.basename(self.config.config_dir)
42        self.username = None
43        self.is_on = False
44        self.current_dir = None
45
46        # only used in get/set/effective_working_directory function, to avoid calling load files to many times.
47        self._local_context_file = None
48
49        self.initialize()
50
51    def initialize(self):
52        self.username = _get_current_system_username()
53        self.is_on = self.config.getboolean(LOCAL_CONTEXT_ON_OFF_CONFIG_SECTION, self.username, False) \
54            if self.username else False
55
56        try:
57            self.current_dir = os.getcwd()
58        except FileNotFoundError:
59            if self.is_on:
60                logger.warning('The working directory has been deleted or recreated. Parameter persistence is ignored.')
61
62        if self.is_on:
63            self._local_context_file = self._get_local_context_file()
64
65    def _get_local_context_file_name(self):
66        return LOCAL_CONTEXT_FILE.format(self.username)
67
68    def _load_local_context_files(self, recursive=False):
69        local_context_files = []
70        if self.username and self.current_dir:
71            current_dir = self.current_dir
72            while current_dir:
73                dir_path = os.path.join(current_dir, self.dir_name)
74                file_path = os.path.join(dir_path, self._get_local_context_file_name())
75                if os.path.isfile(file_path) and os.access(file_path, os.R_OK) and os.access(file_path, os.W_OK):
76                    local_context_files.append(_ConfigFile(dir_path, file_path, LOCAL_CONTEXT_NOTICE))
77                    if not recursive:
78                        break   # load only one local context
79                # Stop if already in root drive
80                if current_dir == os.path.dirname(current_dir):
81                    break
82                current_dir = os.path.dirname(current_dir)
83        return local_context_files
84
85    def _get_local_context_file(self):
86        local_context_files = self._load_local_context_files(recursive=False)
87        if len(local_context_files) == 1:
88            return local_context_files[0]
89        return None
90
91    def effective_working_directory(self):
92        return os.path.dirname(self._local_context_file.config_dir) if self._local_context_file else ''
93
94    def get(self, command, argument):
95        if self.is_on and self._local_context_file:
96            command_parts = command.split()
97            while True:
98                section = ' '.join(command_parts) if command_parts else ALL
99                try:
100                    return self._local_context_file.get(section.lower(), argument)
101                except (configparser.NoSectionError, configparser.NoOptionError):
102                    pass
103                if not command_parts:
104                    break
105                command_parts = command_parts[:-1]
106        return None
107
108    def set(self, scopes, argument, value):
109        if self.is_on and self.username and self.current_dir:
110            if self._local_context_file is None:
111                file_path = os.path.join(self.current_dir, self.dir_name, self._get_local_context_file_name())
112                dir_path = os.path.join(self.current_dir, self.dir_name)
113                self._local_context_file = _ConfigFile(dir_path, file_path, LOCAL_CONTEXT_NOTICE)
114
115            for scope in scopes:
116                self._local_context_file.set_value(scope.lower(), argument, value)
117
118    def turn_on(self):
119        self.config.set_value(LOCAL_CONTEXT_ON_OFF_CONFIG_SECTION, self.username, 'on')
120        self.is_on = self.config.getboolean(LOCAL_CONTEXT_ON_OFF_CONFIG_SECTION, self.username, False)
121        self._local_context_file = self._get_local_context_file()
122
123    def turn_off(self):
124        self.config.remove_option(LOCAL_CONTEXT_ON_OFF_CONFIG_SECTION, self.username)
125        self.is_on = self.config.getboolean(LOCAL_CONTEXT_ON_OFF_CONFIG_SECTION, self.username, False)
126        self._local_context_file = None
127
128    def delete_file(self, recursive=False):
129        local_context_files = self._load_local_context_files(recursive=recursive)
130        for local_context_file in local_context_files:
131            try:
132                os.remove(local_context_file.config_path)
133                parent_dir = os.path.dirname(local_context_file.config_path)
134                if not os.listdir(parent_dir):
135                    shutil.rmtree(parent_dir)
136                logger.warning('Parameter persistence file in working directory %s is deleted.',
137                               os.path.dirname(local_context_file.config_dir))
138            except Exception:  # pylint: disable=broad-except
139                logger.warning('Fail to delete parameter persistence file in working directory %s',
140                               os.path.dirname(local_context_file.config_dir))
141
142    def clear(self, recursive=False):
143        local_context_files = self._load_local_context_files(recursive=recursive)
144        for local_context_file in local_context_files:
145            local_context_file.clear()
146            logger.warning('Parameter persistence information in working directory %s is cleared.',
147                           os.path.dirname(local_context_file.config_dir))
148
149    def delete(self, names=None):
150        local_context_file = self._get_local_context_file()
151        if local_context_file:
152            for scope in local_context_file.sections():
153                for name in names:
154                    local_context_file.remove_option(scope, name)
155        logger.warning('Parameter persistence value is deleted. You can run `az config param-persist show` to show all '
156                       'available values.')
157
158    def get_value(self, names=None):
159        result = {}
160
161        local_context_file = self._get_local_context_file()
162        if not local_context_file:
163            return result
164
165        for scope in local_context_file.sections():
166            try:
167                if names is None:
168                    for name, value in local_context_file.items(scope):  # may raise NoSectionError
169                        if scope not in result:
170                            result[scope] = {}
171                        result[scope][name] = value
172                else:
173                    for name in names:
174                        value = local_context_file.get(scope, name)  # may raise NoOptionError
175                        if scope not in result:
176                            result[scope] = {}
177                        result[scope][name] = value
178            except (configparser.NoSectionError, configparser.NoOptionError):
179                pass
180        return result
181
182
183class LocalContextAttribute:
184    # pylint: disable=too-few-public-methods
185    def __init__(self, name, actions, scopes=None):
186        """ Local Context Attribute arguments
187
188        :param name: Argument name in local context. Make sure it is consistent for SET and GET.
189        :type name: str
190        :param actions: Which action should be taken for local context. Allowed values: SET, GET
191        :type actions: list
192        :param scopes: The effective commands or command groups of this argument when saved to local context.
193        :type scopes: list
194        """
195        self.name = name
196
197        if isinstance(actions, str):
198            actions = [actions]
199        self.actions = actions
200
201        if isinstance(scopes, str):
202            scopes = [scopes]
203        if scopes is None and LocalContextAction.SET in actions:
204            scopes = [ALL]
205        self.scopes = scopes
206