1// Copyright 2016 Keybase Inc. All rights reserved.
2// Use of this source code is governed by a BSD
3// license that can be found in the LICENSE file.
4
5package libkbfs
6
7import (
8	"math/rand"
9	"sync"
10	"time"
11
12	"github.com/keybase/client/go/kbfs/data"
13	"github.com/keybase/client/go/kbfs/kbfsblock"
14	"github.com/keybase/client/go/kbfs/kbfscrypto"
15	"github.com/keybase/client/go/kbfs/kbfsmd"
16	"github.com/keybase/client/go/kbfs/libcontext"
17	"github.com/keybase/client/go/kbfs/tlf"
18	"github.com/keybase/client/go/kbfs/tlfhandle"
19	"github.com/keybase/client/go/protocol/keybase1"
20	"golang.org/x/net/context"
21)
22
23type stallableOp string
24
25// StallableBlockOp defines an Op that is stallable using StallBlockOp
26type StallableBlockOp stallableOp
27
28// StallableMDOp defines an Op that is stallable using StallMDOp
29type StallableMDOp stallableOp
30
31// stallable Block Ops and MD Ops
32const (
33	StallableBlockGet StallableBlockOp = "Get"
34	StallableBlockPut StallableBlockOp = "Put"
35
36	StallableMDGetForTLF                    StallableMDOp = "GetForTLF"
37	StallableMDGetForTLFByTime              StallableMDOp = "GetForTLFByTime"
38	StallableMDGetLatestHandleForTLF        StallableMDOp = "GetLatestHandleForTLF"
39	StallableMDValidateLatestHandleNotFinal StallableMDOp = "ValidateLatestHandleNotFinal"
40	StallableMDGetUnmergedForTLF            StallableMDOp = "GetUnmergedForTLF"
41	StallableMDGetRange                     StallableMDOp = "GetRange"
42	StallableMDAfterGetRange                StallableMDOp = "AfterGetRange"
43	StallableMDGetUnmergedRange             StallableMDOp = "GetUnmergedRange"
44	StallableMDPut                          StallableMDOp = "Put"
45	StallableMDAfterPut                     StallableMDOp = "AfterPut"
46	StallableMDPutUnmerged                  StallableMDOp = "PutUnmerged"
47	StallableMDAfterPutUnmerged             StallableMDOp = "AfterPutUnmerged"
48	StallableMDPruneBranch                  StallableMDOp = "PruneBranch"
49	StallableMDResolveBranch                StallableMDOp = "ResolveBranch"
50)
51
52type stallKeyType uint64
53
54const stallKeyStallEverything stallKeyType = 0
55
56type naïveStallInfo struct {
57	onStalled               <-chan struct{}
58	unstall                 chan<- struct{}
59	oldBlockServer          BlockServer
60	oldMDOps                MDOps
61	oldJournalDelegateMDOps MDOps
62}
63
64// NaïveStaller is used to stall certain ops in BlockServer or
65// MDOps. Unlike StallBlockOp and StallMDOp which provides a way to
66// precisely control which particular op is stalled by passing in ctx
67// with corresponding stallKey, NaïveStaller simply stalls all
68// instances of specified op.
69type NaïveStaller struct {
70	config Config
71
72	mu             sync.RWMutex
73	blockOpsStalls map[StallableBlockOp]*naïveStallInfo
74	mdOpsStalls    map[StallableMDOp]*naïveStallInfo
75
76	// We are only supporting stalling one Op per kind at a time for now. If in
77	// the future a dsl test needs to stall different Ops, please see
78	// https://github.com/keybase/client/go/kbfs/pull/163 for an implementation.
79	blockStalled bool
80	mdStalled    bool
81}
82
83// NewNaïveStaller returns a new NaïveStaller
84func NewNaïveStaller(config Config) *NaïveStaller {
85	return &NaïveStaller{
86		config:         config,
87		blockOpsStalls: make(map[StallableBlockOp]*naïveStallInfo),
88		mdOpsStalls:    make(map[StallableMDOp]*naïveStallInfo),
89	}
90}
91
92func (s *NaïveStaller) getNaïveStallInfoForBlockOpOrBust(
93	stalledOp StallableBlockOp) *naïveStallInfo {
94	s.mu.RLock()
95	defer s.mu.RUnlock()
96	info, ok := s.blockOpsStalls[stalledOp]
97	if !ok {
98		panic("naïveStallInfo is not found." +
99			"This indicates incorrect use of NaïveStaller")
100	}
101	return info
102}
103
104func (s *NaïveStaller) getNaïveStallInfoForMDOpOrBust(
105	stalledOp StallableMDOp) *naïveStallInfo {
106	s.mu.RLock()
107	defer s.mu.RUnlock()
108	info, ok := s.mdOpsStalls[stalledOp]
109	if !ok {
110		panic("naïveStallInfo is not found." +
111			"This indicates incorrect use of NaïveStaller")
112	}
113	return info
114}
115
116// StallBlockOp wraps the internal BlockServer so that all subsequent stalledOp
117// will be stalled. This can be undone by calling UndoStallBlockOp.
118func (s *NaïveStaller) StallBlockOp(stalledOp StallableBlockOp, maxStalls int) {
119	s.mu.Lock()
120	defer s.mu.Unlock()
121	if s.blockStalled {
122		panic("incorrect use of NaïveStaller;" +
123			" only one stalled Op at a time is supported")
124	}
125	onStalledCh := make(chan struct{}, maxStalls)
126	unstallCh := make(chan struct{})
127	oldBlockServer := s.config.BlockServer()
128	s.config.SetBlockServer(&stallingBlockServer{
129		BlockServer: oldBlockServer,
130		stallOpName: stalledOp,
131		stallKey:    stallKeyStallEverything,
132		staller: staller{
133			stalled: onStalledCh,
134			unstall: unstallCh,
135		},
136	})
137	s.blockStalled = true
138	s.blockOpsStalls[stalledOp] = &naïveStallInfo{
139		onStalled:      onStalledCh,
140		unstall:        unstallCh,
141		oldBlockServer: oldBlockServer,
142	}
143}
144
145// StallMDOp wraps the internal MDOps so that all subsequent stalledOp
146// will be stalled. This can be undone by calling UndoStallMDOp.
147func (s *NaïveStaller) StallMDOp(stalledOp StallableMDOp, maxStalls int,
148	stallDelegate bool) {
149	s.mu.Lock()
150	defer s.mu.Unlock()
151	if s.mdStalled {
152		panic("incorrect use of NaïveStaller;" +
153			" only one stalled Op at a time is supported")
154	}
155	onStalledCh := make(chan struct{}, maxStalls)
156	unstallCh := make(chan struct{})
157	oldMDOps := s.config.MDOps()
158	var oldJDelegate MDOps
159	if jManager, err := GetJournalManager(s.config); err == nil && stallDelegate {
160		oldJDelegate = jManager.delegateMDOps
161		// Stall the delegate server as well
162		jManager.delegateMDOps = &stallingMDOps{
163			stallOpName: stalledOp,
164			stallKey:    stallKeyStallEverything,
165			staller: staller{
166				stalled: onStalledCh,
167				unstall: unstallCh,
168			},
169			delegate: jManager.delegateMDOps,
170		}
171		s.config.SetMDOps(jManager.mdOps())
172	} else {
173		s.config.SetMDOps(&stallingMDOps{
174			stallOpName: stalledOp,
175			stallKey:    stallKeyStallEverything,
176			staller: staller{
177				stalled: onStalledCh,
178				unstall: unstallCh,
179			},
180			delegate: oldMDOps,
181		})
182	}
183	s.mdStalled = true
184	s.mdOpsStalls[stalledOp] = &naïveStallInfo{
185		onStalled:               onStalledCh,
186		unstall:                 unstallCh,
187		oldMDOps:                oldMDOps,
188		oldJournalDelegateMDOps: oldJDelegate,
189	}
190}
191
192// WaitForStallBlockOp blocks until stalledOp is stalled. StallBlockOp should
193// have been called upon stalledOp, otherwise this would panic.
194func (s *NaïveStaller) WaitForStallBlockOp(stalledOp StallableBlockOp) {
195	<-s.getNaïveStallInfoForBlockOpOrBust(stalledOp).onStalled
196}
197
198// WaitForStallMDOp blocks until stalledOp is stalled. StallMDOp should
199// have been called upon stalledOp, otherwise this would panic.
200func (s *NaïveStaller) WaitForStallMDOp(stalledOp StallableMDOp) {
201	<-s.getNaïveStallInfoForMDOpOrBust(stalledOp).onStalled
202}
203
204// UnstallOneBlockOp unstalls exactly one stalled stalledOp. StallBlockOp
205// should have been called upon stalledOp, otherwise this would panic.
206func (s *NaïveStaller) UnstallOneBlockOp(stalledOp StallableBlockOp) {
207	s.getNaïveStallInfoForBlockOpOrBust(stalledOp).unstall <- struct{}{}
208}
209
210// UnstallOneMDOp unstalls exactly one stalled stalledOp. StallMDOp
211// should have been called upon stalledOp, otherwise this would panic.
212func (s *NaïveStaller) UnstallOneMDOp(stalledOp StallableMDOp) {
213	s.getNaïveStallInfoForMDOpOrBust(stalledOp).unstall <- struct{}{}
214}
215
216// UndoStallBlockOp reverts StallBlockOp so that future stalledOp are not
217// stalled anymore. It also unstalls any stalled stalledOp. StallBlockOp
218// should have been called upon stalledOp, otherwise this would panic.
219func (s *NaïveStaller) UndoStallBlockOp(stalledOp StallableBlockOp) {
220	ns := s.getNaïveStallInfoForBlockOpOrBust(stalledOp)
221	s.config.SetBlockServer(ns.oldBlockServer)
222	close(ns.unstall)
223	s.mu.Lock()
224	defer s.mu.Unlock()
225	s.blockStalled = false
226	delete(s.blockOpsStalls, stalledOp)
227}
228
229// UndoStallMDOp reverts StallMDOp so that future stalledOp are not
230// stalled anymore. It also unstalls any stalled stalledOp. StallMDOp
231// should have been called upon stalledOp, otherwise this would panic.
232func (s *NaïveStaller) UndoStallMDOp(stalledOp StallableMDOp) {
233	ns := s.getNaïveStallInfoForMDOpOrBust(stalledOp)
234	if jManager, err := GetJournalManager(s.config); err == nil &&
235		ns.oldJournalDelegateMDOps != nil {
236		jManager.delegateMDOps = ns.oldJournalDelegateMDOps
237	}
238	s.config.SetMDOps(ns.oldMDOps)
239	close(ns.unstall)
240	s.mu.Lock()
241	defer s.mu.Unlock()
242	s.mdStalled = false
243	delete(s.mdOpsStalls, stalledOp)
244}
245
246// StallBlockOp sets a wrapped BlockOps in config so that the specified Op, stalledOp,
247// is stalled. Caller should use the returned newCtx for subsequent operations
248// for the stall to be effective. onStalled is a channel to notify the caller
249// when the stall has happened. unstall is a channel for caller to unstall an
250// Op.
251func StallBlockOp(ctx context.Context, config Config,
252	stalledOp StallableBlockOp, maxStalls int) (
253	onStalled <-chan struct{}, unstall chan<- struct{}, newCtx context.Context) {
254	onStalledCh := make(chan struct{}, maxStalls)
255	unstallCh := make(chan struct{})
256	stallKey := newStallKey()
257	config.SetBlockServer(&stallingBlockServer{
258		BlockServer: config.BlockServer(),
259		stallOpName: stalledOp,
260		stallKey:    stallKey,
261		staller: staller{
262			stalled: onStalledCh,
263			unstall: unstallCh,
264		},
265	})
266	newCtx = libcontext.NewContextReplayable(ctx, func(ctx context.Context) context.Context {
267		return context.WithValue(ctx, stallKey, true)
268	})
269	return onStalledCh, unstallCh, newCtx
270}
271
272// StallMDOp sets a wrapped MDOps in config so that the specified Op,
273// stalledOp, is stalled. Caller should use the returned newCtx for subsequent
274// operations for the stall to be effective. onStalled is a channel to notify
275// the caller when the stall has happened. unstall is a channel for caller to
276// unstall an Op.
277func StallMDOp(ctx context.Context, config Config, stalledOp StallableMDOp,
278	maxStalls int) (
279	onStalled <-chan struct{}, unstall chan<- struct{}, newCtx context.Context) {
280	onStalledCh := make(chan struct{}, maxStalls)
281	unstallCh := make(chan struct{})
282	stallKey := newStallKey()
283	config.SetMDOps(&stallingMDOps{
284		stallOpName: stalledOp,
285		stallKey:    stallKey,
286		staller: staller{
287			stalled: onStalledCh,
288			unstall: unstallCh,
289		},
290		delegate: config.MDOps(),
291	})
292	newCtx = libcontext.NewContextReplayable(ctx, func(ctx context.Context) context.Context {
293		return context.WithValue(ctx, stallKey, true)
294	})
295	return onStalledCh, unstallCh, newCtx
296}
297
298func newStallKey() stallKeyType {
299	stallKey := stallKeyStallEverything
300	for stallKey == stallKeyStallEverything {
301		stallKey = stallKeyType(rand.Int63())
302	}
303	return stallKey
304}
305
306// staller is a pair of channels. Whenever something is to be
307// stalled, a value is sent on stalled (if not blocked), and then
308// unstall is waited on.
309type staller struct {
310	stalled chan<- struct{}
311	unstall <-chan struct{}
312}
313
314func maybeStall(ctx context.Context, opName stallableOp,
315	stallOpName stallableOp, stallKey stallKeyType,
316	staller staller) {
317	if opName != stallOpName {
318		return
319	}
320
321	if stallKey != stallKeyStallEverything {
322		if v, ok := ctx.Value(stallKey).(bool); !ok || !v {
323			return
324		}
325	}
326
327	select {
328	case staller.stalled <- struct{}{}:
329	default:
330	}
331	<-staller.unstall
332}
333
334// runWithContextCheck checks ctx.Done() before and after running action. If
335// either ctx.Done() check has error, ctx's error is returned. Otherwise,
336// action's returned value is returned.
337func runWithContextCheck(ctx context.Context, action func(ctx context.Context) error) error {
338	select {
339	case <-ctx.Done():
340		return ctx.Err()
341	default:
342	}
343	err := action(ctx)
344	select {
345	case <-ctx.Done():
346		return ctx.Err()
347	default:
348	}
349	return err
350}
351
352// stallingBlockServer is an implementation of BlockServer whose
353// operations sometimes stall. In particular, if the operation name
354// matches stallOpName, and ctx.Value(stallKey) is a key in the
355// corresponding staller is used to stall the operation.
356type stallingBlockServer struct {
357	BlockServer
358	stallOpName StallableBlockOp
359	// stallKey is a key for switching on/off stalling. If it's present in ctx,
360	// and equal to `true`, the operation is stalled. This allows us to use the
361	// ctx to control stallings
362	stallKey stallKeyType
363	staller  staller
364}
365
366var _ BlockServer = (*stallingBlockServer)(nil)
367
368func (f *stallingBlockServer) maybeStall(ctx context.Context, opName StallableBlockOp) {
369	maybeStall(ctx, stallableOp(opName), stallableOp(f.stallOpName),
370		f.stallKey, f.staller)
371}
372
373func (f *stallingBlockServer) Get(
374	ctx context.Context, tlfID tlf.ID, id kbfsblock.ID,
375	bctx kbfsblock.Context, cacheType DiskBlockCacheType) (
376	buf []byte, serverHalf kbfscrypto.BlockCryptKeyServerHalf, err error) {
377	f.maybeStall(ctx, StallableBlockGet)
378	err = runWithContextCheck(ctx, func(ctx context.Context) error {
379		var errGet error
380		buf, serverHalf, errGet = f.BlockServer.Get(
381			ctx, tlfID, id, bctx, cacheType)
382		return errGet
383	})
384	return buf, serverHalf, err
385}
386
387func (f *stallingBlockServer) Put(
388	ctx context.Context, tlfID tlf.ID, id kbfsblock.ID,
389	bctx kbfsblock.Context, buf []byte,
390	serverHalf kbfscrypto.BlockCryptKeyServerHalf,
391	cacheType DiskBlockCacheType) error {
392	f.maybeStall(ctx, StallableBlockPut)
393	return runWithContextCheck(ctx, func(ctx context.Context) error {
394		return f.BlockServer.Put(
395			ctx, tlfID, id, bctx, buf, serverHalf, cacheType)
396	})
397}
398
399// stallingMDOps is an implementation of MDOps whose operations
400// sometimes stall. In particular, if the operation name matches
401// stallOpName, and ctx.Value(stallKey) is a key in the corresponding
402// staller is used to stall the operation.
403type stallingMDOps struct {
404	stallOpName StallableMDOp
405	// stallKey is a key for switching on/off stalling. If it's present in ctx,
406	// and equal to `true`, the operation is stalled. This allows us to use the
407	// ctx to control stallings
408	stallKey stallKeyType
409	staller  staller
410	delegate MDOps
411}
412
413var _ MDOps = (*stallingMDOps)(nil)
414
415func (m *stallingMDOps) maybeStall(ctx context.Context, opName StallableMDOp) {
416	maybeStall(ctx, stallableOp(opName), stallableOp(m.stallOpName),
417		m.stallKey, m.staller)
418}
419
420func (m *stallingMDOps) GetIDForHandle(
421	ctx context.Context, handle *tlfhandle.Handle) (tlfID tlf.ID, err error) {
422	return m.delegate.GetIDForHandle(ctx, handle)
423}
424
425func (m *stallingMDOps) GetForTLF(ctx context.Context, id tlf.ID,
426	lockBeforeGet *keybase1.LockID) (md ImmutableRootMetadata, err error) {
427	m.maybeStall(ctx, StallableMDGetForTLF)
428	err = runWithContextCheck(ctx, func(ctx context.Context) error {
429		var errGetForTLF error
430		md, errGetForTLF = m.delegate.GetForTLF(ctx, id, lockBeforeGet)
431		return errGetForTLF
432	})
433	return md, err
434}
435
436func (m *stallingMDOps) GetForTLFByTime(
437	ctx context.Context, id tlf.ID, serverTime time.Time) (
438	md ImmutableRootMetadata, err error) {
439	m.maybeStall(ctx, StallableMDGetForTLFByTime)
440	err = runWithContextCheck(ctx, func(ctx context.Context) error {
441		var errGetForTLF error
442		md, errGetForTLF = m.delegate.GetForTLFByTime(ctx, id, serverTime)
443		return errGetForTLF
444	})
445	return md, err
446}
447
448func (m *stallingMDOps) GetLatestHandleForTLF(ctx context.Context, id tlf.ID) (
449	h tlf.Handle, err error) {
450	m.maybeStall(ctx, StallableMDGetLatestHandleForTLF)
451	err = runWithContextCheck(ctx, func(ctx context.Context) error {
452		var errGetLatestHandleForTLF error
453		h, errGetLatestHandleForTLF = m.delegate.GetLatestHandleForTLF(
454			ctx, id)
455		return errGetLatestHandleForTLF
456	})
457	return h, err
458}
459
460func (m *stallingMDOps) ValidateLatestHandleNotFinal(
461	ctx context.Context, h *tlfhandle.Handle) (b bool, err error) {
462	m.maybeStall(ctx, StallableMDValidateLatestHandleNotFinal)
463	err = runWithContextCheck(ctx, func(ctx context.Context) error {
464		var errValidateLatestHandleNotFinal error
465		b, errValidateLatestHandleNotFinal =
466			m.delegate.ValidateLatestHandleNotFinal(ctx, h)
467		return errValidateLatestHandleNotFinal
468	})
469	return b, err
470}
471
472func (m *stallingMDOps) GetUnmergedForTLF(ctx context.Context, id tlf.ID,
473	bid kbfsmd.BranchID) (md ImmutableRootMetadata, err error) {
474	m.maybeStall(ctx, StallableMDGetUnmergedForTLF)
475	err = runWithContextCheck(ctx, func(ctx context.Context) error {
476		var errGetUnmergedForTLF error
477		md, errGetUnmergedForTLF = m.delegate.GetUnmergedForTLF(ctx, id, bid)
478		return errGetUnmergedForTLF
479	})
480	return md, err
481}
482
483func (m *stallingMDOps) GetRange(ctx context.Context, id tlf.ID,
484	start, stop kbfsmd.Revision, lockBeforeGet *keybase1.LockID) (
485	mds []ImmutableRootMetadata, err error) {
486	m.maybeStall(ctx, StallableMDGetRange)
487	err = runWithContextCheck(ctx, func(ctx context.Context) error {
488		var errGetRange error
489		mds, errGetRange = m.delegate.GetRange(
490			ctx, id, start, stop, lockBeforeGet)
491		m.maybeStall(ctx, StallableMDAfterGetRange)
492		return errGetRange
493	})
494	return mds, err
495}
496
497func (m *stallingMDOps) GetUnmergedRange(ctx context.Context, id tlf.ID,
498	bid kbfsmd.BranchID, start, stop kbfsmd.Revision) (mds []ImmutableRootMetadata, err error) {
499	m.maybeStall(ctx, StallableMDGetUnmergedRange)
500	err = runWithContextCheck(ctx, func(ctx context.Context) error {
501		var errGetUnmergedRange error
502		mds, errGetUnmergedRange = m.delegate.GetUnmergedRange(
503			ctx, id, bid, start, stop)
504		return errGetUnmergedRange
505	})
506	return mds, err
507}
508
509func (m *stallingMDOps) Put(
510	ctx context.Context, md *RootMetadata, verifyingKey kbfscrypto.VerifyingKey,
511	lockContext *keybase1.LockContext, priority keybase1.MDPriority,
512	bps data.BlockPutState) (irmd ImmutableRootMetadata, err error) {
513	m.maybeStall(ctx, StallableMDPut)
514	err = runWithContextCheck(ctx, func(ctx context.Context) error {
515		irmd, err = m.delegate.Put(
516			ctx, md, verifyingKey, lockContext, priority, bps)
517		m.maybeStall(ctx, StallableMDAfterPut)
518		return err
519	})
520	return irmd, err
521}
522
523func (m *stallingMDOps) PutUnmerged(
524	ctx context.Context, md *RootMetadata,
525	verifyingKey kbfscrypto.VerifyingKey, bps data.BlockPutState) (
526	irmd ImmutableRootMetadata, err error) {
527	m.maybeStall(ctx, StallableMDPutUnmerged)
528	err = runWithContextCheck(ctx, func(ctx context.Context) error {
529		irmd, err = m.delegate.PutUnmerged(ctx, md, verifyingKey, bps)
530		m.maybeStall(ctx, StallableMDAfterPutUnmerged)
531		return err
532	})
533	return irmd, err
534}
535
536func (m *stallingMDOps) PruneBranch(
537	ctx context.Context, id tlf.ID, bid kbfsmd.BranchID) error {
538	m.maybeStall(ctx, StallableMDPruneBranch)
539	return runWithContextCheck(ctx, func(ctx context.Context) error {
540		return m.delegate.PruneBranch(ctx, id, bid)
541	})
542}
543
544func (m *stallingMDOps) ResolveBranch(
545	ctx context.Context, id tlf.ID, bid kbfsmd.BranchID,
546	blocksToDelete []kbfsblock.ID, rmd *RootMetadata,
547	verifyingKey kbfscrypto.VerifyingKey, bps data.BlockPutState) (
548	irmd ImmutableRootMetadata, err error) {
549	m.maybeStall(ctx, StallableMDResolveBranch)
550	err = runWithContextCheck(ctx, func(ctx context.Context) error {
551		irmd, err = m.delegate.ResolveBranch(
552			ctx, id, bid, blocksToDelete, rmd, verifyingKey, bps)
553		return err
554	})
555	return irmd, err
556}
557