1// Copyright 2016 Keybase Inc. All rights reserved.
2// Use of this source code is governed by a BSD
3// license that can be found in the LICENSE file.
4
5package libkbfs
6
7import (
8	"fmt"
9	"sync"
10
11	"github.com/keybase/client/go/kbfs/data"
12)
13
14type nodeCacheEntry struct {
15	core     *nodeCore
16	refCount int
17}
18
19// nodeCacheStandard implements the NodeCache interface by tracking
20// the reference counts of nodeStandard Nodes, and using their member
21// fields to construct paths.
22type nodeCacheStandard struct {
23	folderBranch data.FolderBranch
24
25	lock           sync.RWMutex
26	nodes          map[data.BlockRef]*nodeCacheEntry
27	rootWrappers   []func(Node) Node
28	makeObfuscator func() data.Obfuscator
29}
30
31var _ NodeCache = (*nodeCacheStandard)(nil)
32
33func newNodeCacheStandard(fb data.FolderBranch) *nodeCacheStandard {
34	return &nodeCacheStandard{
35		folderBranch: fb,
36		nodes:        make(map[data.BlockRef]*nodeCacheEntry),
37	}
38}
39
40// lock must be locked for writing by the caller
41func (ncs *nodeCacheStandard) forgetLocked(core *nodeCore) {
42	ref := core.pathNode.Ref()
43
44	entry, ok := ncs.nodes[ref]
45	if !ok {
46		return
47	}
48	if entry.core != core {
49		return
50	}
51
52	entry.refCount--
53	if entry.refCount <= 0 {
54		delete(ncs.nodes, ref)
55	}
56}
57
58// should be called only by nodeStandardFinalizer().
59func (ncs *nodeCacheStandard) forget(core *nodeCore) {
60	ncs.lock.Lock()
61	defer ncs.lock.Unlock()
62	ncs.forgetLocked(core)
63}
64
65// lock must be held for writing by the caller
66func (ncs *nodeCacheStandard) newChildForParentLocked(parent Node) (*nodeStandard, error) {
67	nodeStandard, ok := parent.Unwrap().(*nodeStandard)
68	if !ok {
69		return nil, ParentNodeNotFoundError{data.BlockRef{}}
70	}
71
72	ref := nodeStandard.core.pathNode.Ref()
73	entry, ok := ncs.nodes[ref]
74	if !ok {
75		return nil, ParentNodeNotFoundError{ref}
76	}
77	if nodeStandard.core != entry.core {
78		return nil, ParentNodeNotFoundError{ref}
79	}
80	return nodeStandard, nil
81}
82func (ncs *nodeCacheStandard) wrapNodeStandard(
83	n Node, rootWrappers []func(Node) Node, parent Node) Node {
84	if parent != nil {
85		return parent.WrapChild(n)
86	}
87	for _, f := range rootWrappers {
88		n = f(n)
89	}
90	return n
91}
92
93func (ncs *nodeCacheStandard) makeNodeStandardForEntryLocked(
94	entry *nodeCacheEntry) *nodeStandard {
95	entry.refCount++
96	return makeNodeStandard(entry.core)
97}
98
99// GetOrCreate implements the NodeCache interface for nodeCacheStandard.
100func (ncs *nodeCacheStandard) GetOrCreate(
101	ptr data.BlockPointer, name data.PathPartString, parent Node,
102	et data.EntryType) (n Node, err error) {
103	var rootWrappers []func(Node) Node
104	defer func() {
105		if n != nil {
106			n = ncs.wrapNodeStandard(n, rootWrappers, parent)
107		}
108	}()
109
110	if !ptr.IsValid() {
111		// Temporary code to track down bad block
112		// pointers. Remove when not needed anymore.
113		panic(InvalidBlockRefError{ptr.Ref()})
114	}
115
116	if name.Plaintext() == "" {
117		return nil, EmptyNameError{ptr.Ref()}
118	}
119
120	ncs.lock.Lock()
121	defer ncs.lock.Unlock()
122	rootWrappers = ncs.rootWrappers
123	entry, ok := ncs.nodes[ptr.Ref()]
124	if ok {
125		// If the entry happens to be unlinked, we may be in a
126		// situation where a node got unlinked and then recreated, but
127		// someone held onto a node the whole time and so it never got
128		// removed from the cache.  In that case, forcibly remove it
129		// from the cache to make room for the new node.
130		if parent != nil && entry.core.parent == nil {
131			delete(ncs.nodes, ptr.Ref())
132		} else {
133			return ncs.makeNodeStandardForEntryLocked(entry), nil
134		}
135	}
136
137	if parent != nil {
138		// Make sure a child can be made for this parent.
139		_, err := ncs.newChildForParentLocked(parent)
140		if err != nil {
141			return nil, err
142		}
143	}
144
145	entry = &nodeCacheEntry{}
146	if et == data.Dir && ncs.makeObfuscator != nil {
147		entry.core = newNodeCoreForDir(
148			ptr, name, parent, ncs, ncs.makeObfuscator())
149	} else {
150		entry.core = newNodeCore(ptr, name, parent, ncs, et)
151	}
152	ncs.nodes[ptr.Ref()] = entry
153	return ncs.makeNodeStandardForEntryLocked(entry), nil
154}
155
156// Get implements the NodeCache interface for nodeCacheStandard.
157func (ncs *nodeCacheStandard) Get(ref data.BlockRef) (n Node) {
158	if ref == (data.BlockRef{}) {
159		return nil
160	}
161
162	// Temporary code to track down bad block pointers. Remove (or
163	// return an error) when not needed anymore.
164	if !ref.IsValid() {
165		panic(InvalidBlockRefError{ref})
166	}
167
168	var rootWrappers []func(Node) Node
169	var parent Node
170	defer func() {
171		if n != nil {
172			n = ncs.wrapNodeStandard(n, rootWrappers, parent)
173		}
174	}()
175
176	ncs.lock.Lock()
177	defer ncs.lock.Unlock()
178	rootWrappers = ncs.rootWrappers
179	entry, ok := ncs.nodes[ref]
180	if !ok {
181		return nil
182	}
183	ns := ncs.makeNodeStandardForEntryLocked(entry)
184	parent = ns.core.parent // get while under lock
185	return ns
186}
187
188// UpdatePointer implements the NodeCache interface for nodeCacheStandard.
189func (ncs *nodeCacheStandard) UpdatePointer(
190	oldRef data.BlockRef, newPtr data.BlockPointer) (updatedNode NodeID) {
191	if oldRef == (data.BlockRef{}) && newPtr == (data.BlockPointer{}) {
192		return nil
193	}
194
195	if !oldRef.IsValid() {
196		panic(fmt.Sprintf("invalid oldRef %s with newPtr %s", oldRef, newPtr))
197	}
198
199	if !newPtr.IsValid() {
200		panic(fmt.Sprintf("invalid newPtr %s with oldRef %s", newPtr, oldRef))
201	}
202
203	ncs.lock.Lock()
204	defer ncs.lock.Unlock()
205	entry, ok := ncs.nodes[oldRef]
206	if !ok {
207		return nil
208	}
209
210	// Cannot update the pointer for an unlinked node.
211	if entry.core.cachedPath.IsValid() {
212		return nil
213	}
214
215	entry.core.pathNode.BlockPointer = newPtr
216	delete(ncs.nodes, oldRef)
217	ncs.nodes[newPtr.Ref()] = entry
218	return entry.core
219}
220
221// Move implements the NodeCache interface for nodeCacheStandard.
222func (ncs *nodeCacheStandard) Move(
223	ref data.BlockRef, newParent Node, newName data.PathPartString) (
224	undoFn func(), err error) {
225	if ref == (data.BlockRef{}) {
226		return nil, nil
227	}
228
229	// Temporary code to track down bad block pointers. Remove (or
230	// return an error) when not needed anymore.
231	if !ref.IsValid() {
232		panic(InvalidBlockRefError{ref})
233	}
234
235	if newName.Plaintext() == "" {
236		return nil, EmptyNameError{ref}
237	}
238
239	ncs.lock.Lock()
240	defer ncs.lock.Unlock()
241	entry, ok := ncs.nodes[ref]
242	if !ok {
243		return nil, nil
244	}
245
246	newParentNS, err := ncs.newChildForParentLocked(newParent)
247	if err != nil {
248		return nil, err
249	}
250
251	oldParent := entry.core.parent
252	oldName := entry.core.pathNode.Name
253
254	entry.core.parent = newParentNS
255	entry.core.pathNode.Name = newName
256
257	return func() {
258		entry.core.parent = oldParent
259		entry.core.pathNode.Name = oldName
260	}, nil
261}
262
263// Unlink implements the NodeCache interface for nodeCacheStandard.
264func (ncs *nodeCacheStandard) Unlink(
265	ref data.BlockRef, oldPath data.Path, oldDe data.DirEntry) (undoFn func()) {
266	if ref == (data.BlockRef{}) {
267		return nil
268	}
269
270	// Temporary code to track down bad block pointers. Remove (or
271	// return an error) when not needed anymore.
272	if !ref.IsValid() {
273		panic(InvalidBlockRefError{ref})
274	}
275
276	ncs.lock.Lock()
277	defer ncs.lock.Unlock()
278	entry, ok := ncs.nodes[ref]
279	if !ok {
280		return nil
281	}
282
283	if entry.core.cachedPath.IsValid() {
284		// Already unlinked!
285		return nil
286	}
287
288	oldParent := entry.core.parent
289	oldName := entry.core.pathNode.Name
290
291	entry.core.cachedPath = oldPath
292	entry.core.cachedDe = oldDe
293	entry.core.parent = nil
294	entry.core.pathNode.Name = data.PathPartString{}
295
296	return func() {
297		entry.core.cachedPath = data.Path{}
298		entry.core.cachedDe = data.DirEntry{}
299		entry.core.parent = oldParent
300		entry.core.pathNode.Name = oldName
301	}
302}
303
304// IsUnlinked implements the NodeCache interface for
305// nodeCacheStandard.
306func (ncs *nodeCacheStandard) IsUnlinked(node Node) bool {
307	ncs.lock.RLock()
308	defer ncs.lock.RUnlock()
309
310	ns, ok := node.Unwrap().(*nodeStandard)
311	if !ok {
312		return false
313	}
314
315	return ns.core.cachedPath.IsValid()
316}
317
318// UnlinkedDirEntry implements the NodeCache interface for
319// nodeCacheStandard.
320func (ncs *nodeCacheStandard) UnlinkedDirEntry(node Node) data.DirEntry {
321	ncs.lock.RLock()
322	defer ncs.lock.RUnlock()
323
324	ns, ok := node.Unwrap().(*nodeStandard)
325	if !ok {
326		return data.DirEntry{}
327	}
328
329	return ns.core.cachedDe
330}
331
332// UpdateUnlinkedDirEntry implements the NodeCache interface for
333// nodeCacheStandard.
334func (ncs *nodeCacheStandard) UpdateUnlinkedDirEntry(
335	node Node, newDe data.DirEntry) {
336	ncs.lock.Lock()
337	defer ncs.lock.Unlock()
338
339	ns, ok := node.Unwrap().(*nodeStandard)
340	if !ok {
341		return
342	}
343
344	ns.core.cachedDe = newDe
345}
346
347// PathFromNode implements the NodeCache interface for nodeCacheStandard.
348func (ncs *nodeCacheStandard) PathFromNode(node Node) (p data.Path) {
349	ncs.lock.RLock()
350	defer ncs.lock.RUnlock()
351
352	ns, ok := node.Unwrap().(*nodeStandard)
353	if !ok {
354		p.Path = nil
355		return
356	}
357
358	p.ChildObfuscator = ns.core.obfuscator
359
360	for ns != nil {
361		core := ns.core
362		if core.parent == nil && len(core.cachedPath.Path) > 0 {
363			// The node was unlinked, but is still in use, so use its
364			// cached path.  The path is already reversed, so append
365			// it backwards one-by-one to the existing path.  If this
366			// is the first node, we can just optimize by returning
367			// the complete cached path.
368			if len(p.Path) == 0 {
369				return core.cachedPath
370			}
371			for i := len(core.cachedPath.Path) - 1; i >= 0; i-- {
372				p.Path = append(p.Path, core.cachedPath.Path[i])
373			}
374			break
375		}
376
377		p.Path = append(p.Path, *core.pathNode)
378		if core.parent != nil {
379			ns = core.parent.Unwrap().(*nodeStandard)
380		} else {
381			break
382		}
383	}
384
385	// need to reverse the path nodes
386	for i := len(p.Path)/2 - 1; i >= 0; i-- {
387		opp := len(p.Path) - 1 - i
388		p.Path[i], p.Path[opp] = p.Path[opp], p.Path[i]
389	}
390
391	// TODO: would it make any sense to cache the constructed path?
392	p.FolderBranch = ncs.folderBranch
393	return
394}
395
396// AllNodes implements the NodeCache interface for nodeCacheStandard.
397func (ncs *nodeCacheStandard) AllNodes() (nodes []Node) {
398	ncs.lock.RLock()
399	defer ncs.lock.RUnlock()
400	nodes = make([]Node, 0, len(ncs.nodes))
401	for _, entry := range ncs.nodes {
402		nodes = append(nodes, ncs.makeNodeStandardForEntryLocked(entry))
403	}
404	return nodes
405}
406
407// AllNodeChildren implements the NodeCache interface for nodeCacheStandard.
408func (ncs *nodeCacheStandard) AllNodeChildren(n Node) (nodes []Node) {
409	ncs.lock.RLock()
410	defer ncs.lock.RUnlock()
411	nodes = make([]Node, 0, len(ncs.nodes))
412	entryIDs := make(map[NodeID]bool)
413	for _, entry := range ncs.nodes {
414		var pathIDs []NodeID
415		parent := entry.core.parent
416		for parent != nil {
417			// If the node's parent is what we're looking for (or on
418			// the path to what we're looking for), include it in the
419			// list.
420			parentID := parent.GetID()
421			if parentID == n.GetID() || entryIDs[parentID] {
422				nodes = append(nodes, ncs.makeNodeStandardForEntryLocked(entry))
423				for _, id := range pathIDs {
424					entryIDs[id] = true
425				}
426				entryIDs[entry.core] = true
427				break
428			}
429
430			// Otherwise, remember this parent and continue back
431			// toward the root.
432			pathIDs = append(pathIDs, parentID)
433			ns, ok := parent.Unwrap().(*nodeStandard)
434			if !ok {
435				break
436			}
437			parent = ns.core.parent
438		}
439	}
440	return nodes
441}
442
443func (ncs *nodeCacheStandard) AddRootWrapper(f func(Node) Node) {
444	ncs.lock.Lock()
445	defer ncs.lock.Unlock()
446	ncs.rootWrappers = append(ncs.rootWrappers, f)
447}
448
449func (ncs *nodeCacheStandard) SetObfuscatorMaker(
450	makeOb func() data.Obfuscator) {
451	ncs.lock.Lock()
452	defer ncs.lock.Unlock()
453	ncs.makeObfuscator = makeOb
454}
455
456func (ncs *nodeCacheStandard) ObfuscatorMaker() func() data.Obfuscator {
457	ncs.lock.RLock()
458	defer ncs.lock.RUnlock()
459	return ncs.makeObfuscator
460}
461