1""" 2 :codeauthor: Pedro Algarvio (pedro@algarvio.me) 3 :codeauthor: Thomas Jackson (jacksontj.89@gmail.com) 4 5 6 salt.utils.context 7 ~~~~~~~~~~~~~~~~~~ 8 9 Context managers used throughout Salt's source code. 10""" 11 12import copy 13import threading 14from collections.abc import MutableMapping 15from contextlib import contextmanager 16 17 18@contextmanager 19def func_globals_inject(func, **overrides): 20 """ 21 Override specific variables within a function's global context. 22 """ 23 # recognize methods 24 if hasattr(func, "im_func") and func.im_func: 25 func = func.__func__ 26 27 # Get a reference to the function globals dictionary 28 func_globals = func.__globals__ 29 # Save the current function globals dictionary state values for the 30 # overridden objects 31 injected_func_globals = [] 32 overridden_func_globals = {} 33 for override in overrides: 34 if override in func_globals: 35 overridden_func_globals[override] = func_globals[override] 36 else: 37 injected_func_globals.append(override) 38 39 # Override the function globals with what's passed in the above overrides 40 func_globals.update(overrides) 41 42 # The context is now ready to be used 43 try: 44 yield 45 finally: 46 # We're now done with the context 47 48 # Restore the overwritten function globals 49 func_globals.update(overridden_func_globals) 50 51 # Remove any entry injected in the function globals 52 for injected in injected_func_globals: 53 del func_globals[injected] 54 55 56class ContextDict(MutableMapping): 57 """ 58 A context manager that saves some per-thread state globally. 59 Intended for use with Tornado's StackContext. 60 61 Provide arbitrary data as kwargs upon creation, 62 then allow any children to override the values of the parent. 63 """ 64 65 def __init__(self, threadsafe=False, **data): 66 # state should be thread local, so this object can be threadsafe 67 self._state = threading.local() 68 # variable for the overridden data 69 self._state.data = None 70 self.global_data = {} 71 # Threadsafety indicates whether or not we should protect data stored 72 # in child context dicts from being leaked 73 self._threadsafe = threadsafe 74 75 @property 76 def active(self): 77 """Determine if this ContextDict is currently overridden 78 Since the ContextDict can be overridden in each thread, we check whether 79 the _state.data is set or not. 80 """ 81 try: 82 return self._state.data is not None 83 except AttributeError: 84 return False 85 86 # TODO: rename? 87 def clone(self, **kwargs): 88 """ 89 Clone this context, and return the ChildContextDict 90 """ 91 child = ChildContextDict( 92 parent=self, threadsafe=self._threadsafe, overrides=kwargs 93 ) 94 return child 95 96 def __setitem__(self, key, val): 97 if self.active: 98 self._state.data[key] = val 99 else: 100 self.global_data[key] = val 101 102 def __delitem__(self, key): 103 if self.active: 104 del self._state.data[key] 105 else: 106 del self.global_data[key] 107 108 def __getitem__(self, key): 109 if self.active: 110 return self._state.data[key] 111 else: 112 return self.global_data[key] 113 114 def __len__(self): 115 if self.active: 116 return len(self._state.data) 117 else: 118 return len(self.global_data) 119 120 def __iter__(self): 121 if self.active: 122 return iter(self._state.data) 123 else: 124 return iter(self.global_data) 125 126 def __copy__(self): 127 new_obj = type(self)(threadsafe=self._threadsafe) 128 if self.active: 129 new_obj.global_data = copy.copy(self._state.data) 130 else: 131 new_obj.global_data = copy.copy(self.global_data) 132 return new_obj 133 134 def __deepcopy__(self, memo): 135 new_obj = type(self)(threadsafe=self._threadsafe) 136 if self.active: 137 new_obj.global_data = copy.deepcopy(self._state.data, memo) 138 else: 139 new_obj.global_data = copy.deepcopy(self.global_data, memo) 140 return new_obj 141 142 143class ChildContextDict(MutableMapping): 144 """An overrideable child of ContextDict""" 145 146 def __init__(self, parent, overrides=None, threadsafe=False): 147 self.parent = parent 148 self._data = {} if overrides is None else overrides 149 self._old_data = None 150 151 # merge self.global_data into self._data 152 if threadsafe: 153 for k, v in self.parent.global_data.items(): 154 if k not in self._data: 155 # A deepcopy is necessary to avoid using the same 156 # objects in globals as we do in thread local storage. 157 # Otherwise, changing one would automatically affect 158 # the other. 159 self._data[k] = copy.deepcopy(v) 160 else: 161 for k, v in self.parent.global_data.items(): 162 if k not in self._data: 163 self._data[k] = v 164 165 def __setitem__(self, key, val): 166 self._data[key] = val 167 168 def __delitem__(self, key): 169 del self._data[key] 170 171 def __getitem__(self, key): 172 return self._data[key] 173 174 def __len__(self): 175 return len(self._data) 176 177 def __iter__(self): 178 return iter(self._data) 179 180 def __enter__(self): 181 if hasattr(self.parent._state, "data"): 182 # Save old data to support nested calls 183 self._old_data = self.parent._state.data 184 self.parent._state.data = self._data 185 186 def __exit__(self, *exc): 187 self.parent._state.data = self._old_data 188 189 190class NamespacedDictWrapper(MutableMapping, dict): 191 """ 192 Create a dict which wraps another dict with a specific prefix of key(s) 193 194 MUST inherit from dict to serialize through msgpack correctly 195 """ 196 197 def __init__(self, d, pre_keys): # pylint: disable=W0231 198 self.__dict = d 199 if isinstance(pre_keys, str): 200 self.pre_keys = (pre_keys,) 201 else: 202 self.pre_keys = pre_keys 203 super().__init__(self._dict()) 204 205 def _dict(self): 206 r = self.__dict 207 for k in self.pre_keys: 208 r = r[k] 209 return r 210 211 def __repr__(self): 212 return repr(self._dict()) 213 214 def __setitem__(self, key, val): 215 self._dict()[key] = val 216 217 def __delitem__(self, key): 218 del self._dict()[key] 219 220 def __getitem__(self, key): 221 return self._dict()[key] 222 223 def __len__(self): 224 return len(self._dict()) 225 226 def __iter__(self): 227 return iter(self._dict()) 228 229 def __copy__(self): 230 return type(self)(copy.copy(self.__dict), copy.copy(self.pre_keys)) 231 232 def __deepcopy__(self, memo): 233 return type(self)( 234 copy.deepcopy(self.__dict, memo), copy.deepcopy(self.pre_keys, memo) 235 ) 236 237 def __str__(self): 238 return self._dict().__str__() 239