1"""
2Alex Martelli's soulution for recursive dict update from
3http://stackoverflow.com/a/3233356
4"""
5
6import copy
7import logging
8from collections.abc import Mapping
9
10import salt.utils.data
11from salt.defaults import DEFAULT_TARGET_DELIM
12from salt.exceptions import SaltInvocationError
13from salt.utils.decorators.jinja import jinja_filter
14from salt.utils.odict import OrderedDict
15
16log = logging.getLogger(__name__)
17
18
19def update(dest, upd, recursive_update=True, merge_lists=False):
20    """
21    Recursive version of the default dict.update
22
23    Merges upd recursively into dest
24
25    If recursive_update=False, will use the classic dict.update, or fall back
26    on a manual merge (helpful for non-dict types like FunctionWrapper)
27
28    If merge_lists=True, will aggregate list object types instead of replace.
29    The list in ``upd`` is added to the list in ``dest``, so the resulting list
30    is ``dest[key] + upd[key]``. This behavior is only activated when
31    recursive_update=True. By default merge_lists=False.
32
33    .. versionchanged:: 2016.11.6
34        When merging lists, duplicate values are removed. Values already
35        present in the ``dest`` list are not added from the ``upd`` list.
36    """
37    if (not isinstance(dest, Mapping)) or (not isinstance(upd, Mapping)):
38        raise TypeError("Cannot update using non-dict types in dictupdate.update()")
39    updkeys = list(upd.keys())
40    if not set(list(dest.keys())) & set(updkeys):
41        recursive_update = False
42    if recursive_update:
43        for key in updkeys:
44            val = upd[key]
45            try:
46                dest_subkey = dest.get(key, None)
47            except AttributeError:
48                dest_subkey = None
49            if isinstance(dest_subkey, Mapping) and isinstance(val, Mapping):
50                ret = update(dest_subkey, val, merge_lists=merge_lists)
51                dest[key] = ret
52            elif isinstance(dest_subkey, list) and isinstance(val, list):
53                if merge_lists:
54                    merged = copy.deepcopy(dest_subkey)
55                    merged.extend([x for x in val if x not in merged])
56                    dest[key] = merged
57                else:
58                    dest[key] = upd[key]
59            else:
60                dest[key] = upd[key]
61        return dest
62    for k in upd:
63        dest[k] = upd[k]
64    return dest
65
66
67def merge_list(obj_a, obj_b):
68    ret = {}
69    for key, val in obj_a.items():
70        if key in obj_b:
71            ret[key] = [val, obj_b[key]]
72        else:
73            ret[key] = val
74    return ret
75
76
77def merge_recurse(obj_a, obj_b, merge_lists=False):
78    copied = copy.deepcopy(obj_a)
79    return update(copied, obj_b, merge_lists=merge_lists)
80
81
82def merge_aggregate(obj_a, obj_b):
83    from salt.serializers.yamlex import merge_recursive as _yamlex_merge_recursive
84
85    return _yamlex_merge_recursive(obj_a, obj_b, level=1)
86
87
88def merge_overwrite(obj_a, obj_b, merge_lists=False):
89    for obj in obj_b:
90        if obj in obj_a:
91            obj_a[obj] = obj_b[obj]
92    return merge_recurse(obj_a, obj_b, merge_lists=merge_lists)
93
94
95def merge(obj_a, obj_b, strategy="smart", renderer="yaml", merge_lists=False):
96    if strategy == "smart":
97        if renderer.split("|")[-1] == "yamlex" or renderer.startswith("yamlex_"):
98            strategy = "aggregate"
99        else:
100            strategy = "recurse"
101
102    if strategy == "list":
103        merged = merge_list(obj_a, obj_b)
104    elif strategy == "recurse":
105        merged = merge_recurse(obj_a, obj_b, merge_lists)
106    elif strategy == "aggregate":
107        #: level = 1 merge at least root data
108        merged = merge_aggregate(obj_a, obj_b)
109    elif strategy == "overwrite":
110        merged = merge_overwrite(obj_a, obj_b, merge_lists)
111    elif strategy == "none":
112        # If we do not want to merge, there is only one pillar passed, so we can safely use the default recurse,
113        # we just do not want to log an error
114        merged = merge_recurse(obj_a, obj_b)
115    else:
116        log.warning("Unknown merging strategy '%s', fallback to recurse", strategy)
117        merged = merge_recurse(obj_a, obj_b)
118
119    return merged
120
121
122def ensure_dict_key(in_dict, keys, delimiter=DEFAULT_TARGET_DELIM, ordered_dict=False):
123    """
124    Ensures that in_dict contains the series of recursive keys defined in keys.
125
126    :param dict in_dict: The dict to work with.
127    :param str keys: The delimited string with one or more keys.
128    :param str delimiter: The delimiter to use in `keys`. Defaults to ':'.
129    :param bool ordered_dict: Create OrderedDicts if keys are missing.
130                              Default: create regular dicts.
131    :rtype: dict
132    :return: Returns the modified in-place `in_dict`.
133    """
134    if delimiter in keys:
135        a_keys = keys.split(delimiter)
136    else:
137        a_keys = [keys]
138    dict_pointer = in_dict
139    while a_keys:
140        current_key = a_keys.pop(0)
141        if current_key not in dict_pointer or not isinstance(
142            dict_pointer[current_key], dict
143        ):
144            dict_pointer[current_key] = OrderedDict() if ordered_dict else {}
145        dict_pointer = dict_pointer[current_key]
146    return in_dict
147
148
149def _dict_rpartition(in_dict, keys, delimiter=DEFAULT_TARGET_DELIM, ordered_dict=False):
150    """
151    Helper function to:
152    - Ensure all but the last key in `keys` exist recursively in `in_dict`.
153    - Return the dict at the one-to-last key, and the last key
154
155    :param dict in_dict: The dict to work with.
156    :param str keys: The delimited string with one or more keys.
157    :param str delimiter: The delimiter to use in `keys`. Defaults to ':'.
158    :param bool ordered_dict: Create OrderedDicts if keys are missing.
159                              Default: create regular dicts.
160    :rtype: tuple(dict, str)
161    :return: (The dict at the one-to-last key, the last key)
162    """
163    if delimiter in keys:
164        all_but_last_keys, _, last_key = keys.rpartition(delimiter)
165        ensure_dict_key(
166            in_dict, all_but_last_keys, delimiter=delimiter, ordered_dict=ordered_dict
167        )
168        dict_pointer = salt.utils.data.traverse_dict(
169            in_dict, all_but_last_keys, default=None, delimiter=delimiter
170        )
171    else:
172        dict_pointer = in_dict
173        last_key = keys
174    return dict_pointer, last_key
175
176
177@jinja_filter("set_dict_key_value")
178def set_dict_key_value(
179    in_dict, keys, value, delimiter=DEFAULT_TARGET_DELIM, ordered_dict=False
180):
181    """
182    Ensures that in_dict contains the series of recursive keys defined in keys.
183    Also sets whatever is at the end of `in_dict` traversed with `keys` to `value`.
184
185    :param dict in_dict: The dictionary to work with
186    :param str keys: The delimited string with one or more keys.
187    :param any value: The value to assign to the nested dict-key.
188    :param str delimiter: The delimiter to use in `keys`. Defaults to ':'.
189    :param bool ordered_dict: Create OrderedDicts if keys are missing.
190                              Default: create regular dicts.
191    :rtype: dict
192    :return: Returns the modified in-place `in_dict`.
193    """
194    dict_pointer, last_key = _dict_rpartition(
195        in_dict, keys, delimiter=delimiter, ordered_dict=ordered_dict
196    )
197    dict_pointer[last_key] = value
198    return in_dict
199
200
201@jinja_filter("update_dict_key_value")
202def update_dict_key_value(
203    in_dict, keys, value, delimiter=DEFAULT_TARGET_DELIM, ordered_dict=False
204):
205    """
206    Ensures that in_dict contains the series of recursive keys defined in keys.
207    Also updates the dict, that is at the end of `in_dict` traversed with `keys`,
208    with `value`.
209
210    :param dict in_dict: The dictionary to work with
211    :param str keys: The delimited string with one or more keys.
212    :param any value: The value to update the nested dict-key with.
213    :param str delimiter: The delimiter to use in `keys`. Defaults to ':'.
214    :param bool ordered_dict: Create OrderedDicts if keys are missing.
215                              Default: create regular dicts.
216    :rtype: dict
217    :return: Returns the modified in-place `in_dict`.
218    """
219    dict_pointer, last_key = _dict_rpartition(
220        in_dict, keys, delimiter=delimiter, ordered_dict=ordered_dict
221    )
222    if last_key not in dict_pointer or dict_pointer[last_key] is None:
223        dict_pointer[last_key] = OrderedDict() if ordered_dict else {}
224    try:
225        dict_pointer[last_key].update(value)
226    except AttributeError:
227        raise SaltInvocationError(
228            "The last key contains a {}, which cannot update.".format(
229                type(dict_pointer[last_key])
230            )
231        )
232    except (ValueError, TypeError):
233        raise SaltInvocationError(
234            "Cannot update {} with a {}.".format(
235                type(dict_pointer[last_key]), type(value)
236            )
237        )
238    return in_dict
239
240
241@jinja_filter("append_dict_key_value")
242def append_dict_key_value(
243    in_dict, keys, value, delimiter=DEFAULT_TARGET_DELIM, ordered_dict=False
244):
245    """
246    Ensures that in_dict contains the series of recursive keys defined in keys.
247    Also appends `value` to the list that is at the end of `in_dict` traversed
248    with `keys`.
249
250    :param dict in_dict: The dictionary to work with
251    :param str keys: The delimited string with one or more keys.
252    :param any value: The value to append to the nested dict-key.
253    :param str delimiter: The delimiter to use in `keys`. Defaults to ':'.
254    :param bool ordered_dict: Create OrderedDicts if keys are missing.
255                              Default: create regular dicts.
256    :rtype: dict
257    :return: Returns the modified in-place `in_dict`.
258    """
259    dict_pointer, last_key = _dict_rpartition(
260        in_dict, keys, delimiter=delimiter, ordered_dict=ordered_dict
261    )
262    if last_key not in dict_pointer or dict_pointer[last_key] is None:
263        dict_pointer[last_key] = []
264    try:
265        dict_pointer[last_key].append(value)
266    except AttributeError:
267        raise SaltInvocationError(
268            "The last key contains a {}, which cannot append.".format(
269                type(dict_pointer[last_key])
270            )
271        )
272    return in_dict
273
274
275@jinja_filter("extend_dict_key_value")
276def extend_dict_key_value(
277    in_dict, keys, value, delimiter=DEFAULT_TARGET_DELIM, ordered_dict=False
278):
279    """
280    Ensures that in_dict contains the series of recursive keys defined in keys.
281    Also extends the list, that is at the end of `in_dict` traversed with `keys`,
282    with `value`.
283
284    :param dict in_dict: The dictionary to work with
285    :param str keys: The delimited string with one or more keys.
286    :param any value: The value to extend the nested dict-key with.
287    :param str delimiter: The delimiter to use in `keys`. Defaults to ':'.
288    :param bool ordered_dict: Create OrderedDicts if keys are missing.
289                              Default: create regular dicts.
290    :rtype: dict
291    :return: Returns the modified in-place `in_dict`.
292    """
293    dict_pointer, last_key = _dict_rpartition(
294        in_dict, keys, delimiter=delimiter, ordered_dict=ordered_dict
295    )
296    if last_key not in dict_pointer or dict_pointer[last_key] is None:
297        dict_pointer[last_key] = []
298    try:
299        dict_pointer[last_key].extend(value)
300    except AttributeError:
301        raise SaltInvocationError(
302            "The last key contains a {}, which cannot extend.".format(
303                type(dict_pointer[last_key])
304            )
305        )
306    except TypeError:
307        raise SaltInvocationError(
308            "Cannot extend {} with a {}.".format(
309                type(dict_pointer[last_key]), type(value)
310            )
311        )
312    return in_dict
313