1""" 2Alex Martelli's soulution for recursive dict update from 3http://stackoverflow.com/a/3233356 4""" 5 6import copy 7import logging 8from collections.abc import Mapping 9 10import salt.utils.data 11from salt.defaults import DEFAULT_TARGET_DELIM 12from salt.exceptions import SaltInvocationError 13from salt.utils.decorators.jinja import jinja_filter 14from salt.utils.odict import OrderedDict 15 16log = logging.getLogger(__name__) 17 18 19def update(dest, upd, recursive_update=True, merge_lists=False): 20 """ 21 Recursive version of the default dict.update 22 23 Merges upd recursively into dest 24 25 If recursive_update=False, will use the classic dict.update, or fall back 26 on a manual merge (helpful for non-dict types like FunctionWrapper) 27 28 If merge_lists=True, will aggregate list object types instead of replace. 29 The list in ``upd`` is added to the list in ``dest``, so the resulting list 30 is ``dest[key] + upd[key]``. This behavior is only activated when 31 recursive_update=True. By default merge_lists=False. 32 33 .. versionchanged:: 2016.11.6 34 When merging lists, duplicate values are removed. Values already 35 present in the ``dest`` list are not added from the ``upd`` list. 36 """ 37 if (not isinstance(dest, Mapping)) or (not isinstance(upd, Mapping)): 38 raise TypeError("Cannot update using non-dict types in dictupdate.update()") 39 updkeys = list(upd.keys()) 40 if not set(list(dest.keys())) & set(updkeys): 41 recursive_update = False 42 if recursive_update: 43 for key in updkeys: 44 val = upd[key] 45 try: 46 dest_subkey = dest.get(key, None) 47 except AttributeError: 48 dest_subkey = None 49 if isinstance(dest_subkey, Mapping) and isinstance(val, Mapping): 50 ret = update(dest_subkey, val, merge_lists=merge_lists) 51 dest[key] = ret 52 elif isinstance(dest_subkey, list) and isinstance(val, list): 53 if merge_lists: 54 merged = copy.deepcopy(dest_subkey) 55 merged.extend([x for x in val if x not in merged]) 56 dest[key] = merged 57 else: 58 dest[key] = upd[key] 59 else: 60 dest[key] = upd[key] 61 return dest 62 for k in upd: 63 dest[k] = upd[k] 64 return dest 65 66 67def merge_list(obj_a, obj_b): 68 ret = {} 69 for key, val in obj_a.items(): 70 if key in obj_b: 71 ret[key] = [val, obj_b[key]] 72 else: 73 ret[key] = val 74 return ret 75 76 77def merge_recurse(obj_a, obj_b, merge_lists=False): 78 copied = copy.deepcopy(obj_a) 79 return update(copied, obj_b, merge_lists=merge_lists) 80 81 82def merge_aggregate(obj_a, obj_b): 83 from salt.serializers.yamlex import merge_recursive as _yamlex_merge_recursive 84 85 return _yamlex_merge_recursive(obj_a, obj_b, level=1) 86 87 88def merge_overwrite(obj_a, obj_b, merge_lists=False): 89 for obj in obj_b: 90 if obj in obj_a: 91 obj_a[obj] = obj_b[obj] 92 return merge_recurse(obj_a, obj_b, merge_lists=merge_lists) 93 94 95def merge(obj_a, obj_b, strategy="smart", renderer="yaml", merge_lists=False): 96 if strategy == "smart": 97 if renderer.split("|")[-1] == "yamlex" or renderer.startswith("yamlex_"): 98 strategy = "aggregate" 99 else: 100 strategy = "recurse" 101 102 if strategy == "list": 103 merged = merge_list(obj_a, obj_b) 104 elif strategy == "recurse": 105 merged = merge_recurse(obj_a, obj_b, merge_lists) 106 elif strategy == "aggregate": 107 #: level = 1 merge at least root data 108 merged = merge_aggregate(obj_a, obj_b) 109 elif strategy == "overwrite": 110 merged = merge_overwrite(obj_a, obj_b, merge_lists) 111 elif strategy == "none": 112 # If we do not want to merge, there is only one pillar passed, so we can safely use the default recurse, 113 # we just do not want to log an error 114 merged = merge_recurse(obj_a, obj_b) 115 else: 116 log.warning("Unknown merging strategy '%s', fallback to recurse", strategy) 117 merged = merge_recurse(obj_a, obj_b) 118 119 return merged 120 121 122def ensure_dict_key(in_dict, keys, delimiter=DEFAULT_TARGET_DELIM, ordered_dict=False): 123 """ 124 Ensures that in_dict contains the series of recursive keys defined in keys. 125 126 :param dict in_dict: The dict to work with. 127 :param str keys: The delimited string with one or more keys. 128 :param str delimiter: The delimiter to use in `keys`. Defaults to ':'. 129 :param bool ordered_dict: Create OrderedDicts if keys are missing. 130 Default: create regular dicts. 131 :rtype: dict 132 :return: Returns the modified in-place `in_dict`. 133 """ 134 if delimiter in keys: 135 a_keys = keys.split(delimiter) 136 else: 137 a_keys = [keys] 138 dict_pointer = in_dict 139 while a_keys: 140 current_key = a_keys.pop(0) 141 if current_key not in dict_pointer or not isinstance( 142 dict_pointer[current_key], dict 143 ): 144 dict_pointer[current_key] = OrderedDict() if ordered_dict else {} 145 dict_pointer = dict_pointer[current_key] 146 return in_dict 147 148 149def _dict_rpartition(in_dict, keys, delimiter=DEFAULT_TARGET_DELIM, ordered_dict=False): 150 """ 151 Helper function to: 152 - Ensure all but the last key in `keys` exist recursively in `in_dict`. 153 - Return the dict at the one-to-last key, and the last key 154 155 :param dict in_dict: The dict to work with. 156 :param str keys: The delimited string with one or more keys. 157 :param str delimiter: The delimiter to use in `keys`. Defaults to ':'. 158 :param bool ordered_dict: Create OrderedDicts if keys are missing. 159 Default: create regular dicts. 160 :rtype: tuple(dict, str) 161 :return: (The dict at the one-to-last key, the last key) 162 """ 163 if delimiter in keys: 164 all_but_last_keys, _, last_key = keys.rpartition(delimiter) 165 ensure_dict_key( 166 in_dict, all_but_last_keys, delimiter=delimiter, ordered_dict=ordered_dict 167 ) 168 dict_pointer = salt.utils.data.traverse_dict( 169 in_dict, all_but_last_keys, default=None, delimiter=delimiter 170 ) 171 else: 172 dict_pointer = in_dict 173 last_key = keys 174 return dict_pointer, last_key 175 176 177@jinja_filter("set_dict_key_value") 178def set_dict_key_value( 179 in_dict, keys, value, delimiter=DEFAULT_TARGET_DELIM, ordered_dict=False 180): 181 """ 182 Ensures that in_dict contains the series of recursive keys defined in keys. 183 Also sets whatever is at the end of `in_dict` traversed with `keys` to `value`. 184 185 :param dict in_dict: The dictionary to work with 186 :param str keys: The delimited string with one or more keys. 187 :param any value: The value to assign to the nested dict-key. 188 :param str delimiter: The delimiter to use in `keys`. Defaults to ':'. 189 :param bool ordered_dict: Create OrderedDicts if keys are missing. 190 Default: create regular dicts. 191 :rtype: dict 192 :return: Returns the modified in-place `in_dict`. 193 """ 194 dict_pointer, last_key = _dict_rpartition( 195 in_dict, keys, delimiter=delimiter, ordered_dict=ordered_dict 196 ) 197 dict_pointer[last_key] = value 198 return in_dict 199 200 201@jinja_filter("update_dict_key_value") 202def update_dict_key_value( 203 in_dict, keys, value, delimiter=DEFAULT_TARGET_DELIM, ordered_dict=False 204): 205 """ 206 Ensures that in_dict contains the series of recursive keys defined in keys. 207 Also updates the dict, that is at the end of `in_dict` traversed with `keys`, 208 with `value`. 209 210 :param dict in_dict: The dictionary to work with 211 :param str keys: The delimited string with one or more keys. 212 :param any value: The value to update the nested dict-key with. 213 :param str delimiter: The delimiter to use in `keys`. Defaults to ':'. 214 :param bool ordered_dict: Create OrderedDicts if keys are missing. 215 Default: create regular dicts. 216 :rtype: dict 217 :return: Returns the modified in-place `in_dict`. 218 """ 219 dict_pointer, last_key = _dict_rpartition( 220 in_dict, keys, delimiter=delimiter, ordered_dict=ordered_dict 221 ) 222 if last_key not in dict_pointer or dict_pointer[last_key] is None: 223 dict_pointer[last_key] = OrderedDict() if ordered_dict else {} 224 try: 225 dict_pointer[last_key].update(value) 226 except AttributeError: 227 raise SaltInvocationError( 228 "The last key contains a {}, which cannot update.".format( 229 type(dict_pointer[last_key]) 230 ) 231 ) 232 except (ValueError, TypeError): 233 raise SaltInvocationError( 234 "Cannot update {} with a {}.".format( 235 type(dict_pointer[last_key]), type(value) 236 ) 237 ) 238 return in_dict 239 240 241@jinja_filter("append_dict_key_value") 242def append_dict_key_value( 243 in_dict, keys, value, delimiter=DEFAULT_TARGET_DELIM, ordered_dict=False 244): 245 """ 246 Ensures that in_dict contains the series of recursive keys defined in keys. 247 Also appends `value` to the list that is at the end of `in_dict` traversed 248 with `keys`. 249 250 :param dict in_dict: The dictionary to work with 251 :param str keys: The delimited string with one or more keys. 252 :param any value: The value to append to the nested dict-key. 253 :param str delimiter: The delimiter to use in `keys`. Defaults to ':'. 254 :param bool ordered_dict: Create OrderedDicts if keys are missing. 255 Default: create regular dicts. 256 :rtype: dict 257 :return: Returns the modified in-place `in_dict`. 258 """ 259 dict_pointer, last_key = _dict_rpartition( 260 in_dict, keys, delimiter=delimiter, ordered_dict=ordered_dict 261 ) 262 if last_key not in dict_pointer or dict_pointer[last_key] is None: 263 dict_pointer[last_key] = [] 264 try: 265 dict_pointer[last_key].append(value) 266 except AttributeError: 267 raise SaltInvocationError( 268 "The last key contains a {}, which cannot append.".format( 269 type(dict_pointer[last_key]) 270 ) 271 ) 272 return in_dict 273 274 275@jinja_filter("extend_dict_key_value") 276def extend_dict_key_value( 277 in_dict, keys, value, delimiter=DEFAULT_TARGET_DELIM, ordered_dict=False 278): 279 """ 280 Ensures that in_dict contains the series of recursive keys defined in keys. 281 Also extends the list, that is at the end of `in_dict` traversed with `keys`, 282 with `value`. 283 284 :param dict in_dict: The dictionary to work with 285 :param str keys: The delimited string with one or more keys. 286 :param any value: The value to extend the nested dict-key with. 287 :param str delimiter: The delimiter to use in `keys`. Defaults to ':'. 288 :param bool ordered_dict: Create OrderedDicts if keys are missing. 289 Default: create regular dicts. 290 :rtype: dict 291 :return: Returns the modified in-place `in_dict`. 292 """ 293 dict_pointer, last_key = _dict_rpartition( 294 in_dict, keys, delimiter=delimiter, ordered_dict=ordered_dict 295 ) 296 if last_key not in dict_pointer or dict_pointer[last_key] is None: 297 dict_pointer[last_key] = [] 298 try: 299 dict_pointer[last_key].extend(value) 300 except AttributeError: 301 raise SaltInvocationError( 302 "The last key contains a {}, which cannot extend.".format( 303 type(dict_pointer[last_key]) 304 ) 305 ) 306 except TypeError: 307 raise SaltInvocationError( 308 "Cannot extend {} with a {}.".format( 309 type(dict_pointer[last_key]), type(value) 310 ) 311 ) 312 return in_dict 313