1// Copyright 2018 Keybase Inc. All rights reserved.
2// Use of this source code is governed by a BSD
3// license that can be found in the LICENSE file.
4
5package libkbfs
6
7import (
8	"context"
9
10	"github.com/keybase/client/go/kbfs/data"
11	"github.com/pkg/errors"
12)
13
14type blockState struct {
15	block          data.Block
16	readyBlockData data.ReadyBlockData
17	syncedCb       func() error
18	oldPtr         data.BlockPointer
19}
20
21// blockPutStateMemory is an internal structure to track data in
22// memory when putting blocks.
23type blockPutStateMemory struct {
24	blockStates map[data.BlockPointer]blockState
25	lastBlock   data.BlockPointer
26}
27
28var _ blockPutStateCopiable = (*blockPutStateMemory)(nil)
29
30func newBlockPutStateMemory(length int) *blockPutStateMemory {
31	bps := &blockPutStateMemory{}
32	bps.blockStates = make(map[data.BlockPointer]blockState, length)
33	return bps
34}
35
36// AddNewBlock tracks a new block that will be put.  If syncedCb is
37// non-nil, it will be called whenever the put for that block is
38// complete (whether or not the put resulted in an error).  Currently
39// it will not be called if the block is never put (due to an earlier
40// error).
41func (bps *blockPutStateMemory) AddNewBlock(
42	_ context.Context, blockPtr data.BlockPointer, block data.Block,
43	readyBlockData data.ReadyBlockData, syncedCb func() error) error {
44	bps.blockStates[blockPtr] = blockState{
45		block, readyBlockData, syncedCb, data.ZeroPtr}
46	bps.lastBlock = blockPtr
47	return nil
48}
49
50// SaveOldPtr stores the given BlockPointer as the old (pre-readied)
51// pointer for the most recent blockState.
52func (bps *blockPutStateMemory) SaveOldPtr(
53	_ context.Context, oldPtr data.BlockPointer) error {
54	if bps.lastBlock == data.ZeroPtr {
55		return errors.New("No blocks have been added")
56	}
57	bs, ok := bps.blockStates[bps.lastBlock]
58	if !ok {
59		return errors.Errorf("Last block %v doesn't exist", bps.lastBlock)
60	}
61	bs.oldPtr = oldPtr
62	bps.blockStates[bps.lastBlock] = bs
63	return nil
64}
65
66func (bps *blockPutStateMemory) oldPtr(
67	_ context.Context, blockPtr data.BlockPointer) (data.BlockPointer, error) {
68	bs, ok := bps.blockStates[blockPtr]
69	if ok {
70		return bs.oldPtr, nil
71	}
72	return data.BlockPointer{}, errors.WithStack(
73		data.NoSuchBlockError{ID: blockPtr.ID})
74}
75
76func (bps *blockPutStateMemory) mergeOtherBps(
77	_ context.Context, other blockPutStateCopiable) error {
78	otherMem, ok := other.(*blockPutStateMemory)
79	if !ok {
80		return errors.Errorf("Cannot remove other bps of type %T", other)
81	}
82
83	for ptr, bs := range otherMem.blockStates {
84		bps.blockStates[ptr] = bs
85	}
86	return nil
87}
88
89func (bps *blockPutStateMemory) removeOtherBps(
90	ctx context.Context, other blockPutStateCopiable) error {
91	otherMem, ok := other.(*blockPutStateMemory)
92	if !ok {
93		return errors.Errorf("Cannot remove other bps of type %T", other)
94	}
95	if len(otherMem.blockStates) == 0 {
96		return nil
97	}
98
99	otherMemPtrs := make(map[data.BlockPointer]bool, len(otherMem.blockStates))
100	for ptr := range otherMem.blockStates {
101		otherMemPtrs[ptr] = true
102	}
103
104	newBps, err := bps.deepCopyWithBlacklist(ctx, otherMemPtrs)
105	if err != nil {
106		return err
107	}
108	newBpsMem, ok := newBps.(*blockPutStateMemory)
109	if !ok {
110		return errors.Errorf(
111			"Bad deep copy type when removing blocks: %T", newBps)
112	}
113
114	bps.blockStates = newBpsMem.blockStates
115	return nil
116}
117
118func (bps *blockPutStateMemory) Ptrs() []data.BlockPointer {
119	ret := make([]data.BlockPointer, len(bps.blockStates))
120	i := 0
121	for ptr := range bps.blockStates {
122		ret[i] = ptr
123		i++
124	}
125	return ret
126}
127
128func (bps *blockPutStateMemory) GetBlock(
129	_ context.Context, blockPtr data.BlockPointer) (data.Block, error) {
130	bs, ok := bps.blockStates[blockPtr]
131	if ok {
132		return bs.block, nil
133	}
134	return nil, errors.WithStack(data.NoSuchBlockError{ID: blockPtr.ID})
135}
136
137func (bps *blockPutStateMemory) getReadyBlockData(
138	_ context.Context, blockPtr data.BlockPointer) (data.ReadyBlockData, error) {
139	bs, ok := bps.blockStates[blockPtr]
140	if ok {
141		return bs.readyBlockData, nil
142	}
143	return data.ReadyBlockData{}, errors.WithStack(
144		data.NoSuchBlockError{ID: blockPtr.ID})
145}
146
147func (bps *blockPutStateMemory) synced(blockPtr data.BlockPointer) error {
148	bs, ok := bps.blockStates[blockPtr]
149	if ok && bs.syncedCb != nil {
150		return bs.syncedCb()
151	}
152	return nil
153}
154
155func (bps *blockPutStateMemory) numBlocks() int {
156	return len(bps.blockStates)
157}
158
159func (bps *blockPutStateMemory) deepCopy(
160	_ context.Context) (blockPutStateCopiable, error) {
161	newBps := &blockPutStateMemory{}
162	newBps.blockStates = make(map[data.BlockPointer]blockState, len(bps.blockStates))
163	for ptr, bs := range bps.blockStates {
164		newBps.blockStates[ptr] = bs
165	}
166	return newBps, nil
167}
168
169func (bps *blockPutStateMemory) deepCopyWithBlacklist(
170	_ context.Context, blacklist map[data.BlockPointer]bool) (
171	blockPutStateCopiable, error) {
172	newBps := &blockPutStateMemory{}
173	newLen := len(bps.blockStates) - len(blacklist)
174	if newLen < 0 {
175		newLen = 0
176	}
177	newBps.blockStates = make(map[data.BlockPointer]blockState, newLen)
178	for ptr, bs := range bps.blockStates {
179		// Only save the good pointers
180		if !blacklist[ptr] {
181			newBps.blockStates[ptr] = bs
182		}
183	}
184	return newBps, nil
185}
186