1package raft
2
3import (
4	"sync"
5	"sync/atomic"
6)
7
8// RaftState captures the state of a Raft node: Follower, Candidate, Leader,
9// or Shutdown.
10type RaftState uint32
11
12const (
13	// Follower is the initial state of a Raft node.
14	Follower RaftState = iota
15
16	// Candidate is one of the valid states of a Raft node.
17	Candidate
18
19	// Leader is one of the valid states of a Raft node.
20	Leader
21
22	// Shutdown is the terminal state of a Raft node.
23	Shutdown
24)
25
26func (s RaftState) String() string {
27	switch s {
28	case Follower:
29		return "Follower"
30	case Candidate:
31		return "Candidate"
32	case Leader:
33		return "Leader"
34	case Shutdown:
35		return "Shutdown"
36	default:
37		return "Unknown"
38	}
39}
40
41// raftState is used to maintain various state variables
42// and provides an interface to set/get the variables in a
43// thread safe manner.
44type raftState struct {
45	// The current term, cache of StableStore
46	currentTerm uint64
47
48	// Highest committed log entry
49	commitIndex uint64
50
51	// Last applied log to the FSM
52	lastApplied uint64
53
54	// protects 4 next fields
55	lastLock sync.Mutex
56
57	// Cache the latest snapshot index/term
58	lastSnapshotIndex uint64
59	lastSnapshotTerm  uint64
60
61	// Cache the latest log from LogStore
62	lastLogIndex uint64
63	lastLogTerm  uint64
64
65	// Tracks running goroutines
66	routinesGroup sync.WaitGroup
67
68	// The current state
69	state RaftState
70}
71
72func (r *raftState) getState() RaftState {
73	stateAddr := (*uint32)(&r.state)
74	return RaftState(atomic.LoadUint32(stateAddr))
75}
76
77func (r *raftState) setState(s RaftState) {
78	stateAddr := (*uint32)(&r.state)
79	atomic.StoreUint32(stateAddr, uint32(s))
80}
81
82func (r *raftState) getCurrentTerm() uint64 {
83	return atomic.LoadUint64(&r.currentTerm)
84}
85
86func (r *raftState) setCurrentTerm(term uint64) {
87	atomic.StoreUint64(&r.currentTerm, term)
88}
89
90func (r *raftState) getLastLog() (index, term uint64) {
91	r.lastLock.Lock()
92	index = r.lastLogIndex
93	term = r.lastLogTerm
94	r.lastLock.Unlock()
95	return
96}
97
98func (r *raftState) setLastLog(index, term uint64) {
99	r.lastLock.Lock()
100	r.lastLogIndex = index
101	r.lastLogTerm = term
102	r.lastLock.Unlock()
103}
104
105func (r *raftState) getLastSnapshot() (index, term uint64) {
106	r.lastLock.Lock()
107	index = r.lastSnapshotIndex
108	term = r.lastSnapshotTerm
109	r.lastLock.Unlock()
110	return
111}
112
113func (r *raftState) setLastSnapshot(index, term uint64) {
114	r.lastLock.Lock()
115	r.lastSnapshotIndex = index
116	r.lastSnapshotTerm = term
117	r.lastLock.Unlock()
118}
119
120func (r *raftState) getCommitIndex() uint64 {
121	return atomic.LoadUint64(&r.commitIndex)
122}
123
124func (r *raftState) setCommitIndex(index uint64) {
125	atomic.StoreUint64(&r.commitIndex, index)
126}
127
128func (r *raftState) getLastApplied() uint64 {
129	return atomic.LoadUint64(&r.lastApplied)
130}
131
132func (r *raftState) setLastApplied(index uint64) {
133	atomic.StoreUint64(&r.lastApplied, index)
134}
135
136// Start a goroutine and properly handle the race between a routine
137// starting and incrementing, and exiting and decrementing.
138func (r *raftState) goFunc(f func()) {
139	r.routinesGroup.Add(1)
140	go func() {
141		defer r.routinesGroup.Done()
142		f()
143	}()
144}
145
146func (r *raftState) waitShutdown() {
147	r.routinesGroup.Wait()
148}
149
150// getLastIndex returns the last index in stable storage.
151// Either from the last log or from the last snapshot.
152func (r *raftState) getLastIndex() uint64 {
153	r.lastLock.Lock()
154	defer r.lastLock.Unlock()
155	return max(r.lastLogIndex, r.lastSnapshotIndex)
156}
157
158// getLastEntry returns the last index and term in stable storage.
159// Either from the last log or from the last snapshot.
160func (r *raftState) getLastEntry() (uint64, uint64) {
161	r.lastLock.Lock()
162	defer r.lastLock.Unlock()
163	if r.lastLogIndex >= r.lastSnapshotIndex {
164		return r.lastLogIndex, r.lastLogTerm
165	}
166	return r.lastSnapshotIndex, r.lastSnapshotTerm
167}
168