1"""
2    :codeauthor: Pedro Algarvio (pedro@algarvio.me)
3    :codeauthor: Thomas Jackson (jacksontj.89@gmail.com)
4
5
6    salt.utils.context
7    ~~~~~~~~~~~~~~~~~~
8
9    Context managers used throughout Salt's source code.
10"""
11
12import copy
13import threading
14from collections.abc import MutableMapping
15from contextlib import contextmanager
16
17
18@contextmanager
19def func_globals_inject(func, **overrides):
20    """
21    Override specific variables within a function's global context.
22    """
23    # recognize methods
24    if hasattr(func, "im_func") and func.im_func:
25        func = func.__func__
26
27    # Get a reference to the function globals dictionary
28    func_globals = func.__globals__
29    # Save the current function globals dictionary state values for the
30    # overridden objects
31    injected_func_globals = []
32    overridden_func_globals = {}
33    for override in overrides:
34        if override in func_globals:
35            overridden_func_globals[override] = func_globals[override]
36        else:
37            injected_func_globals.append(override)
38
39    # Override the function globals with what's passed in the above overrides
40    func_globals.update(overrides)
41
42    # The context is now ready to be used
43    try:
44        yield
45    finally:
46        # We're now done with the context
47
48        # Restore the overwritten function globals
49        func_globals.update(overridden_func_globals)
50
51        # Remove any entry injected in the function globals
52        for injected in injected_func_globals:
53            del func_globals[injected]
54
55
56class ContextDict(MutableMapping):
57    """
58    A context manager that saves some per-thread state globally.
59    Intended for use with Tornado's StackContext.
60
61    Provide arbitrary data as kwargs upon creation,
62    then allow any children to override the values of the parent.
63    """
64
65    def __init__(self, threadsafe=False, **data):
66        # state should be thread local, so this object can be threadsafe
67        self._state = threading.local()
68        # variable for the overridden data
69        self._state.data = None
70        self.global_data = {}
71        # Threadsafety indicates whether or not we should protect data stored
72        # in child context dicts from being leaked
73        self._threadsafe = threadsafe
74
75    @property
76    def active(self):
77        """Determine if this ContextDict is currently overridden
78        Since the ContextDict can be overridden in each thread, we check whether
79        the _state.data is set or not.
80        """
81        try:
82            return self._state.data is not None
83        except AttributeError:
84            return False
85
86    # TODO: rename?
87    def clone(self, **kwargs):
88        """
89        Clone this context, and return the ChildContextDict
90        """
91        child = ChildContextDict(
92            parent=self, threadsafe=self._threadsafe, overrides=kwargs
93        )
94        return child
95
96    def __setitem__(self, key, val):
97        if self.active:
98            self._state.data[key] = val
99        else:
100            self.global_data[key] = val
101
102    def __delitem__(self, key):
103        if self.active:
104            del self._state.data[key]
105        else:
106            del self.global_data[key]
107
108    def __getitem__(self, key):
109        if self.active:
110            return self._state.data[key]
111        else:
112            return self.global_data[key]
113
114    def __len__(self):
115        if self.active:
116            return len(self._state.data)
117        else:
118            return len(self.global_data)
119
120    def __iter__(self):
121        if self.active:
122            return iter(self._state.data)
123        else:
124            return iter(self.global_data)
125
126    def __copy__(self):
127        new_obj = type(self)(threadsafe=self._threadsafe)
128        if self.active:
129            new_obj.global_data = copy.copy(self._state.data)
130        else:
131            new_obj.global_data = copy.copy(self.global_data)
132        return new_obj
133
134    def __deepcopy__(self, memo):
135        new_obj = type(self)(threadsafe=self._threadsafe)
136        if self.active:
137            new_obj.global_data = copy.deepcopy(self._state.data, memo)
138        else:
139            new_obj.global_data = copy.deepcopy(self.global_data, memo)
140        return new_obj
141
142
143class ChildContextDict(MutableMapping):
144    """An overrideable child of ContextDict"""
145
146    def __init__(self, parent, overrides=None, threadsafe=False):
147        self.parent = parent
148        self._data = {} if overrides is None else overrides
149        self._old_data = None
150
151        # merge self.global_data into self._data
152        if threadsafe:
153            for k, v in self.parent.global_data.items():
154                if k not in self._data:
155                    # A deepcopy is necessary to avoid using the same
156                    # objects in globals as we do in thread local storage.
157                    # Otherwise, changing one would automatically affect
158                    # the other.
159                    self._data[k] = copy.deepcopy(v)
160        else:
161            for k, v in self.parent.global_data.items():
162                if k not in self._data:
163                    self._data[k] = v
164
165    def __setitem__(self, key, val):
166        self._data[key] = val
167
168    def __delitem__(self, key):
169        del self._data[key]
170
171    def __getitem__(self, key):
172        return self._data[key]
173
174    def __len__(self):
175        return len(self._data)
176
177    def __iter__(self):
178        return iter(self._data)
179
180    def __enter__(self):
181        if hasattr(self.parent._state, "data"):
182            # Save old data to support nested calls
183            self._old_data = self.parent._state.data
184        self.parent._state.data = self._data
185
186    def __exit__(self, *exc):
187        self.parent._state.data = self._old_data
188
189
190class NamespacedDictWrapper(MutableMapping, dict):
191    """
192    Create a dict which wraps another dict with a specific prefix of key(s)
193
194    MUST inherit from dict to serialize through msgpack correctly
195    """
196
197    def __init__(self, d, pre_keys):  # pylint: disable=W0231
198        self.__dict = d
199        if isinstance(pre_keys, str):
200            self.pre_keys = (pre_keys,)
201        else:
202            self.pre_keys = pre_keys
203        super().__init__(self._dict())
204
205    def _dict(self):
206        r = self.__dict
207        for k in self.pre_keys:
208            r = r[k]
209        return r
210
211    def __repr__(self):
212        return repr(self._dict())
213
214    def __setitem__(self, key, val):
215        self._dict()[key] = val
216
217    def __delitem__(self, key):
218        del self._dict()[key]
219
220    def __getitem__(self, key):
221        return self._dict()[key]
222
223    def __len__(self):
224        return len(self._dict())
225
226    def __iter__(self):
227        return iter(self._dict())
228
229    def __copy__(self):
230        return type(self)(copy.copy(self.__dict), copy.copy(self.pre_keys))
231
232    def __deepcopy__(self, memo):
233        return type(self)(
234            copy.deepcopy(self.__dict, memo), copy.deepcopy(self.pre_keys, memo)
235        )
236
237    def __str__(self):
238        return self._dict().__str__()
239