1# coding: utf-8 2 3# Copyright (c) Jupyter Development Team. 4# Distributed under the terms of the Modified BSD License. 5 6import copy 7import nbformat 8from nbformat import NotebookNode 9 10from .diff_format import DiffOp, NBDiffFormatError 11from .diff_utils import flatten_list_of_string_diff 12 13 14 15__all__ = ["patch", "patch_notebook"] 16 17 18def patch_list(obj, diff): 19 # The patched sequence to build and return 20 newobj = [] 21 # Index into obj, the next item to take unless diff says otherwise 22 take = 0 23 for e in diff: 24 op = e.op 25 index = e.key 26 assert isinstance(index, int), 'list key must be integer' 27 28 # Take values from obj not mentioned in diff, up to not including index 29 newobj.extend(copy.deepcopy(value) for value in obj[take:index]) 30 31 if op == DiffOp.ADDRANGE: 32 # Extend with new values directly 33 newobj.extend(e.valuelist) 34 skip = 0 35 elif op == DiffOp.REMOVERANGE: 36 # Delete a number of values by skipping 37 skip = e.length 38 elif op == DiffOp.PATCH: 39 newobj.append(patch(obj[index], e.diff)) 40 skip = 1 41 # Note that the operations ADD, REMOVE, REPLACE are not produced by the 42 # diff algorithm anymore, keeping these cases just in case we want them back: 43 elif op == DiffOp.ADD: 44 # Append new value directly 45 newobj.append(e.value) 46 skip = 0 47 elif op == DiffOp.REMOVE: 48 # Delete values obj[index] by incrementing take to skip 49 skip = 1 50 elif op == DiffOp.REPLACE: 51 # Add replacement value and skip old 52 newobj.append(e.value) 53 skip = 1 54 else: 55 raise NBDiffFormatError("Invalid op {}.".format(op)) 56 57 # Skip the specified number of elements, but never decrement take. 58 # Note that take can pass index in diffs with repeated +/- on the 59 # same index, i.e. [op_remove(index), op_add(index, value)] 60 take = max(take, index + skip) 61 62 # Take values at end not mentioned in diff 63 newobj.extend(copy.deepcopy(value) for value in obj[take:len(obj)]) 64 65 return newobj 66 67 68def patch_string(obj, diff): 69 "Patch a multiline string, assuming diff is line based." 70 # This can possibly be optimized for str if wanted, but 71 # waiting until patch_list has been tested and debugged better 72 73 # Flatten line-based diff to character based first! 74 diff = flatten_list_of_string_diff(obj, diff) 75 return "".join(patch_list(list(obj), diff)) 76 77 78def patch_singleline_string(obj, diff): 79 "Patch a singleline string, assuming diff is character based." 80 # This can possibly be optimized for str if wanted, but 81 # waiting until patch_list has been tested and debugged better 82 return "".join(patch_list(list(obj), diff)) 83 84 85def patch_dict(obj, diff): 86 newobj = {} 87 deleted_keys = set() 88 89 for e in diff: 90 op = e.op 91 key = e.key 92 assert isinstance(key, str), 'dict key must be string' 93 assert key not in newobj, 'multiple diff entries target same key: %r' % key 94 95 if op == DiffOp.ADD: 96 assert key not in obj, 'patch add value not found for key: %r' % key 97 newobj[key] = e.value 98 elif op == DiffOp.REMOVE: 99 deleted_keys.add(key) 100 elif op == DiffOp.REPLACE: 101 assert key not in deleted_keys, 'cannot replace deleted key: %r' % key 102 newobj[key] = e.value 103 elif op == DiffOp.PATCH: 104 assert key not in deleted_keys, 'cannot patch deleted key: %r' % key 105 newobj[key] = patch(obj[key], e.diff) 106 else: 107 raise NBDiffFormatError("Invalid op {}.".format(op)) 108 109 # Take items not mentioned in diff 110 for key in obj: 111 if key not in deleted_keys and key not in newobj: 112 newobj[key] = copy.deepcopy(obj[key]) 113 114 return NotebookNode(newobj) 115 116 117def patch(obj, diff): 118 """Produce a patched version of obj with given hierarchical diff. 119 120 A valid input object can be any dict or list of leaf values, 121 or arbitrarily nested dict or list of valid input objects. 122 123 Dicts are required to have string keys. 124 125 Leaf values are any non-dict, non-list objects as far as patch 126 is concerned, although the intentional use of this library 127 is that values are json-serializable. 128 """ 129 if isinstance(obj, dict): 130 return patch_dict(obj, diff) 131 elif isinstance(obj, list): 132 return patch_list(obj, diff) 133 elif isinstance(obj, str): 134 return patch_string(obj, diff) 135 else: 136 raise ValueError("Invalid object type to patch: {}".format(type(obj).__name__)) 137 138 139def patch_notebook(nb, diff): 140 return nbformat.from_dict(patch(nb, diff)) 141