1package libkb
2
3import (
4	"encoding/hex"
5	"sync"
6	"time"
7
8	keybase1 "github.com/keybase/client/go/protocol/keybase1"
9	"golang.org/x/sync/errgroup"
10)
11
12// Identify3Session corresponds to a single screen showing a user profile.
13// It maps 1:1 with an Identify2GUIID, and is labeled as such. Here we'll keep
14// track of whatever context we need to pass across calls, and also the TrackToken
15// for the final result.
16type Identify3Session struct {
17	sync.Mutex
18	created     time.Time
19	id          keybase1.Identify3GUIID
20	outcome     *IdentifyOutcome
21	trackBroken bool
22	needUpgrade bool
23	didExpire   bool // true if we ran an expire on this session (so we don't repeat)
24}
25
26func NewIdentify3GUIID() (keybase1.Identify3GUIID, error) {
27	var b []byte
28	l := 12
29	b, err := RandBytes(l)
30	if err != nil {
31		return keybase1.Identify3GUIID(""), err
32	}
33	b[l-1] = 0x34
34	return keybase1.Identify3GUIID(hex.EncodeToString(b)), nil
35}
36
37func NewIdentify3SessionWithID(mctx MetaContext, id keybase1.Identify3GUIID) *Identify3Session {
38	return &Identify3Session{
39		created: mctx.G().GetClock().Now(),
40		id:      id,
41	}
42}
43
44func NewIdentify3Session(mctx MetaContext) (*Identify3Session, error) {
45	id, err := NewIdentify3GUIID()
46	if err != nil {
47		return nil, err
48	}
49	ret := &Identify3Session{
50		created: mctx.G().GetClock().Now(),
51		id:      id,
52	}
53	mctx.Debug("generated new identify3 session: %s", id)
54	return ret, nil
55}
56
57func (s *Identify3Session) ID() keybase1.Identify3GUIID {
58	s.Lock()
59	defer s.Unlock()
60	return s.id
61}
62
63func (s *Identify3Session) ResultType() keybase1.Identify3ResultType {
64	s.Lock()
65	defer s.Unlock()
66	switch {
67	case s.trackBroken:
68		return keybase1.Identify3ResultType_BROKEN
69	case s.needUpgrade:
70		return keybase1.Identify3ResultType_NEEDS_UPGRADE
71	default:
72		return keybase1.Identify3ResultType_OK
73	}
74}
75
76func (s *Identify3Session) Outcome() *IdentifyOutcome {
77	s.Lock()
78	defer s.Unlock()
79	return s.outcome
80}
81
82func (s *Identify3Session) OutcomeLocked() *IdentifyOutcome {
83	return s.outcome
84}
85
86func (s *Identify3Session) SetTrackBroken() {
87	s.Lock()
88	defer s.Unlock()
89	s.trackBroken = true
90}
91
92func (s *Identify3Session) SetNeedUpgrade() {
93	s.Lock()
94	defer s.Unlock()
95	s.needUpgrade = true
96}
97
98func (s *Identify3Session) SetOutcome(o *IdentifyOutcome) {
99	s.Lock()
100	defer s.Unlock()
101	s.outcome = o
102}
103
104// Identify3State keeps track of all active ID3 state across the whole app. It has
105// a cache that's periodically cleaned up.
106type Identify3State struct {
107	sync.Mutex
108
109	expireCh   chan<- struct{}
110	shutdownCh chan struct{}
111	eg         errgroup.Group
112
113	// Table of keybase1.Identify3GUIID -> *identify3Session's
114	cache           map[keybase1.Identify3GUIID](*Identify3Session)
115	expirationQueue [](*Identify3Session)
116
117	defaultWaitTime time.Duration
118	expireTime      time.Duration
119
120	bgThreadTimeMu   sync.Mutex
121	testCompletionCh chan<- time.Time
122
123	shutdownMu sync.Mutex
124	shutdown   bool
125}
126
127func NewIdentify3State(g *GlobalContext) *Identify3State {
128	return newIdentify3State(g, nil)
129}
130
131func NewIdentify3StateForTest(g *GlobalContext) (*Identify3State, <-chan time.Time) {
132	ch := make(chan time.Time, 1000)
133	state := newIdentify3State(g, ch)
134	return state, ch
135}
136
137func newIdentify3State(g *GlobalContext, testCompletionCh chan<- time.Time) *Identify3State {
138	expireCh := make(chan struct{})
139	shutdownCh := make(chan struct{})
140	ret := &Identify3State{
141		expireCh:         expireCh,
142		shutdownCh:       shutdownCh,
143		cache:            make(map[keybase1.Identify3GUIID](*Identify3Session)),
144		defaultWaitTime:  time.Hour,
145		expireTime:       24 * time.Hour,
146		testCompletionCh: testCompletionCh,
147	}
148	ret.makeNewCache()
149	ret.eg.Go(func() error { return ret.runExpireThread(g, expireCh, shutdownCh) })
150	ret.pokeExpireThread()
151	return ret
152}
153
154func (s *Identify3State) Shutdown() chan struct{} {
155	ch := make(chan struct{})
156	if s.markShutdown() {
157		go func() {
158			_ = s.eg.Wait()
159			s.shutdownCh = nil
160			close(ch)
161		}()
162	} else {
163		close(ch)
164	}
165	return ch
166}
167
168func (s *Identify3State) isShutdown() bool {
169	s.shutdownMu.Lock()
170	defer s.shutdownMu.Unlock()
171	return s.shutdown
172}
173
174// markShutdown marks this state as having shutdown. Will return true the first
175// time through, and false every other time.
176func (s *Identify3State) markShutdown() bool {
177	s.shutdownMu.Lock()
178	defer s.shutdownMu.Unlock()
179	if s.shutdown {
180		return false
181	}
182	close(s.shutdownCh)
183	s.shutdown = true
184	return true
185}
186
187func (s *Identify3State) makeNewCache() {
188	s.Lock()
189	s.cache = make(map[keybase1.Identify3GUIID](*Identify3Session))
190	s.expirationQueue = nil
191	s.Unlock()
192}
193
194func (s *Identify3State) OnLogout() {
195	s.makeNewCache()
196	s.pokeExpireThread()
197}
198
199func (s *Identify3State) runExpireThread(g *GlobalContext, expireCh <-chan struct{},
200	shutdownCh chan struct{}) error {
201
202	mctx := NewMetaContextBackground(g)
203	wait := s.defaultWaitTime
204
205	nowFn := func() time.Time { return mctx.G().Clock().Now() }
206	now := nowFn()
207	wakeupTime := now.Add(wait)
208
209	for {
210		select {
211		case <-shutdownCh:
212			mctx.Debug("identify3State#runExpireThread: exiting on shutdown")
213			return nil
214		case <-expireCh:
215		case <-mctx.G().Clock().AfterTime(wakeupTime):
216			mctx.Debug("identify3State#runExpireThread: wakeup after %v timeout (at %v)", wait, wakeupTime)
217		}
218
219		// Guard all time manipulation in a lock for the purposes of testing.
220		// In real life, this shouldn't matter much, but it can't really hurt.
221		s.bgThreadTimeMu.Lock()
222		now = nowFn()
223		wait = s.expireSessions(mctx, now)
224		wakeupTime = now.Add(wait)
225		s.bgThreadTimeMu.Unlock()
226
227		// Also for the purposes of test, broadcast how far we've processing in time.
228		// In real life, this will be a noop, since s.testCompletionCh will be nil
229		if s.testCompletionCh != nil {
230			s.testCompletionCh <- now
231		}
232	}
233}
234
235func (s *Identify3Session) doExpireSession(mctx MetaContext) {
236	defer mctx.Trace("Identify3Session#doExpireSession", nil)()
237	s.Lock()
238	defer s.Unlock()
239	mctx.Debug("Identify3Session#doExpireSession(%s)", s.id)
240
241	if s.didExpire {
242		mctx.Warning("not repeating session expire for %s", s.id)
243		return
244	}
245	s.didExpire = true
246
247	cli, err := mctx.G().UIRouter.GetIdentify3UI(mctx)
248	if err != nil {
249		mctx.Warning("failed to get an electron UI to expire %s: %s", s.id, err)
250		return
251	}
252	if cli == nil {
253		mctx.Warning("failed to get an electron UI to expire %s: got nil", s.id)
254		return
255	}
256	err = cli.Identify3TrackerTimedOut(mctx.Ctx(), s.id)
257	if err != nil {
258		mctx.Warning("error timing ID3 session %s: %s", s.id, err)
259	}
260}
261
262func (s *Identify3State) expireSessions(mctx MetaContext, now time.Time) time.Duration {
263	defer mctx.Trace("Identify3State#expireSessions", nil)()
264
265	// getSesionsToExpire holds the Identify3State Mutex.
266	toExpire, diff := s.getSessionsToExpire(mctx, now)
267
268	// doExpireSessions does not hold the Identify3State Mutex, because it
269	// calls out to the front end via Identify3TrackedTimedOut.
270	s.doExpireSessions(mctx, toExpire)
271
272	return diff
273}
274
275func (s *Identify3State) doExpireSessions(mctx MetaContext, toExpire []*Identify3Session) {
276	for _, sess := range toExpire {
277		sess.doExpireSession(mctx)
278	}
279}
280
281func (s *Identify3State) getSessionsToExpire(mctx MetaContext, now time.Time) (ret []*Identify3Session, diff time.Duration) {
282	s.Lock()
283	defer s.Unlock()
284
285	for {
286		if len(s.expirationQueue) == 0 {
287			return ret, s.defaultWaitTime
288		}
289		var sess *Identify3Session
290		sess, diff = s.getSessionToExpire(mctx, now)
291		if diff > 0 {
292			return ret, diff
293		}
294		if sess != nil {
295			ret = append(ret, sess)
296		}
297	}
298}
299
300// getSessionToExpire should be called when holding the Identify3State Mutex. It looks in the
301// expiration queue and pops off those sessions that are ready to be marked expired.
302func (s *Identify3State) getSessionToExpire(mctx MetaContext, now time.Time) (*Identify3Session, time.Duration) {
303	sess := s.expirationQueue[0]
304	sess.Lock()
305	defer sess.Unlock()
306	expireAt := sess.created.Add(s.expireTime)
307	diff := expireAt.Sub(now)
308	if diff > 0 {
309		return nil, diff
310	}
311	s.expirationQueue = s.expirationQueue[1:]
312
313	// Only send the expiration if the session is still in the cache table.
314	// If not, that means it was already acted upon
315	if _, found := s.cache[sess.id]; !found {
316		return nil, diff
317	}
318	mctx.Debug("Identify3State#getSessionToExpire: removing %s", sess.id)
319	s.removeFromTableLocked(sess.id)
320	return sess, diff
321}
322
323// get an identify3Session out of the cache, as keyed by a Identify3GUIID. Return
324// (nil, nil) if not found. Return (nil, Error) if there was an expected error.
325// Return (i, nil) if found, where i is the **unlocked** object.
326func (s *Identify3State) Get(key keybase1.Identify3GUIID) (ret *Identify3Session, err error) {
327	s.Lock()
328	defer s.Unlock()
329	return s.getLocked(key)
330}
331
332func (s *Identify3State) getLocked(key keybase1.Identify3GUIID) (ret *Identify3Session, err error) {
333	ret, found := s.cache[key]
334	if !found {
335		return nil, nil
336	}
337	return ret, nil
338}
339
340func (s *Identify3State) Put(sess *Identify3Session) error {
341	err := s.lockAndPut(sess)
342	s.pokeExpireThread()
343	return err
344}
345
346func (s *Identify3State) lockAndPut(sess *Identify3Session) error {
347	s.Lock()
348	defer s.Unlock()
349
350	id := sess.ID()
351	tmp, err := s.getLocked(id)
352	if err != nil {
353		return err
354	}
355	if tmp != nil {
356		return ExistsError{Msg: "Identify3 ID already exists"}
357	}
358	s.cache[id] = sess
359	s.expirationQueue = append(s.expirationQueue, sess)
360	return nil
361}
362
363func (s *Identify3State) Remove(key keybase1.Identify3GUIID) {
364	s.Lock()
365	defer s.Unlock()
366	s.removeFromTableLocked(key)
367}
368
369func (s *Identify3State) removeFromTableLocked(key keybase1.Identify3GUIID) {
370	delete(s.cache, key)
371}
372
373// pokeExpireThread should never be called when holding s.Mutex.
374func (s *Identify3State) pokeExpireThread() {
375	if s.isShutdown() {
376		return
377	}
378	s.expireCh <- struct{}{}
379}
380