1package raft 2 3import ( 4 "sync" 5 "sync/atomic" 6) 7 8// RaftState captures the state of a Raft node: Follower, Candidate, Leader, 9// or Shutdown. 10type RaftState uint32 11 12const ( 13 // Follower is the initial state of a Raft node. 14 Follower RaftState = iota 15 16 // Candidate is one of the valid states of a Raft node. 17 Candidate 18 19 // Leader is one of the valid states of a Raft node. 20 Leader 21 22 // Shutdown is the terminal state of a Raft node. 23 Shutdown 24) 25 26func (s RaftState) String() string { 27 switch s { 28 case Follower: 29 return "Follower" 30 case Candidate: 31 return "Candidate" 32 case Leader: 33 return "Leader" 34 case Shutdown: 35 return "Shutdown" 36 default: 37 return "Unknown" 38 } 39} 40 41// raftState is used to maintain various state variables 42// and provides an interface to set/get the variables in a 43// thread safe manner. 44type raftState struct { 45 // currentTerm commitIndex, lastApplied, must be kept at the top of 46 // the struct so they're 64 bit aligned which is a requirement for 47 // atomic ops on 32 bit platforms. 48 49 // The current term, cache of StableStore 50 currentTerm uint64 51 52 // Highest committed log entry 53 commitIndex uint64 54 55 // Last applied log to the FSM 56 lastApplied uint64 57 58 // protects 4 next fields 59 lastLock sync.Mutex 60 61 // Cache the latest snapshot index/term 62 lastSnapshotIndex uint64 63 lastSnapshotTerm uint64 64 65 // Cache the latest log from LogStore 66 lastLogIndex uint64 67 lastLogTerm uint64 68 69 // Tracks running goroutines 70 routinesGroup sync.WaitGroup 71 72 // The current state 73 state RaftState 74} 75 76func (r *raftState) getState() RaftState { 77 stateAddr := (*uint32)(&r.state) 78 return RaftState(atomic.LoadUint32(stateAddr)) 79} 80 81func (r *raftState) setState(s RaftState) { 82 stateAddr := (*uint32)(&r.state) 83 atomic.StoreUint32(stateAddr, uint32(s)) 84} 85 86func (r *raftState) getCurrentTerm() uint64 { 87 return atomic.LoadUint64(&r.currentTerm) 88} 89 90func (r *raftState) setCurrentTerm(term uint64) { 91 atomic.StoreUint64(&r.currentTerm, term) 92} 93 94func (r *raftState) getLastLog() (index, term uint64) { 95 r.lastLock.Lock() 96 index = r.lastLogIndex 97 term = r.lastLogTerm 98 r.lastLock.Unlock() 99 return 100} 101 102func (r *raftState) setLastLog(index, term uint64) { 103 r.lastLock.Lock() 104 r.lastLogIndex = index 105 r.lastLogTerm = term 106 r.lastLock.Unlock() 107} 108 109func (r *raftState) getLastSnapshot() (index, term uint64) { 110 r.lastLock.Lock() 111 index = r.lastSnapshotIndex 112 term = r.lastSnapshotTerm 113 r.lastLock.Unlock() 114 return 115} 116 117func (r *raftState) setLastSnapshot(index, term uint64) { 118 r.lastLock.Lock() 119 r.lastSnapshotIndex = index 120 r.lastSnapshotTerm = term 121 r.lastLock.Unlock() 122} 123 124func (r *raftState) getCommitIndex() uint64 { 125 return atomic.LoadUint64(&r.commitIndex) 126} 127 128func (r *raftState) setCommitIndex(index uint64) { 129 atomic.StoreUint64(&r.commitIndex, index) 130} 131 132func (r *raftState) getLastApplied() uint64 { 133 return atomic.LoadUint64(&r.lastApplied) 134} 135 136func (r *raftState) setLastApplied(index uint64) { 137 atomic.StoreUint64(&r.lastApplied, index) 138} 139 140// Start a goroutine and properly handle the race between a routine 141// starting and incrementing, and exiting and decrementing. 142func (r *raftState) goFunc(f func()) { 143 r.routinesGroup.Add(1) 144 go func() { 145 defer r.routinesGroup.Done() 146 f() 147 }() 148} 149 150func (r *raftState) waitShutdown() { 151 r.routinesGroup.Wait() 152} 153 154// getLastIndex returns the last index in stable storage. 155// Either from the last log or from the last snapshot. 156func (r *raftState) getLastIndex() uint64 { 157 r.lastLock.Lock() 158 defer r.lastLock.Unlock() 159 return max(r.lastLogIndex, r.lastSnapshotIndex) 160} 161 162// getLastEntry returns the last index and term in stable storage. 163// Either from the last log or from the last snapshot. 164func (r *raftState) getLastEntry() (uint64, uint64) { 165 r.lastLock.Lock() 166 defer r.lastLock.Unlock() 167 if r.lastLogIndex >= r.lastSnapshotIndex { 168 return r.lastLogIndex, r.lastLogTerm 169 } 170 return r.lastSnapshotIndex, r.lastSnapshotTerm 171} 172