1package api
2
3import (
4	"fmt"
5	"sync"
6	"time"
7)
8
9const (
10	// DefaultLockSessionName is the Session Name we assign if none is provided
11	DefaultLockSessionName = "Consul API Lock"
12
13	// DefaultLockSessionTTL is the default session TTL if no Session is provided
14	// when creating a new Lock. This is used because we do not have another
15	// other check to depend upon.
16	DefaultLockSessionTTL = "15s"
17
18	// DefaultLockWaitTime is how long we block for at a time to check if lock
19	// acquisition is possible. This affects the minimum time it takes to cancel
20	// a Lock acquisition.
21	DefaultLockWaitTime = 15 * time.Second
22
23	// DefaultLockRetryTime is how long we wait after a failed lock acquisition
24	// before attempting to do the lock again. This is so that once a lock-delay
25	// is in effect, we do not hot loop retrying the acquisition.
26	DefaultLockRetryTime = 5 * time.Second
27
28	// DefaultMonitorRetryTime is how long we wait after a failed monitor check
29	// of a lock (500 response code). This allows the monitor to ride out brief
30	// periods of unavailability, subject to the MonitorRetries setting in the
31	// lock options which is by default set to 0, disabling this feature. This
32	// affects locks and semaphores.
33	DefaultMonitorRetryTime = 2 * time.Second
34
35	// LockFlagValue is a magic flag we set to indicate a key
36	// is being used for a lock. It is used to detect a potential
37	// conflict with a semaphore.
38	LockFlagValue = 0x2ddccbc058a50c18
39)
40
41var (
42	// ErrLockHeld is returned if we attempt to double lock
43	ErrLockHeld = fmt.Errorf("Lock already held")
44
45	// ErrLockNotHeld is returned if we attempt to unlock a lock
46	// that we do not hold.
47	ErrLockNotHeld = fmt.Errorf("Lock not held")
48
49	// ErrLockInUse is returned if we attempt to destroy a lock
50	// that is in use.
51	ErrLockInUse = fmt.Errorf("Lock in use")
52
53	// ErrLockConflict is returned if the flags on a key
54	// used for a lock do not match expectation
55	ErrLockConflict = fmt.Errorf("Existing key does not match lock use")
56)
57
58// Lock is used to implement client-side leader election. It is follows the
59// algorithm as described here: https://www.consul.io/docs/guides/leader-election.html.
60type Lock struct {
61	c    *Client
62	opts *LockOptions
63
64	isHeld       bool
65	sessionRenew chan struct{}
66	lockSession  string
67	l            sync.Mutex
68}
69
70// LockOptions is used to parameterize the Lock behavior.
71type LockOptions struct {
72	Key              string        // Must be set and have write permissions
73	Value            []byte        // Optional, value to associate with the lock
74	Session          string        // Optional, created if not specified
75	SessionOpts      *SessionEntry // Optional, options to use when creating a session
76	SessionName      string        // Optional, defaults to DefaultLockSessionName (ignored if SessionOpts is given)
77	SessionTTL       string        // Optional, defaults to DefaultLockSessionTTL (ignored if SessionOpts is given)
78	MonitorRetries   int           // Optional, defaults to 0 which means no retries
79	MonitorRetryTime time.Duration // Optional, defaults to DefaultMonitorRetryTime
80	LockWaitTime     time.Duration // Optional, defaults to DefaultLockWaitTime
81	LockTryOnce      bool          // Optional, defaults to false which means try forever
82	Namespace        string        `json:",omitempty"` // Optional, defaults to API client config, namespace of ACL token, or "default" namespace
83}
84
85// LockKey returns a handle to a lock struct which can be used
86// to acquire and release the mutex. The key used must have
87// write permissions.
88func (c *Client) LockKey(key string) (*Lock, error) {
89	opts := &LockOptions{
90		Key: key,
91	}
92	return c.LockOpts(opts)
93}
94
95// LockOpts returns a handle to a lock struct which can be used
96// to acquire and release the mutex. The key used must have
97// write permissions.
98func (c *Client) LockOpts(opts *LockOptions) (*Lock, error) {
99	if opts.Key == "" {
100		return nil, fmt.Errorf("missing key")
101	}
102	if opts.SessionName == "" {
103		opts.SessionName = DefaultLockSessionName
104	}
105	if opts.SessionTTL == "" {
106		opts.SessionTTL = DefaultLockSessionTTL
107	} else {
108		if _, err := time.ParseDuration(opts.SessionTTL); err != nil {
109			return nil, fmt.Errorf("invalid SessionTTL: %v", err)
110		}
111	}
112	if opts.MonitorRetryTime == 0 {
113		opts.MonitorRetryTime = DefaultMonitorRetryTime
114	}
115	if opts.LockWaitTime == 0 {
116		opts.LockWaitTime = DefaultLockWaitTime
117	}
118	l := &Lock{
119		c:    c,
120		opts: opts,
121	}
122	return l, nil
123}
124
125// Lock attempts to acquire the lock and blocks while doing so.
126// Providing a non-nil stopCh can be used to abort the lock attempt.
127// Returns a channel that is closed if our lock is lost or an error.
128// This channel could be closed at any time due to session invalidation,
129// communication errors, operator intervention, etc. It is NOT safe to
130// assume that the lock is held until Unlock() unless the Session is specifically
131// created without any associated health checks. By default Consul sessions
132// prefer liveness over safety and an application must be able to handle
133// the lock being lost.
134func (l *Lock) Lock(stopCh <-chan struct{}) (<-chan struct{}, error) {
135	// Hold the lock as we try to acquire
136	l.l.Lock()
137	defer l.l.Unlock()
138
139	// Check if we already hold the lock
140	if l.isHeld {
141		return nil, ErrLockHeld
142	}
143
144	wOpts := WriteOptions{
145		Namespace: l.opts.Namespace,
146	}
147
148	// Check if we need to create a session first
149	l.lockSession = l.opts.Session
150	if l.lockSession == "" {
151		s, err := l.createSession()
152		if err != nil {
153			return nil, fmt.Errorf("failed to create session: %v", err)
154		}
155
156		l.sessionRenew = make(chan struct{})
157		l.lockSession = s
158
159		session := l.c.Session()
160		go session.RenewPeriodic(l.opts.SessionTTL, s, &wOpts, l.sessionRenew)
161
162		// If we fail to acquire the lock, cleanup the session
163		defer func() {
164			if !l.isHeld {
165				close(l.sessionRenew)
166				l.sessionRenew = nil
167			}
168		}()
169	}
170
171	// Setup the query options
172	kv := l.c.KV()
173	qOpts := QueryOptions{
174		WaitTime:  l.opts.LockWaitTime,
175		Namespace: l.opts.Namespace,
176	}
177
178	start := time.Now()
179	attempts := 0
180WAIT:
181	// Check if we should quit
182	select {
183	case <-stopCh:
184		return nil, nil
185	default:
186	}
187
188	// Handle the one-shot mode.
189	if l.opts.LockTryOnce && attempts > 0 {
190		elapsed := time.Since(start)
191		if elapsed > l.opts.LockWaitTime {
192			return nil, nil
193		}
194
195		// Query wait time should not exceed the lock wait time
196		qOpts.WaitTime = l.opts.LockWaitTime - elapsed
197	}
198	attempts++
199
200	// Look for an existing lock, blocking until not taken
201	pair, meta, err := kv.Get(l.opts.Key, &qOpts)
202	if err != nil {
203		return nil, fmt.Errorf("failed to read lock: %v", err)
204	}
205	if pair != nil && pair.Flags != LockFlagValue {
206		return nil, ErrLockConflict
207	}
208	locked := false
209	if pair != nil && pair.Session == l.lockSession {
210		goto HELD
211	}
212	if pair != nil && pair.Session != "" {
213		qOpts.WaitIndex = meta.LastIndex
214		goto WAIT
215	}
216
217	// Try to acquire the lock
218	pair = l.lockEntry(l.lockSession)
219
220	locked, _, err = kv.Acquire(pair, &wOpts)
221	if err != nil {
222		return nil, fmt.Errorf("failed to acquire lock: %v", err)
223	}
224
225	// Handle the case of not getting the lock
226	if !locked {
227		// Determine why the lock failed
228		qOpts.WaitIndex = 0
229		pair, meta, err = kv.Get(l.opts.Key, &qOpts)
230		if pair != nil && pair.Session != "" {
231			//If the session is not null, this means that a wait can safely happen
232			//using a long poll
233			qOpts.WaitIndex = meta.LastIndex
234			goto WAIT
235		} else {
236			// If the session is empty and the lock failed to acquire, then it means
237			// a lock-delay is in effect and a timed wait must be used
238			select {
239			case <-time.After(DefaultLockRetryTime):
240				goto WAIT
241			case <-stopCh:
242				return nil, nil
243			}
244		}
245	}
246
247HELD:
248	// Watch to ensure we maintain leadership
249	leaderCh := make(chan struct{})
250	go l.monitorLock(l.lockSession, leaderCh)
251
252	// Set that we own the lock
253	l.isHeld = true
254
255	// Locked! All done
256	return leaderCh, nil
257}
258
259// Unlock released the lock. It is an error to call this
260// if the lock is not currently held.
261func (l *Lock) Unlock() error {
262	// Hold the lock as we try to release
263	l.l.Lock()
264	defer l.l.Unlock()
265
266	// Ensure the lock is actually held
267	if !l.isHeld {
268		return ErrLockNotHeld
269	}
270
271	// Set that we no longer own the lock
272	l.isHeld = false
273
274	// Stop the session renew
275	if l.sessionRenew != nil {
276		defer func() {
277			close(l.sessionRenew)
278			l.sessionRenew = nil
279		}()
280	}
281
282	// Get the lock entry, and clear the lock session
283	lockEnt := l.lockEntry(l.lockSession)
284	l.lockSession = ""
285
286	// Release the lock explicitly
287	kv := l.c.KV()
288	w := WriteOptions{Namespace: l.opts.Namespace}
289
290	_, _, err := kv.Release(lockEnt, &w)
291	if err != nil {
292		return fmt.Errorf("failed to release lock: %v", err)
293	}
294	return nil
295}
296
297// Destroy is used to cleanup the lock entry. It is not necessary
298// to invoke. It will fail if the lock is in use.
299func (l *Lock) Destroy() error {
300	// Hold the lock as we try to release
301	l.l.Lock()
302	defer l.l.Unlock()
303
304	// Check if we already hold the lock
305	if l.isHeld {
306		return ErrLockHeld
307	}
308
309	// Look for an existing lock
310	kv := l.c.KV()
311	q := QueryOptions{Namespace: l.opts.Namespace}
312
313	pair, _, err := kv.Get(l.opts.Key, &q)
314	if err != nil {
315		return fmt.Errorf("failed to read lock: %v", err)
316	}
317
318	// Nothing to do if the lock does not exist
319	if pair == nil {
320		return nil
321	}
322
323	// Check for possible flag conflict
324	if pair.Flags != LockFlagValue {
325		return ErrLockConflict
326	}
327
328	// Check if it is in use
329	if pair.Session != "" {
330		return ErrLockInUse
331	}
332
333	// Attempt the delete
334	w := WriteOptions{Namespace: l.opts.Namespace}
335	didRemove, _, err := kv.DeleteCAS(pair, &w)
336	if err != nil {
337		return fmt.Errorf("failed to remove lock: %v", err)
338	}
339	if !didRemove {
340		return ErrLockInUse
341	}
342	return nil
343}
344
345// createSession is used to create a new managed session
346func (l *Lock) createSession() (string, error) {
347	session := l.c.Session()
348	se := l.opts.SessionOpts
349	if se == nil {
350		se = &SessionEntry{
351			Name: l.opts.SessionName,
352			TTL:  l.opts.SessionTTL,
353		}
354	}
355	w := WriteOptions{Namespace: l.opts.Namespace}
356	id, _, err := session.Create(se, &w)
357	if err != nil {
358		return "", err
359	}
360	return id, nil
361}
362
363// lockEntry returns a formatted KVPair for the lock
364func (l *Lock) lockEntry(session string) *KVPair {
365	return &KVPair{
366		Key:     l.opts.Key,
367		Value:   l.opts.Value,
368		Session: session,
369		Flags:   LockFlagValue,
370	}
371}
372
373// monitorLock is a long running routine to monitor a lock ownership
374// It closes the stopCh if we lose our leadership.
375func (l *Lock) monitorLock(session string, stopCh chan struct{}) {
376	defer close(stopCh)
377	kv := l.c.KV()
378	opts := QueryOptions{
379		RequireConsistent: true,
380		Namespace:         l.opts.Namespace,
381	}
382WAIT:
383	retries := l.opts.MonitorRetries
384RETRY:
385	pair, meta, err := kv.Get(l.opts.Key, &opts)
386	if err != nil {
387		// If configured we can try to ride out a brief Consul unavailability
388		// by doing retries. Note that we have to attempt the retry in a non-
389		// blocking fashion so that we have a clean place to reset the retry
390		// counter if service is restored.
391		if retries > 0 && IsRetryableError(err) {
392			time.Sleep(l.opts.MonitorRetryTime)
393			retries--
394			opts.WaitIndex = 0
395			goto RETRY
396		}
397		return
398	}
399	if pair != nil && pair.Session == session {
400		opts.WaitIndex = meta.LastIndex
401		goto WAIT
402	}
403}
404