1# coding: utf-8
2
3# Copyright (c) Jupyter Development Team.
4# Distributed under the terms of the Modified BSD License.
5
6import copy
7import nbformat
8from nbformat import NotebookNode
9
10from .diff_format import DiffOp, NBDiffFormatError
11from .diff_utils import flatten_list_of_string_diff
12
13
14
15__all__ = ["patch", "patch_notebook"]
16
17
18def patch_list(obj, diff):
19    # The patched sequence to build and return
20    newobj = []
21    # Index into obj, the next item to take unless diff says otherwise
22    take = 0
23    for e in diff:
24        op = e.op
25        index = e.key
26        assert isinstance(index, int), 'list key must be integer'
27
28        # Take values from obj not mentioned in diff, up to not including index
29        newobj.extend(copy.deepcopy(value) for value in obj[take:index])
30
31        if op == DiffOp.ADDRANGE:
32            # Extend with new values directly
33            newobj.extend(e.valuelist)
34            skip = 0
35        elif op == DiffOp.REMOVERANGE:
36            # Delete a number of values by skipping
37            skip = e.length
38        elif op == DiffOp.PATCH:
39            newobj.append(patch(obj[index], e.diff))
40            skip = 1
41        # Note that the operations ADD, REMOVE, REPLACE are not produced by the
42        # diff algorithm anymore, keeping these cases just in case we want them back:
43        elif op == DiffOp.ADD:
44            # Append new value directly
45            newobj.append(e.value)
46            skip = 0
47        elif op == DiffOp.REMOVE:
48            # Delete values obj[index] by incrementing take to skip
49            skip = 1
50        elif op == DiffOp.REPLACE:
51            # Add replacement value and skip old
52            newobj.append(e.value)
53            skip = 1
54        else:
55            raise NBDiffFormatError("Invalid op {}.".format(op))
56
57        # Skip the specified number of elements, but never decrement take.
58        # Note that take can pass index in diffs with repeated +/- on the
59        # same index, i.e. [op_remove(index), op_add(index, value)]
60        take = max(take, index + skip)
61
62    # Take values at end not mentioned in diff
63    newobj.extend(copy.deepcopy(value) for value in obj[take:len(obj)])
64
65    return newobj
66
67
68def patch_string(obj, diff):
69    "Patch a multiline string, assuming diff is line based."
70    # This can possibly be optimized for str if wanted, but
71    # waiting until patch_list has been tested and debugged better
72
73    # Flatten line-based diff to character based first!
74    diff = flatten_list_of_string_diff(obj, diff)
75    return "".join(patch_list(list(obj), diff))
76
77
78def patch_singleline_string(obj, diff):
79    "Patch a singleline string, assuming diff is character based."
80    # This can possibly be optimized for str if wanted, but
81    # waiting until patch_list has been tested and debugged better
82    return "".join(patch_list(list(obj), diff))
83
84
85def patch_dict(obj, diff):
86    newobj = {}
87    deleted_keys = set()
88
89    for e in diff:
90        op = e.op
91        key = e.key
92        assert isinstance(key, str), 'dict key must be string'
93        assert key not in newobj, 'multiple diff entries target same key: %r' % key
94
95        if op == DiffOp.ADD:
96            assert key not in obj, 'patch add value not found for key: %r' % key
97            newobj[key] = e.value
98        elif op == DiffOp.REMOVE:
99            deleted_keys.add(key)
100        elif op == DiffOp.REPLACE:
101            assert key not in deleted_keys, 'cannot replace deleted key: %r' % key
102            newobj[key] = e.value
103        elif op == DiffOp.PATCH:
104            assert key not in deleted_keys, 'cannot patch deleted key: %r' % key
105            newobj[key] = patch(obj[key], e.diff)
106        else:
107            raise NBDiffFormatError("Invalid op {}.".format(op))
108
109    # Take items not mentioned in diff
110    for key in obj:
111        if key not in deleted_keys and key not in newobj:
112            newobj[key] = copy.deepcopy(obj[key])
113
114    return NotebookNode(newobj)
115
116
117def patch(obj, diff):
118    """Produce a patched version of obj with given hierarchical diff.
119
120    A valid input object can be any dict or list of leaf values,
121    or arbitrarily nested dict or list of valid input objects.
122
123    Dicts are required to have string keys.
124
125    Leaf values are any non-dict, non-list objects as far as patch
126    is concerned, although the intentional use of this library
127    is that values are json-serializable.
128    """
129    if isinstance(obj, dict):
130        return patch_dict(obj, diff)
131    elif isinstance(obj, list):
132        return patch_list(obj, diff)
133    elif isinstance(obj, str):
134        return patch_string(obj, diff)
135    else:
136        raise ValueError("Invalid object type to patch: {}".format(type(obj).__name__))
137
138
139def patch_notebook(nb, diff):
140    return nbformat.from_dict(patch(nb, diff))
141