1package raft
2
3import (
4	"sync"
5	"sync/atomic"
6)
7
8// RaftState captures the state of a Raft node: Follower, Candidate, Leader,
9// or Shutdown.
10type RaftState uint32
11
12const (
13	// Follower is the initial state of a Raft node.
14	Follower RaftState = iota
15
16	// Candidate is one of the valid states of a Raft node.
17	Candidate
18
19	// Leader is one of the valid states of a Raft node.
20	Leader
21
22	// Shutdown is the terminal state of a Raft node.
23	Shutdown
24)
25
26func (s RaftState) String() string {
27	switch s {
28	case Follower:
29		return "Follower"
30	case Candidate:
31		return "Candidate"
32	case Leader:
33		return "Leader"
34	case Shutdown:
35		return "Shutdown"
36	default:
37		return "Unknown"
38	}
39}
40
41// raftState is used to maintain various state variables
42// and provides an interface to set/get the variables in a
43// thread safe manner.
44type raftState struct {
45	// currentTerm commitIndex, lastApplied,  must be kept at the top of
46	// the struct so they're 64 bit aligned which is a requirement for
47	// atomic ops on 32 bit platforms.
48
49	// The current term, cache of StableStore
50	currentTerm uint64
51
52	// Highest committed log entry
53	commitIndex uint64
54
55	// Last applied log to the FSM
56	lastApplied uint64
57
58	// protects 4 next fields
59	lastLock sync.Mutex
60
61	// Cache the latest snapshot index/term
62	lastSnapshotIndex uint64
63	lastSnapshotTerm  uint64
64
65	// Cache the latest log from LogStore
66	lastLogIndex uint64
67	lastLogTerm  uint64
68
69	// Tracks running goroutines
70	routinesGroup sync.WaitGroup
71
72	// The current state
73	state RaftState
74}
75
76func (r *raftState) getState() RaftState {
77	stateAddr := (*uint32)(&r.state)
78	return RaftState(atomic.LoadUint32(stateAddr))
79}
80
81func (r *raftState) setState(s RaftState) {
82	stateAddr := (*uint32)(&r.state)
83	atomic.StoreUint32(stateAddr, uint32(s))
84}
85
86func (r *raftState) getCurrentTerm() uint64 {
87	return atomic.LoadUint64(&r.currentTerm)
88}
89
90func (r *raftState) setCurrentTerm(term uint64) {
91	atomic.StoreUint64(&r.currentTerm, term)
92}
93
94func (r *raftState) getLastLog() (index, term uint64) {
95	r.lastLock.Lock()
96	index = r.lastLogIndex
97	term = r.lastLogTerm
98	r.lastLock.Unlock()
99	return
100}
101
102func (r *raftState) setLastLog(index, term uint64) {
103	r.lastLock.Lock()
104	r.lastLogIndex = index
105	r.lastLogTerm = term
106	r.lastLock.Unlock()
107}
108
109func (r *raftState) getLastSnapshot() (index, term uint64) {
110	r.lastLock.Lock()
111	index = r.lastSnapshotIndex
112	term = r.lastSnapshotTerm
113	r.lastLock.Unlock()
114	return
115}
116
117func (r *raftState) setLastSnapshot(index, term uint64) {
118	r.lastLock.Lock()
119	r.lastSnapshotIndex = index
120	r.lastSnapshotTerm = term
121	r.lastLock.Unlock()
122}
123
124func (r *raftState) getCommitIndex() uint64 {
125	return atomic.LoadUint64(&r.commitIndex)
126}
127
128func (r *raftState) setCommitIndex(index uint64) {
129	atomic.StoreUint64(&r.commitIndex, index)
130}
131
132func (r *raftState) getLastApplied() uint64 {
133	return atomic.LoadUint64(&r.lastApplied)
134}
135
136func (r *raftState) setLastApplied(index uint64) {
137	atomic.StoreUint64(&r.lastApplied, index)
138}
139
140// Start a goroutine and properly handle the race between a routine
141// starting and incrementing, and exiting and decrementing.
142func (r *raftState) goFunc(f func()) {
143	r.routinesGroup.Add(1)
144	go func() {
145		defer r.routinesGroup.Done()
146		f()
147	}()
148}
149
150func (r *raftState) waitShutdown() {
151	r.routinesGroup.Wait()
152}
153
154// getLastIndex returns the last index in stable storage.
155// Either from the last log or from the last snapshot.
156func (r *raftState) getLastIndex() uint64 {
157	r.lastLock.Lock()
158	defer r.lastLock.Unlock()
159	return max(r.lastLogIndex, r.lastSnapshotIndex)
160}
161
162// getLastEntry returns the last index and term in stable storage.
163// Either from the last log or from the last snapshot.
164func (r *raftState) getLastEntry() (uint64, uint64) {
165	r.lastLock.Lock()
166	defer r.lastLock.Unlock()
167	if r.lastLogIndex >= r.lastSnapshotIndex {
168		return r.lastLogIndex, r.lastLogTerm
169	}
170	return r.lastSnapshotIndex, r.lastSnapshotTerm
171}
172