1// Copyright 2016 Keybase Inc. All rights reserved.
2// Use of this source code is governed by a BSD
3// license that can be found in the LICENSE file.
4
5package libkbfs
6
7import (
8	"reflect"
9	"sync"
10	"time"
11
12	"github.com/keybase/client/go/kbfs/kbfscrypto"
13	"github.com/keybase/client/go/kbfs/kbfsmd"
14	"github.com/keybase/client/go/kbfs/tlf"
15	"github.com/keybase/client/go/logger"
16	"github.com/keybase/client/go/protocol/keybase1"
17	"github.com/pkg/errors"
18	"golang.org/x/net/context"
19)
20
21// An mdHandleKey is an encoded tlf.Handle.
22type mdHandleKey string
23
24type mdBlockKey struct {
25	tlfID    tlf.ID
26	branchID kbfsmd.BranchID
27}
28
29type mdBranchKey struct {
30	tlfID     tlf.ID
31	deviceKey kbfscrypto.CryptPublicKey
32}
33
34type mdExtraWriterKey struct {
35	tlfID          tlf.ID
36	writerBundleID kbfsmd.TLFWriterKeyBundleID
37}
38
39type mdExtraReaderKey struct {
40	tlfID          tlf.ID
41	readerBundleID kbfsmd.TLFReaderKeyBundleID
42}
43
44type mdBlockMem struct {
45	// An encoded RootMetdataSigned.
46	encodedMd []byte
47	timestamp time.Time
48	version   kbfsmd.MetadataVer
49}
50
51type mdBlockMemList struct {
52	initialRevision kbfsmd.Revision
53	blocks          []mdBlockMem
54}
55
56const mdLockTimeout = time.Minute
57
58type mdLockMemKey struct {
59	tlfID  tlf.ID
60	lockID keybase1.LockID
61}
62
63type mdLockMemVal struct {
64	etime    time.Time
65	holder   mdServerLocal
66	released chan struct{}
67}
68
69type mdServerMemShared struct {
70	// Protects all *db variables and truncateLockManager. After
71	// Shutdown() is called, all *db variables and
72	// truncateLockManager are nil.
73	lock sync.RWMutex // nolint
74	// Bare TLF handle -> TLF ID
75	handleDb map[mdHandleKey]tlf.ID
76	// TLF ID -> latest bare TLF handle
77	latestHandleDb map[tlf.ID]tlf.Handle
78	// (TLF ID, branch ID) -> list of MDs
79	mdDb map[mdBlockKey]mdBlockMemList
80	// Writer key bundle ID -> writer key bundles
81	writerKeyBundleDb map[mdExtraWriterKey]kbfsmd.TLFWriterKeyBundleV3
82	// Reader key bundle ID -> reader key bundles
83	readerKeyBundleDb map[mdExtraReaderKey]kbfsmd.TLFReaderKeyBundleV3
84	// (TLF ID, crypt public key) -> branch ID
85	branchDb            map[mdBranchKey]kbfsmd.BranchID
86	truncateLockManager *mdServerLocalTruncateLockManager
87	// tracks expire time and holder
88	lockIDs              map[mdLockMemKey]mdLockMemVal
89	implicitTeamsEnabled bool // nolint
90	iTeamMigrationLocks  map[tlf.ID]bool
91	merkleRoots          map[keybase1.MerkleTreeID]*kbfsmd.MerkleRoot
92
93	updateManager *mdServerLocalUpdateManager
94}
95
96// MDServerMemory just stores metadata objects in memory.
97type MDServerMemory struct {
98	config mdServerLocalConfig
99	log    logger.Logger
100
101	*mdServerMemShared
102}
103
104var _ mdServerLocal = (*MDServerMemory)(nil)
105
106// NewMDServerMemory constructs a new MDServerMemory object that stores
107// all data in-memory.
108func NewMDServerMemory(config mdServerLocalConfig) (*MDServerMemory, error) {
109	handleDb := make(map[mdHandleKey]tlf.ID)
110	latestHandleDb := make(map[tlf.ID]tlf.Handle)
111	mdDb := make(map[mdBlockKey]mdBlockMemList)
112	branchDb := make(map[mdBranchKey]kbfsmd.BranchID)
113	writerKeyBundleDb := make(map[mdExtraWriterKey]kbfsmd.TLFWriterKeyBundleV3)
114	readerKeyBundleDb := make(map[mdExtraReaderKey]kbfsmd.TLFReaderKeyBundleV3)
115	log := config.MakeLogger("MDSM")
116	truncateLockManager := newMDServerLocalTruncatedLockManager()
117	shared := mdServerMemShared{
118		handleDb:            handleDb,
119		latestHandleDb:      latestHandleDb,
120		mdDb:                mdDb,
121		branchDb:            branchDb,
122		writerKeyBundleDb:   writerKeyBundleDb,
123		readerKeyBundleDb:   readerKeyBundleDb,
124		truncateLockManager: &truncateLockManager,
125		lockIDs:             make(map[mdLockMemKey]mdLockMemVal),
126		iTeamMigrationLocks: make(map[tlf.ID]bool),
127		updateManager:       newMDServerLocalUpdateManager(),
128		merkleRoots:         make(map[keybase1.MerkleTreeID]*kbfsmd.MerkleRoot),
129	}
130	mdserv := &MDServerMemory{config, log, &shared}
131	return mdserv, nil
132}
133
134type errMDServerMemoryShutdown struct{}
135
136func (e errMDServerMemoryShutdown) Error() string {
137	return "MDServerMemory is shutdown"
138}
139
140func (md *MDServerMemory) checkShutdownRLocked() error {
141	if md.handleDb == nil {
142		return errors.WithStack(errMDServerMemoryShutdown{})
143	}
144	return nil
145}
146
147func (md *MDServerMemory) enableImplicitTeams() {
148	md.lock.Lock()
149	defer md.lock.Unlock()
150	md.implicitTeamsEnabled = true
151}
152
153func (md *MDServerMemory) setKbfsMerkleRoot(
154	treeID keybase1.MerkleTreeID, root *kbfsmd.MerkleRoot) {
155	md.lock.Lock()
156	defer md.lock.Unlock()
157	md.merkleRoots[treeID] = root
158}
159
160func (md *MDServerMemory) getHandleID(ctx context.Context, handle tlf.Handle,
161	mStatus kbfsmd.MergeStatus) (tlfID tlf.ID, created bool, err error) {
162	handleBytes, err := md.config.Codec().Encode(handle)
163	if err != nil {
164		return tlf.NullID, false, kbfsmd.ServerError{Err: err}
165	}
166
167	md.lock.RLock()
168	defer md.lock.RUnlock()
169	err = md.checkShutdownRLocked()
170	if err != nil {
171		return tlf.NullID, false, err
172	}
173
174	id, ok := md.handleDb[mdHandleKey(handleBytes)]
175	if ok {
176		return id, false, nil
177	}
178
179	// Non-readers shouldn't be able to create the dir.
180	session, err := md.config.currentSessionGetter().GetCurrentSession(ctx)
181	if err != nil {
182		return tlf.NullID, false, kbfsmd.ServerError{Err: err}
183	}
184	if handle.Type() == tlf.SingleTeam {
185		isReader, err := md.config.teamMembershipChecker().IsTeamReader(
186			ctx, handle.Writers[0].AsTeamOrBust(), session.UID,
187			keybase1.OfflineAvailability_NONE)
188		if err != nil {
189			return tlf.NullID, false, kbfsmd.ServerError{Err: err}
190		}
191		if !isReader {
192			return tlf.NullID, false, errors.WithStack(
193				kbfsmd.ServerErrorUnauthorized{})
194		}
195	} else if !handle.IsReader(session.UID.AsUserOrTeam()) {
196		return tlf.NullID, false, errors.WithStack(
197			kbfsmd.ServerErrorUnauthorized{})
198	}
199
200	if md.implicitTeamsEnabled {
201		return tlf.NullID, false, kbfsmd.ServerErrorClassicTLFDoesNotExist{}
202	}
203
204	// Allocate a new random ID.
205	id, err = md.config.cryptoPure().MakeRandomTlfID(handle.Type())
206	if err != nil {
207		return tlf.NullID, false, kbfsmd.ServerError{Err: err}
208	}
209
210	md.handleDb[mdHandleKey(handleBytes)] = id
211	md.latestHandleDb[id] = handle
212	return id, true, nil
213}
214
215// GetForHandle implements the MDServer interface for MDServerMemory.
216func (md *MDServerMemory) GetForHandle(ctx context.Context, handle tlf.Handle,
217	mStatus kbfsmd.MergeStatus, _ *keybase1.LockID) (
218	tlf.ID, *RootMetadataSigned, error) {
219	if err := checkContext(ctx); err != nil {
220		return tlf.NullID, nil, err
221	}
222
223	id, created, err := md.getHandleID(ctx, handle, mStatus)
224	if err != nil {
225		return tlf.NullID, nil, err
226	}
227
228	if created {
229		return id, nil, nil
230	}
231
232	rmds, err := md.GetForTLF(ctx, id, kbfsmd.NullBranchID, mStatus, nil)
233	if err != nil {
234		return tlf.NullID, nil, err
235	}
236	return id, rmds, nil
237}
238
239func (md *MDServerMemory) checkGetParamsRLocked(
240	ctx context.Context, id tlf.ID, bid kbfsmd.BranchID, mStatus kbfsmd.MergeStatus) (
241	newBid kbfsmd.BranchID, err error) {
242	if mStatus == kbfsmd.Merged && bid != kbfsmd.NullBranchID {
243		return kbfsmd.NullBranchID, kbfsmd.ServerErrorBadRequest{Reason: "Invalid branch ID"}
244	}
245
246	// Check permissions
247
248	mergedMasterHead, err :=
249		md.getHeadForTLFRLocked(ctx, id, kbfsmd.NullBranchID, kbfsmd.Merged)
250	if err != nil {
251		return kbfsmd.NullBranchID, kbfsmd.ServerError{Err: err}
252	}
253
254	session, err := md.config.currentSessionGetter().GetCurrentSession(ctx)
255	if err != nil {
256		return kbfsmd.NullBranchID, kbfsmd.ServerError{Err: err}
257	}
258
259	// TODO: Figure out nil case.
260	if mergedMasterHead != nil {
261		extra, err := getExtraMetadata(
262			md.getKeyBundlesRLocked, mergedMasterHead.MD)
263		if err != nil {
264			return kbfsmd.NullBranchID, kbfsmd.ServerError{Err: err}
265		}
266		ok, err := isReader(ctx, md.config.teamMembershipChecker(), session.UID,
267			mergedMasterHead.MD, extra)
268		if err != nil {
269			return kbfsmd.NullBranchID, kbfsmd.ServerError{Err: err}
270		}
271		if !ok {
272			return kbfsmd.NullBranchID, errors.WithStack(
273				kbfsmd.ServerErrorUnauthorized{})
274		}
275	}
276
277	// Lookup the branch ID if not supplied
278	if mStatus == kbfsmd.Unmerged && bid == kbfsmd.NullBranchID {
279		return md.getBranchIDRLocked(ctx, id)
280	}
281
282	return bid, nil
283}
284
285// GetForTLF implements the MDServer interface for MDServerMemory.
286func (md *MDServerMemory) GetForTLF(ctx context.Context, id tlf.ID,
287	bid kbfsmd.BranchID, mStatus kbfsmd.MergeStatus, _ *keybase1.LockID) (
288	*RootMetadataSigned, error) {
289	if err := checkContext(ctx); err != nil {
290		return nil, err
291	}
292
293	md.lock.RLock()
294	defer md.lock.RUnlock()
295
296	bid, err := md.checkGetParamsRLocked(ctx, id, bid, mStatus)
297	if err != nil {
298		return nil, err
299	}
300	if mStatus == kbfsmd.Unmerged && bid == kbfsmd.NullBranchID {
301		return nil, nil
302	}
303
304	rmds, err := md.getHeadForTLFRLocked(ctx, id, bid, mStatus)
305	if err != nil {
306		return nil, kbfsmd.ServerError{Err: err}
307	}
308	return rmds, nil
309}
310
311// GetForTLFByTime implements the MDServer interface for MDServerMemory.
312func (md *MDServerMemory) GetForTLFByTime(
313	ctx context.Context, id tlf.ID, serverTime time.Time) (
314	*RootMetadataSigned, error) {
315	if err := checkContext(ctx); err != nil {
316		return nil, err
317	}
318
319	md.lock.RLock()
320	defer md.lock.RUnlock()
321
322	key, err := md.getMDKey(id, kbfsmd.NullBranchID, kbfsmd.Merged)
323	if err != nil {
324		return nil, err
325	}
326	err = md.checkShutdownRLocked()
327	if err != nil {
328		return nil, err
329	}
330
331	blockList, ok := md.mdDb[key]
332	if !ok {
333		return nil, nil
334	}
335	blocks := blockList.blocks
336
337	// Iterate backward until we find a timestamp less than `serverTime`.
338	for i := len(blocks) - 1; i >= 0; i-- {
339		t := blocks[i].timestamp
340		if t.After(serverTime) {
341			continue
342		}
343
344		max := md.config.MetadataVersion()
345		ver := blocks[i].version
346		buf := blocks[i].encodedMd
347		rmds, err := DecodeRootMetadataSigned(
348			md.config.Codec(), id, ver, max, buf, t)
349		if err != nil {
350			return nil, err
351		}
352		return rmds, nil
353	}
354
355	return nil, errors.Errorf(
356		"No MD found for TLF %s and serverTime %s", id, serverTime)
357}
358
359func (md *MDServerMemory) getHeadForTLFRLocked(ctx context.Context, id tlf.ID,
360	bid kbfsmd.BranchID, mStatus kbfsmd.MergeStatus) (*RootMetadataSigned, error) {
361	key, err := md.getMDKey(id, bid, mStatus)
362	if err != nil {
363		return nil, err
364	}
365	err = md.checkShutdownRLocked()
366	if err != nil {
367		return nil, err
368	}
369
370	blockList, ok := md.mdDb[key]
371	if !ok {
372		return nil, nil
373	}
374	blocks := blockList.blocks
375	max := md.config.MetadataVersion()
376	ver := blocks[len(blocks)-1].version
377	buf := blocks[len(blocks)-1].encodedMd
378	timestamp := blocks[len(blocks)-1].timestamp
379	rmds, err := DecodeRootMetadataSigned(
380		md.config.Codec(), id, ver, max, buf, timestamp)
381	if err != nil {
382		return nil, err
383	}
384	return rmds, nil
385}
386
387func (md *MDServerMemory) getMDKey(
388	id tlf.ID, bid kbfsmd.BranchID, mStatus kbfsmd.MergeStatus) (mdBlockKey, error) {
389	if (mStatus == kbfsmd.Merged) != (bid == kbfsmd.NullBranchID) {
390		return mdBlockKey{},
391			errors.Errorf("mstatus=%v is inconsistent with bid=%v",
392				mStatus, bid)
393	}
394	return mdBlockKey{id, bid}, nil
395}
396
397func (md *MDServerMemory) getBranchKey(ctx context.Context, id tlf.ID) (
398	mdBranchKey, error) {
399	// add device key
400	deviceKey, err := md.getCurrentDeviceKey(ctx)
401	if err != nil {
402		return mdBranchKey{}, err
403	}
404	return mdBranchKey{id, deviceKey}, nil
405}
406
407func (md *MDServerMemory) getCurrentDeviceKey(ctx context.Context) (
408	kbfscrypto.CryptPublicKey, error) {
409	session, err := md.config.currentSessionGetter().GetCurrentSession(ctx)
410	if err != nil {
411		return kbfscrypto.CryptPublicKey{}, err
412	}
413	return session.CryptPublicKey, nil
414}
415
416// GetRange implements the MDServer interface for MDServerMemory.
417func (md *MDServerMemory) getRangeLocked(ctx context.Context, id tlf.ID,
418	bid kbfsmd.BranchID, mStatus kbfsmd.MergeStatus, start, stop kbfsmd.Revision,
419	lockBeforeGet *keybase1.LockID) (
420	rmdses []*RootMetadataSigned, lockWaitCh <-chan struct{}, err error) {
421	md.log.CDebugf(ctx, "GetRange %d %d (%s)", start, stop, mStatus)
422	bid, err = md.checkGetParamsRLocked(ctx, id, bid, mStatus)
423	if err != nil {
424		return nil, nil, err
425	}
426
427	if lockBeforeGet != nil {
428		lockWaitCh = md.lockLocked(ctx, id, *lockBeforeGet)
429		if lockWaitCh != nil {
430			return nil, lockWaitCh, nil
431		}
432		defer func() {
433			if err != nil {
434				md.releaseLockLocked(ctx, id, *lockBeforeGet)
435			}
436		}()
437	}
438
439	if mStatus == kbfsmd.Unmerged && bid == kbfsmd.NullBranchID {
440		return nil, nil, nil
441	}
442
443	key, err := md.getMDKey(id, bid, mStatus)
444	if err != nil {
445		return nil, nil, kbfsmd.ServerError{Err: err}
446	}
447
448	err = md.checkShutdownRLocked()
449	if err != nil {
450		return nil, nil, err
451	}
452
453	blockList, ok := md.mdDb[key]
454	if !ok {
455		return nil, nil, nil
456	}
457
458	startI := int(start - blockList.initialRevision)
459	if startI < 0 {
460		startI = 0
461	}
462	endI := int(stop - blockList.initialRevision + 1)
463	blocks := blockList.blocks
464	if endI > len(blocks) {
465		endI = len(blocks)
466	}
467
468	max := md.config.MetadataVersion()
469
470	for i := startI; i < endI; i++ {
471		ver := blocks[i].version
472		buf := blocks[i].encodedMd
473		rmds, err := DecodeRootMetadataSigned(
474			md.config.Codec(), id, ver, max, buf,
475			blocks[i].timestamp)
476		if err != nil {
477			return nil, nil, kbfsmd.ServerError{Err: err}
478		}
479		expectedRevision := blockList.initialRevision + kbfsmd.Revision(i)
480		if expectedRevision != rmds.MD.RevisionNumber() {
481			panic(errors.Errorf("expected revision %v, got %v",
482				expectedRevision, rmds.MD.RevisionNumber()))
483		}
484		rmdses = append(rmdses, rmds)
485	}
486
487	return rmdses, nil, nil
488}
489
490func (md *MDServerMemory) doGetRange(ctx context.Context, id tlf.ID,
491	bid kbfsmd.BranchID, mStatus kbfsmd.MergeStatus, start, stop kbfsmd.Revision,
492	lockBeforeGet *keybase1.LockID) (
493	[]*RootMetadataSigned, <-chan struct{}, error) {
494	md.lock.Lock()
495	defer md.lock.Unlock()
496	return md.getRangeLocked(ctx, id, bid, mStatus, start, stop, lockBeforeGet)
497}
498
499// GetRange implements the MDServer interface for MDServerMemory.
500func (md *MDServerMemory) GetRange(ctx context.Context, id tlf.ID,
501	bid kbfsmd.BranchID, mStatus kbfsmd.MergeStatus, start, stop kbfsmd.Revision,
502	lockBeforeGet *keybase1.LockID) ([]*RootMetadataSigned, error) {
503	if err := checkContext(ctx); err != nil {
504		return nil, err
505	}
506
507	// An RPC-based client would receive a throttle message from the
508	// server and retry with backoff, but here we need to implement
509	// the retry logic explicitly.
510	for {
511		rmds, ch, err := md.doGetRange(
512			ctx, id, bid, mStatus, start, stop, lockBeforeGet)
513		if err != nil {
514			return nil, err
515		}
516		if ch == nil {
517			return rmds, err
518		}
519		select {
520		// TODO: wait for the clock to pass the expired time.  We'd
521		// need a new method in the `Clock` interface to support this.
522		case <-ch:
523			continue
524		case <-ctx.Done():
525			return nil, ctx.Err()
526		}
527	}
528}
529
530// Put implements the MDServer interface for MDServerMemory.
531func (md *MDServerMemory) Put(ctx context.Context, rmds *RootMetadataSigned,
532	extra kbfsmd.ExtraMetadata, lc *keybase1.LockContext, _ keybase1.MDPriority) error {
533	if err := checkContext(ctx); err != nil {
534		return err
535	}
536
537	session, err := md.config.currentSessionGetter().GetCurrentSession(ctx)
538	if err != nil {
539		return kbfsmd.ServerError{Err: err}
540	}
541
542	err = rmds.IsValidAndSigned(
543		ctx, md.config.Codec(), md.config.teamMembershipChecker(), extra,
544		keybase1.OfflineAvailability_NONE)
545	if err != nil {
546		return kbfsmd.ServerErrorBadRequest{Reason: err.Error()}
547	}
548
549	err = rmds.IsLastModifiedBy(session.UID, session.VerifyingKey)
550	if err != nil {
551		return kbfsmd.ServerErrorBadRequest{Reason: err.Error()}
552	}
553
554	id := rmds.MD.TlfID()
555
556	// Check permissions
557	md.lock.Lock()
558	defer md.lock.Unlock()
559
560	if lc != nil && !md.isLockedLocked(ctx, id, lc.RequireLockID) {
561		return kbfsmd.ServerErrorLockConflict{}
562	}
563
564	mergedMasterHead, err :=
565		md.getHeadForTLFRLocked(ctx, id, kbfsmd.NullBranchID, kbfsmd.Merged)
566	if err != nil {
567		return kbfsmd.ServerError{Err: err}
568	}
569
570	// TODO: Figure out nil case.
571	if mergedMasterHead != nil {
572		prevExtra, err := getExtraMetadata(
573			md.getKeyBundlesRLocked, mergedMasterHead.MD)
574		if err != nil {
575			return kbfsmd.ServerError{Err: err}
576		}
577		ok, err := isWriterOrValidRekey(
578			ctx, md.config.teamMembershipChecker(), md.config.Codec(),
579			session.UID, session.VerifyingKey, mergedMasterHead.MD,
580			rmds.MD, prevExtra, extra)
581		if err != nil {
582			return kbfsmd.ServerError{Err: err}
583		}
584		if !ok {
585			return errors.WithStack(kbfsmd.ServerErrorUnauthorized{})
586		}
587	}
588
589	bid := rmds.MD.BID()
590	mStatus := rmds.MD.MergedStatus()
591
592	head, err := md.getHeadForTLFRLocked(ctx, id, bid, mStatus)
593	if err != nil {
594		return kbfsmd.ServerError{Err: err}
595	}
596
597	var recordBranchID bool
598
599	if mStatus == kbfsmd.Unmerged && head == nil {
600		// currHead for unmerged history might be on the main branch
601		prevRev := rmds.MD.RevisionNumber() - 1
602		rmdses, ch, err := md.getRangeLocked(
603			ctx, id, kbfsmd.NullBranchID, kbfsmd.Merged, prevRev, prevRev, nil)
604		if err != nil {
605			return kbfsmd.ServerError{Err: err}
606		}
607		if ch != nil {
608			panic("Got non-nil lock channel with a nil lock context")
609		}
610		if len(rmdses) != 1 {
611			return kbfsmd.ServerError{
612				Err: errors.Errorf("Expected 1 MD block got %d", len(rmdses)),
613			}
614		}
615		head = rmdses[0]
616		recordBranchID = true
617	}
618
619	// Consistency checks
620	if head != nil {
621		id, err := kbfsmd.MakeID(md.config.Codec(), head.MD)
622		if err != nil {
623			return err
624		}
625		err = head.MD.CheckValidSuccessorForServer(id, rmds.MD)
626		if err != nil {
627			return err
628		}
629	}
630
631	// Record branch ID
632	if recordBranchID {
633		branchKey, err := md.getBranchKey(ctx, id)
634		if err != nil {
635			return kbfsmd.ServerError{Err: err}
636		}
637		err = md.checkShutdownRLocked()
638		if err != nil {
639			return err
640		}
641		md.branchDb[branchKey] = bid
642	}
643
644	encodedMd, err := kbfsmd.EncodeRootMetadataSigned(md.config.Codec(), &rmds.RootMetadataSigned)
645	if err != nil {
646		return kbfsmd.ServerError{Err: err}
647	}
648
649	// Pretend the timestamp went over RPC, so we get the same
650	// resolution level as a real server.
651	t := keybase1.FromTime(keybase1.ToTime(md.config.Clock().Now()))
652	block := mdBlockMem{encodedMd, t, rmds.MD.Version()}
653
654	// Add an entry with the revision key.
655	revKey, err := md.getMDKey(id, bid, mStatus)
656	if err != nil {
657		return kbfsmd.ServerError{Err: err}
658	}
659
660	err = md.checkShutdownRLocked()
661	if err != nil {
662		return err
663	}
664
665	blockList, ok := md.mdDb[revKey]
666	if ok {
667		blockList.blocks = append(blockList.blocks, block)
668		md.mdDb[revKey] = blockList
669	} else {
670		md.mdDb[revKey] = mdBlockMemList{
671			initialRevision: rmds.MD.RevisionNumber(),
672			blocks:          []mdBlockMem{block},
673		}
674	}
675
676	if err := md.putExtraMetadataLocked(rmds, extra); err != nil {
677		return kbfsmd.ServerError{Err: err}
678	}
679
680	if lc != nil && lc.ReleaseAfterSuccess {
681		md.releaseLockLocked(ctx, id, lc.RequireLockID)
682	}
683
684	if mStatus == kbfsmd.Merged &&
685		// Don't send notifies if it's just a rekey (the real mdserver
686		// sends a "folder needs rekey" notification in this case).
687		!(rmds.MD.IsRekeySet() && rmds.MD.IsWriterMetadataCopiedSet()) {
688		md.updateManager.setHead(id, md)
689	}
690
691	return nil
692}
693
694func (md *MDServerMemory) isLockedLocked(ctx context.Context,
695	tlfID tlf.ID, lockID keybase1.LockID) bool {
696	val, ok := md.lockIDs[mdLockMemKey{
697		tlfID:  tlfID,
698		lockID: lockID,
699	}]
700	if !ok {
701		return false
702	}
703	return val.etime.After(md.config.Clock().Now()) && md == val.holder
704}
705
706func (md *MDServerMemory) lockLocked(ctx context.Context,
707	tlfID tlf.ID, lockID keybase1.LockID) <-chan struct{} {
708	lockKey := mdLockMemKey{
709		tlfID:  tlfID,
710		lockID: lockID,
711	}
712	val, ok := md.lockIDs[lockKey]
713	if !ok || !val.etime.After(md.config.Clock().Now()) {
714		// The lock doesn't exist or has expired.
715		md.lockIDs[lockKey] = mdLockMemVal{
716			etime:    md.config.Clock().Now().Add(mdLockTimeout),
717			holder:   md,
718			released: make(chan struct{}),
719		}
720		if ok {
721			close(val.released)
722		}
723		return nil
724	} else if val.holder == md {
725		// The lock is already held by this instance; just return
726		// without refreshing timestamp.
727		return nil
728	}
729	// Someone else holds the lock; the caller needs to release
730	// md.lock and wait for this channel to close.
731	return val.released
732}
733
734func (md *MDServerMemory) releaseLockLocked(ctx context.Context,
735	tlfID tlf.ID, lockID keybase1.LockID) {
736	lockKey := mdLockMemKey{
737		tlfID:  tlfID,
738		lockID: lockID,
739	}
740	val, ok := md.lockIDs[lockKey]
741	if !ok || val.holder != md {
742		return
743	}
744	delete(md.lockIDs, lockKey)
745	close(val.released)
746}
747
748func (md *MDServerMemory) doLock(ctx context.Context,
749	tlfID tlf.ID, lockID keybase1.LockID) <-chan struct{} {
750	md.lock.Lock()
751	defer md.lock.Unlock()
752	return md.lockLocked(ctx, tlfID, lockID)
753}
754
755// Lock implements the MDServer interface for MDServerMemory.
756func (md *MDServerMemory) Lock(ctx context.Context,
757	tlfID tlf.ID, lockID keybase1.LockID) error {
758	// An RPC-based client would receive a throttle message from the
759	// server and retry with backoff, but here we need to implement
760	// the retry logic explicitly.
761	for {
762		ch := md.doLock(ctx, tlfID, lockID)
763		if ch == nil {
764			return nil
765		}
766		select {
767		// TODO: wait for the clock to pass the expired time.  We'd
768		// need a new method in the `Clock` interface to support this.
769		case <-ch:
770			continue
771		case <-ctx.Done():
772			return ctx.Err()
773		}
774	}
775}
776
777// ReleaseLock implements the MDServer interface for MDServerMemory.
778func (md *MDServerMemory) ReleaseLock(ctx context.Context,
779	tlfID tlf.ID, lockID keybase1.LockID) error {
780	md.lock.Lock()
781	defer md.lock.Unlock()
782	md.releaseLockLocked(ctx, tlfID, lockID)
783	return nil
784}
785
786// StartImplicitTeamMigration implements the MDServer interface.
787func (md *MDServerMemory) StartImplicitTeamMigration(
788	ctx context.Context, id tlf.ID) (err error) {
789	md.lock.Lock()
790	defer md.lock.Unlock()
791	md.iTeamMigrationLocks[id] = true
792	return nil
793}
794
795// PruneBranch implements the MDServer interface for MDServerMemory.
796func (md *MDServerMemory) PruneBranch(ctx context.Context, id tlf.ID, bid kbfsmd.BranchID) error {
797	if err := checkContext(ctx); err != nil {
798		return err
799	}
800
801	if bid == kbfsmd.NullBranchID {
802		return kbfsmd.ServerErrorBadRequest{Reason: "Invalid branch ID"}
803	}
804
805	md.lock.Lock()
806	defer md.lock.Unlock()
807
808	currBID, err := md.getBranchIDRLocked(ctx, id)
809	if err != nil {
810		return err
811	}
812	if currBID == kbfsmd.NullBranchID || bid != currBID {
813		return kbfsmd.ServerErrorBadRequest{Reason: "Invalid branch ID"}
814	}
815
816	// Don't actually delete unmerged history. This is intentional to be consistent
817	// with the mdserver behavior-- it garbage collects discarded branches in the
818	// background.
819	branchKey, err := md.getBranchKey(ctx, id)
820	if err != nil {
821		return kbfsmd.ServerError{Err: err}
822	}
823	err = md.checkShutdownRLocked()
824	if err != nil {
825		return err
826	}
827
828	delete(md.branchDb, branchKey)
829	return nil
830}
831
832func (md *MDServerMemory) getBranchIDRLocked(ctx context.Context, id tlf.ID) (kbfsmd.BranchID, error) {
833	branchKey, err := md.getBranchKey(ctx, id)
834	if err != nil {
835		return kbfsmd.NullBranchID, kbfsmd.ServerError{Err: err}
836	}
837	err = md.checkShutdownRLocked()
838	if err != nil {
839		return kbfsmd.NullBranchID, err
840	}
841
842	bid, ok := md.branchDb[branchKey]
843	if !ok {
844		return kbfsmd.NullBranchID, nil
845	}
846	return bid, nil
847}
848
849// RegisterForUpdate implements the MDServer interface for MDServerMemory.
850func (md *MDServerMemory) RegisterForUpdate(ctx context.Context, id tlf.ID,
851	currHead kbfsmd.Revision) (<-chan error, error) {
852	if err := checkContext(ctx); err != nil {
853		return nil, err
854	}
855
856	// are we already past this revision?  If so, fire observer
857	// immediately
858	currMergedHeadRev, err := md.getCurrentMergedHeadRevision(ctx, id)
859	if err != nil {
860		return nil, err
861	}
862
863	c := md.updateManager.registerForUpdate(id, currHead, currMergedHeadRev, md)
864	return c, nil
865}
866
867// CancelRegistration implements the MDServer interface for MDServerMemory.
868func (md *MDServerMemory) CancelRegistration(_ context.Context, id tlf.ID) {
869	md.updateManager.cancel(id, md)
870}
871
872// TruncateLock implements the MDServer interface for MDServerMemory.
873func (md *MDServerMemory) TruncateLock(ctx context.Context, id tlf.ID) (
874	bool, error) {
875	if err := checkContext(ctx); err != nil {
876		return false, err
877	}
878
879	md.lock.Lock()
880	defer md.lock.Unlock()
881	err := md.checkShutdownRLocked()
882	if err != nil {
883		return false, err
884	}
885
886	myKey, err := md.getCurrentDeviceKey(ctx)
887	if err != nil {
888		return false, err
889	}
890
891	return md.truncateLockManager.truncateLock(myKey, id)
892}
893
894// TruncateUnlock implements the MDServer interface for MDServerMemory.
895func (md *MDServerMemory) TruncateUnlock(ctx context.Context, id tlf.ID) (
896	bool, error) {
897	if err := checkContext(ctx); err != nil {
898		return false, err
899	}
900
901	md.lock.Lock()
902	defer md.lock.Unlock()
903	err := md.checkShutdownRLocked()
904	if err != nil {
905		return false, err
906	}
907
908	myKey, err := md.getCurrentDeviceKey(ctx)
909	if err != nil {
910		return false, err
911	}
912
913	return md.truncateLockManager.truncateUnlock(myKey, id)
914}
915
916// Shutdown implements the MDServer interface for MDServerMemory.
917func (md *MDServerMemory) Shutdown() {
918	md.lock.Lock()
919	defer md.lock.Unlock()
920	md.handleDb = nil
921	md.latestHandleDb = nil
922	md.branchDb = nil
923	md.truncateLockManager = nil
924}
925
926// IsConnected implements the MDServer interface for MDServerMemory.
927func (md *MDServerMemory) IsConnected() bool {
928	return !md.isShutdown()
929}
930
931// RefreshAuthToken implements the MDServer interface for MDServerMemory.
932func (md *MDServerMemory) RefreshAuthToken(ctx context.Context) {}
933
934// This should only be used for testing with an in-memory server.
935func (md *MDServerMemory) copy(config mdServerLocalConfig) mdServerLocal {
936	// NOTE: observers and sessionHeads are copied shallowly on
937	// purpose, so that the MD server that gets a Put will notify all
938	// observers correctly no matter where they got on the list.
939	log := config.MakeLogger("")
940	return &MDServerMemory{config, log, md.mdServerMemShared}
941}
942
943// isShutdown returns whether the logical, shared MDServer instance
944// has been shut down.
945func (md *MDServerMemory) isShutdown() bool {
946	md.lock.RLock()
947	defer md.lock.RUnlock()
948	return md.checkShutdownRLocked() != nil
949}
950
951// DisableRekeyUpdatesForTesting implements the MDServer interface.
952func (md *MDServerMemory) DisableRekeyUpdatesForTesting() {
953	// Nothing to do.
954}
955
956// CheckForRekeys implements the MDServer interface.
957func (md *MDServerMemory) CheckForRekeys(ctx context.Context) <-chan error {
958	// Nothing to do
959	c := make(chan error, 1)
960	c <- nil
961	return c
962}
963
964func (md *MDServerMemory) addNewAssertionForTest(uid keybase1.UID,
965	newAssertion keybase1.SocialAssertion) error {
966	md.lock.Lock()
967	defer md.lock.Unlock()
968	err := md.checkShutdownRLocked()
969	if err != nil {
970		return err
971	}
972
973	// Iterate through all the handles, and add handles for ones
974	// containing newAssertion to now include the uid.
975	for hBytes, id := range md.handleDb {
976		var h tlf.Handle
977		err := md.config.Codec().Decode([]byte(hBytes), &h)
978		if err != nil {
979			return err
980		}
981		assertions := map[keybase1.SocialAssertion]keybase1.UID{
982			newAssertion: uid,
983		}
984		newH := h.ResolveAssertions(assertions)
985		if reflect.DeepEqual(h, newH) {
986			continue
987		}
988		newHBytes, err := md.config.Codec().Encode(newH)
989		if err != nil {
990			return err
991		}
992		md.handleDb[mdHandleKey(newHBytes)] = id
993	}
994	return nil
995}
996
997func (md *MDServerMemory) getCurrentMergedHeadRevision(
998	ctx context.Context, id tlf.ID) (rev kbfsmd.Revision, err error) {
999	head, err := md.GetForTLF(ctx, id, kbfsmd.NullBranchID, kbfsmd.Merged, nil)
1000	if err != nil {
1001		return 0, err
1002	}
1003	if head != nil {
1004		rev = head.MD.RevisionNumber()
1005	}
1006	return
1007}
1008
1009// GetLatestHandleForTLF implements the MDServer interface for MDServerMemory.
1010func (md *MDServerMemory) GetLatestHandleForTLF(ctx context.Context,
1011	id tlf.ID) (tlf.Handle, error) {
1012	if err := checkContext(ctx); err != nil {
1013		return tlf.Handle{}, err
1014	}
1015
1016	md.lock.RLock()
1017	defer md.lock.RUnlock()
1018	err := md.checkShutdownRLocked()
1019	if err != nil {
1020		return tlf.Handle{}, err
1021	}
1022
1023	return md.latestHandleDb[id], nil
1024}
1025
1026// OffsetFromServerTime implements the MDServer interface for
1027// MDServerMemory.
1028func (md *MDServerMemory) OffsetFromServerTime() (time.Duration, bool) {
1029	return 0, true
1030}
1031
1032func (md *MDServerMemory) putExtraMetadataLocked(rmds *RootMetadataSigned,
1033	extra kbfsmd.ExtraMetadata) error {
1034	if extra == nil {
1035		return nil
1036	}
1037
1038	extraV3, ok := extra.(*kbfsmd.ExtraMetadataV3)
1039	if !ok {
1040		return errors.New("Invalid extra metadata")
1041	}
1042
1043	tlfID := rmds.MD.TlfID()
1044
1045	if extraV3.IsWriterKeyBundleNew() {
1046		wkbID := rmds.MD.GetTLFWriterKeyBundleID()
1047		if wkbID == (kbfsmd.TLFWriterKeyBundleID{}) {
1048			panic("writer key bundle ID is empty")
1049		}
1050		md.writerKeyBundleDb[mdExtraWriterKey{tlfID, wkbID}] =
1051			extraV3.GetWriterKeyBundle()
1052	}
1053
1054	if extraV3.IsReaderKeyBundleNew() {
1055		rkbID := rmds.MD.GetTLFReaderKeyBundleID()
1056		if rkbID == (kbfsmd.TLFReaderKeyBundleID{}) {
1057			panic("reader key bundle ID is empty")
1058		}
1059		md.readerKeyBundleDb[mdExtraReaderKey{tlfID, rkbID}] =
1060			extraV3.GetReaderKeyBundle()
1061	}
1062	return nil
1063}
1064
1065func (md *MDServerMemory) getKeyBundlesRLocked(tlfID tlf.ID,
1066	wkbID kbfsmd.TLFWriterKeyBundleID, rkbID kbfsmd.TLFReaderKeyBundleID) (
1067	*kbfsmd.TLFWriterKeyBundleV3, *kbfsmd.TLFReaderKeyBundleV3, error) {
1068	err := md.checkShutdownRLocked()
1069	if err != nil {
1070		return nil, nil, err
1071	}
1072
1073	var wkb *kbfsmd.TLFWriterKeyBundleV3
1074	if wkbID != (kbfsmd.TLFWriterKeyBundleID{}) {
1075		foundWKB, ok := md.writerKeyBundleDb[mdExtraWriterKey{tlfID, wkbID}]
1076		if !ok {
1077			return nil, nil, errors.Errorf(
1078				"Could not find WKB for ID %s", wkbID)
1079		}
1080
1081		err := kbfsmd.CheckWKBID(md.config.Codec(), wkbID, foundWKB)
1082		if err != nil {
1083			return nil, nil, err
1084		}
1085
1086		wkb = &foundWKB
1087	}
1088
1089	var rkb *kbfsmd.TLFReaderKeyBundleV3
1090	if rkbID != (kbfsmd.TLFReaderKeyBundleID{}) {
1091		foundRKB, ok := md.readerKeyBundleDb[mdExtraReaderKey{tlfID, rkbID}]
1092		if !ok {
1093			return nil, nil, errors.Errorf(
1094				"Could not find RKB for ID %s", rkbID)
1095		}
1096
1097		err := kbfsmd.CheckRKBID(md.config.Codec(), rkbID, foundRKB)
1098		if err != nil {
1099			return nil, nil, err
1100		}
1101
1102		rkb = &foundRKB
1103	}
1104
1105	return wkb, rkb, nil
1106}
1107
1108// GetKeyBundles implements the MDServer interface for MDServerMemory.
1109func (md *MDServerMemory) GetKeyBundles(ctx context.Context,
1110	tlfID tlf.ID, wkbID kbfsmd.TLFWriterKeyBundleID, rkbID kbfsmd.TLFReaderKeyBundleID) (
1111	*kbfsmd.TLFWriterKeyBundleV3, *kbfsmd.TLFReaderKeyBundleV3, error) {
1112	if err := checkContext(ctx); err != nil {
1113		return nil, nil, err
1114	}
1115
1116	md.lock.RLock()
1117	defer md.lock.RUnlock()
1118
1119	wkb, rkb, err := md.getKeyBundlesRLocked(tlfID, wkbID, rkbID)
1120	if err != nil {
1121		return nil, nil, kbfsmd.ServerError{Err: err}
1122	}
1123	return wkb, rkb, nil
1124}
1125
1126// CheckReachability implements the MDServer interface for MDServerMemory.
1127func (md *MDServerMemory) CheckReachability(ctx context.Context) {}
1128
1129// FastForwardBackoff implements the MDServer interface for MDServerMemory.
1130func (md *MDServerMemory) FastForwardBackoff() {}
1131
1132// FindNextMD implements the MDServer interface for MDServerMemory.
1133func (md *MDServerMemory) FindNextMD(
1134	ctx context.Context, tlfID tlf.ID, rootSeqno keybase1.Seqno) (
1135	nextKbfsRoot *kbfsmd.MerkleRoot, nextMerkleNodes [][]byte,
1136	nextRootSeqno keybase1.Seqno, err error) {
1137	return nil, nil, 0, nil
1138}
1139
1140// GetMerkleRootLatest implements the MDServer interface for MDServerMemory.
1141func (md *MDServerMemory) GetMerkleRootLatest(
1142	ctx context.Context, treeID keybase1.MerkleTreeID) (
1143	root *kbfsmd.MerkleRoot, err error) {
1144	md.lock.RLock()
1145	defer md.lock.RUnlock()
1146	return md.merkleRoots[treeID], nil
1147}
1148