1package api
2
3import (
4	"encoding/json"
5	"fmt"
6	"path"
7	"sync"
8	"time"
9)
10
11const (
12	// DefaultSemaphoreSessionName is the Session Name we assign if none is provided
13	DefaultSemaphoreSessionName = "Consul API Semaphore"
14
15	// DefaultSemaphoreSessionTTL is the default session TTL if no Session is provided
16	// when creating a new Semaphore. This is used because we do not have another
17	// other check to depend upon.
18	DefaultSemaphoreSessionTTL = "15s"
19
20	// DefaultSemaphoreWaitTime is how long we block for at a time to check if semaphore
21	// acquisition is possible. This affects the minimum time it takes to cancel
22	// a Semaphore acquisition.
23	DefaultSemaphoreWaitTime = 15 * time.Second
24
25	// DefaultSemaphoreKey is the key used within the prefix to
26	// use for coordination between all the contenders.
27	DefaultSemaphoreKey = ".lock"
28
29	// SemaphoreFlagValue is a magic flag we set to indicate a key
30	// is being used for a semaphore. It is used to detect a potential
31	// conflict with a lock.
32	SemaphoreFlagValue = 0xe0f69a2baa414de0
33)
34
35var (
36	// ErrSemaphoreHeld is returned if we attempt to double lock
37	ErrSemaphoreHeld = fmt.Errorf("Semaphore already held")
38
39	// ErrSemaphoreNotHeld is returned if we attempt to unlock a semaphore
40	// that we do not hold.
41	ErrSemaphoreNotHeld = fmt.Errorf("Semaphore not held")
42
43	// ErrSemaphoreInUse is returned if we attempt to destroy a semaphore
44	// that is in use.
45	ErrSemaphoreInUse = fmt.Errorf("Semaphore in use")
46
47	// ErrSemaphoreConflict is returned if the flags on a key
48	// used for a semaphore do not match expectation
49	ErrSemaphoreConflict = fmt.Errorf("Existing key does not match semaphore use")
50)
51
52// Semaphore is used to implement a distributed semaphore
53// using the Consul KV primitives.
54type Semaphore struct {
55	c    *Client
56	opts *SemaphoreOptions
57
58	isHeld       bool
59	sessionRenew chan struct{}
60	lockSession  string
61	l            sync.Mutex
62}
63
64// SemaphoreOptions is used to parameterize the Semaphore
65type SemaphoreOptions struct {
66	Prefix            string        // Must be set and have write permissions
67	Limit             int           // Must be set, and be positive
68	Value             []byte        // Optional, value to associate with the contender entry
69	Session           string        // Optional, created if not specified
70	SessionName       string        // Optional, defaults to DefaultLockSessionName
71	SessionTTL        string        // Optional, defaults to DefaultLockSessionTTL
72	MonitorRetries    int           // Optional, defaults to 0 which means no retries
73	MonitorRetryTime  time.Duration // Optional, defaults to DefaultMonitorRetryTime
74	SemaphoreWaitTime time.Duration // Optional, defaults to DefaultSemaphoreWaitTime
75	SemaphoreTryOnce  bool          // Optional, defaults to false which means try forever
76}
77
78// semaphoreLock is written under the DefaultSemaphoreKey and
79// is used to coordinate between all the contenders.
80type semaphoreLock struct {
81	// Limit is the integer limit of holders. This is used to
82	// verify that all the holders agree on the value.
83	Limit int
84
85	// Holders is a list of all the semaphore holders.
86	// It maps the session ID to true. It is used as a set effectively.
87	Holders map[string]bool
88}
89
90// SemaphorePrefix is used to created a Semaphore which will operate
91// at the given KV prefix and uses the given limit for the semaphore.
92// The prefix must have write privileges, and the limit must be agreed
93// upon by all contenders.
94func (c *Client) SemaphorePrefix(prefix string, limit int) (*Semaphore, error) {
95	opts := &SemaphoreOptions{
96		Prefix: prefix,
97		Limit:  limit,
98	}
99	return c.SemaphoreOpts(opts)
100}
101
102// SemaphoreOpts is used to create a Semaphore with the given options.
103// The prefix must have write privileges, and the limit must be agreed
104// upon by all contenders. If a Session is not provided, one will be created.
105func (c *Client) SemaphoreOpts(opts *SemaphoreOptions) (*Semaphore, error) {
106	if opts.Prefix == "" {
107		return nil, fmt.Errorf("missing prefix")
108	}
109	if opts.Limit <= 0 {
110		return nil, fmt.Errorf("semaphore limit must be positive")
111	}
112	if opts.SessionName == "" {
113		opts.SessionName = DefaultSemaphoreSessionName
114	}
115	if opts.SessionTTL == "" {
116		opts.SessionTTL = DefaultSemaphoreSessionTTL
117	} else {
118		if _, err := time.ParseDuration(opts.SessionTTL); err != nil {
119			return nil, fmt.Errorf("invalid SessionTTL: %v", err)
120		}
121	}
122	if opts.MonitorRetryTime == 0 {
123		opts.MonitorRetryTime = DefaultMonitorRetryTime
124	}
125	if opts.SemaphoreWaitTime == 0 {
126		opts.SemaphoreWaitTime = DefaultSemaphoreWaitTime
127	}
128	s := &Semaphore{
129		c:    c,
130		opts: opts,
131	}
132	return s, nil
133}
134
135// Acquire attempts to reserve a slot in the semaphore, blocking until
136// success, interrupted via the stopCh or an error is encountered.
137// Providing a non-nil stopCh can be used to abort the attempt.
138// On success, a channel is returned that represents our slot.
139// This channel could be closed at any time due to session invalidation,
140// communication errors, operator intervention, etc. It is NOT safe to
141// assume that the slot is held until Release() unless the Session is specifically
142// created without any associated health checks. By default Consul sessions
143// prefer liveness over safety and an application must be able to handle
144// the session being lost.
145func (s *Semaphore) Acquire(stopCh <-chan struct{}) (<-chan struct{}, error) {
146	// Hold the lock as we try to acquire
147	s.l.Lock()
148	defer s.l.Unlock()
149
150	// Check if we already hold the semaphore
151	if s.isHeld {
152		return nil, ErrSemaphoreHeld
153	}
154
155	// Check if we need to create a session first
156	s.lockSession = s.opts.Session
157	if s.lockSession == "" {
158		sess, err := s.createSession()
159		if err != nil {
160			return nil, fmt.Errorf("failed to create session: %v", err)
161		}
162
163		s.sessionRenew = make(chan struct{})
164		s.lockSession = sess
165		session := s.c.Session()
166		go session.RenewPeriodic(s.opts.SessionTTL, sess, nil, s.sessionRenew)
167
168		// If we fail to acquire the lock, cleanup the session
169		defer func() {
170			if !s.isHeld {
171				close(s.sessionRenew)
172				s.sessionRenew = nil
173			}
174		}()
175	}
176
177	// Create the contender entry
178	kv := s.c.KV()
179	made, _, err := kv.Acquire(s.contenderEntry(s.lockSession), nil)
180	if err != nil || !made {
181		return nil, fmt.Errorf("failed to make contender entry: %v", err)
182	}
183
184	// Setup the query options
185	qOpts := &QueryOptions{
186		WaitTime: s.opts.SemaphoreWaitTime,
187	}
188
189	start := time.Now()
190	attempts := 0
191WAIT:
192	// Check if we should quit
193	select {
194	case <-stopCh:
195		return nil, nil
196	default:
197	}
198
199	// Handle the one-shot mode.
200	if s.opts.SemaphoreTryOnce && attempts > 0 {
201		elapsed := time.Since(start)
202		if elapsed > s.opts.SemaphoreWaitTime {
203			return nil, nil
204		}
205
206		// Query wait time should not exceed the semaphore wait time
207		qOpts.WaitTime = s.opts.SemaphoreWaitTime - elapsed
208	}
209	attempts++
210
211	// Read the prefix
212	pairs, meta, err := kv.List(s.opts.Prefix, qOpts)
213	if err != nil {
214		return nil, fmt.Errorf("failed to read prefix: %v", err)
215	}
216
217	// Decode the lock
218	lockPair := s.findLock(pairs)
219	if lockPair.Flags != SemaphoreFlagValue {
220		return nil, ErrSemaphoreConflict
221	}
222	lock, err := s.decodeLock(lockPair)
223	if err != nil {
224		return nil, err
225	}
226
227	// Verify we agree with the limit
228	if lock.Limit != s.opts.Limit {
229		return nil, fmt.Errorf("semaphore limit conflict (lock: %d, local: %d)",
230			lock.Limit, s.opts.Limit)
231	}
232
233	// Prune the dead holders
234	s.pruneDeadHolders(lock, pairs)
235
236	// Check if the lock is held
237	if len(lock.Holders) >= lock.Limit {
238		qOpts.WaitIndex = meta.LastIndex
239		goto WAIT
240	}
241
242	// Create a new lock with us as a holder
243	lock.Holders[s.lockSession] = true
244	newLock, err := s.encodeLock(lock, lockPair.ModifyIndex)
245	if err != nil {
246		return nil, err
247	}
248
249	// Attempt the acquisition
250	didSet, _, err := kv.CAS(newLock, nil)
251	if err != nil {
252		return nil, fmt.Errorf("failed to update lock: %v", err)
253	}
254	if !didSet {
255		// Update failed, could have been a race with another contender,
256		// retry the operation
257		goto WAIT
258	}
259
260	// Watch to ensure we maintain ownership of the slot
261	lockCh := make(chan struct{})
262	go s.monitorLock(s.lockSession, lockCh)
263
264	// Set that we own the lock
265	s.isHeld = true
266
267	// Acquired! All done
268	return lockCh, nil
269}
270
271// Release is used to voluntarily give up our semaphore slot. It is
272// an error to call this if the semaphore has not been acquired.
273func (s *Semaphore) Release() error {
274	// Hold the lock as we try to release
275	s.l.Lock()
276	defer s.l.Unlock()
277
278	// Ensure the lock is actually held
279	if !s.isHeld {
280		return ErrSemaphoreNotHeld
281	}
282
283	// Set that we no longer own the lock
284	s.isHeld = false
285
286	// Stop the session renew
287	if s.sessionRenew != nil {
288		defer func() {
289			close(s.sessionRenew)
290			s.sessionRenew = nil
291		}()
292	}
293
294	// Get and clear the lock session
295	lockSession := s.lockSession
296	s.lockSession = ""
297
298	// Remove ourselves as a lock holder
299	kv := s.c.KV()
300	key := path.Join(s.opts.Prefix, DefaultSemaphoreKey)
301READ:
302	pair, _, err := kv.Get(key, nil)
303	if err != nil {
304		return err
305	}
306	if pair == nil {
307		pair = &KVPair{}
308	}
309	lock, err := s.decodeLock(pair)
310	if err != nil {
311		return err
312	}
313
314	// Create a new lock without us as a holder
315	if _, ok := lock.Holders[lockSession]; ok {
316		delete(lock.Holders, lockSession)
317		newLock, err := s.encodeLock(lock, pair.ModifyIndex)
318		if err != nil {
319			return err
320		}
321
322		// Swap the locks
323		didSet, _, err := kv.CAS(newLock, nil)
324		if err != nil {
325			return fmt.Errorf("failed to update lock: %v", err)
326		}
327		if !didSet {
328			goto READ
329		}
330	}
331
332	// Destroy the contender entry
333	contenderKey := path.Join(s.opts.Prefix, lockSession)
334	if _, err := kv.Delete(contenderKey, nil); err != nil {
335		return err
336	}
337	return nil
338}
339
340// Destroy is used to cleanup the semaphore entry. It is not necessary
341// to invoke. It will fail if the semaphore is in use.
342func (s *Semaphore) Destroy() error {
343	// Hold the lock as we try to acquire
344	s.l.Lock()
345	defer s.l.Unlock()
346
347	// Check if we already hold the semaphore
348	if s.isHeld {
349		return ErrSemaphoreHeld
350	}
351
352	// List for the semaphore
353	kv := s.c.KV()
354	pairs, _, err := kv.List(s.opts.Prefix, nil)
355	if err != nil {
356		return fmt.Errorf("failed to read prefix: %v", err)
357	}
358
359	// Find the lock pair, bail if it doesn't exist
360	lockPair := s.findLock(pairs)
361	if lockPair.ModifyIndex == 0 {
362		return nil
363	}
364	if lockPair.Flags != SemaphoreFlagValue {
365		return ErrSemaphoreConflict
366	}
367
368	// Decode the lock
369	lock, err := s.decodeLock(lockPair)
370	if err != nil {
371		return err
372	}
373
374	// Prune the dead holders
375	s.pruneDeadHolders(lock, pairs)
376
377	// Check if there are any holders
378	if len(lock.Holders) > 0 {
379		return ErrSemaphoreInUse
380	}
381
382	// Attempt the delete
383	didRemove, _, err := kv.DeleteCAS(lockPair, nil)
384	if err != nil {
385		return fmt.Errorf("failed to remove semaphore: %v", err)
386	}
387	if !didRemove {
388		return ErrSemaphoreInUse
389	}
390	return nil
391}
392
393// createSession is used to create a new managed session
394func (s *Semaphore) createSession() (string, error) {
395	session := s.c.Session()
396	se := &SessionEntry{
397		Name:     s.opts.SessionName,
398		TTL:      s.opts.SessionTTL,
399		Behavior: SessionBehaviorDelete,
400	}
401	id, _, err := session.Create(se, nil)
402	if err != nil {
403		return "", err
404	}
405	return id, nil
406}
407
408// contenderEntry returns a formatted KVPair for the contender
409func (s *Semaphore) contenderEntry(session string) *KVPair {
410	return &KVPair{
411		Key:     path.Join(s.opts.Prefix, session),
412		Value:   s.opts.Value,
413		Session: session,
414		Flags:   SemaphoreFlagValue,
415	}
416}
417
418// findLock is used to find the KV Pair which is used for coordination
419func (s *Semaphore) findLock(pairs KVPairs) *KVPair {
420	key := path.Join(s.opts.Prefix, DefaultSemaphoreKey)
421	for _, pair := range pairs {
422		if pair.Key == key {
423			return pair
424		}
425	}
426	return &KVPair{Flags: SemaphoreFlagValue}
427}
428
429// decodeLock is used to decode a semaphoreLock from an
430// entry in Consul
431func (s *Semaphore) decodeLock(pair *KVPair) (*semaphoreLock, error) {
432	// Handle if there is no lock
433	if pair == nil || pair.Value == nil {
434		return &semaphoreLock{
435			Limit:   s.opts.Limit,
436			Holders: make(map[string]bool),
437		}, nil
438	}
439
440	l := &semaphoreLock{}
441	if err := json.Unmarshal(pair.Value, l); err != nil {
442		return nil, fmt.Errorf("lock decoding failed: %v", err)
443	}
444	return l, nil
445}
446
447// encodeLock is used to encode a semaphoreLock into a KVPair
448// that can be PUT
449func (s *Semaphore) encodeLock(l *semaphoreLock, oldIndex uint64) (*KVPair, error) {
450	enc, err := json.Marshal(l)
451	if err != nil {
452		return nil, fmt.Errorf("lock encoding failed: %v", err)
453	}
454	pair := &KVPair{
455		Key:         path.Join(s.opts.Prefix, DefaultSemaphoreKey),
456		Value:       enc,
457		Flags:       SemaphoreFlagValue,
458		ModifyIndex: oldIndex,
459	}
460	return pair, nil
461}
462
463// pruneDeadHolders is used to remove all the dead lock holders
464func (s *Semaphore) pruneDeadHolders(lock *semaphoreLock, pairs KVPairs) {
465	// Gather all the live holders
466	alive := make(map[string]struct{}, len(pairs))
467	for _, pair := range pairs {
468		if pair.Session != "" {
469			alive[pair.Session] = struct{}{}
470		}
471	}
472
473	// Remove any holders that are dead
474	for holder := range lock.Holders {
475		if _, ok := alive[holder]; !ok {
476			delete(lock.Holders, holder)
477		}
478	}
479}
480
481// monitorLock is a long running routine to monitor a semaphore ownership
482// It closes the stopCh if we lose our slot.
483func (s *Semaphore) monitorLock(session string, stopCh chan struct{}) {
484	defer close(stopCh)
485	kv := s.c.KV()
486	opts := &QueryOptions{RequireConsistent: true}
487WAIT:
488	retries := s.opts.MonitorRetries
489RETRY:
490	pairs, meta, err := kv.List(s.opts.Prefix, opts)
491	if err != nil {
492		// If configured we can try to ride out a brief Consul unavailability
493		// by doing retries. Note that we have to attempt the retry in a non-
494		// blocking fashion so that we have a clean place to reset the retry
495		// counter if service is restored.
496		if retries > 0 && IsRetryableError(err) {
497			time.Sleep(s.opts.MonitorRetryTime)
498			retries--
499			opts.WaitIndex = 0
500			goto RETRY
501		}
502		return
503	}
504	lockPair := s.findLock(pairs)
505	lock, err := s.decodeLock(lockPair)
506	if err != nil {
507		return
508	}
509	s.pruneDeadHolders(lock, pairs)
510	if _, ok := lock.Holders[session]; ok {
511		opts.WaitIndex = meta.LastIndex
512		goto WAIT
513	}
514}
515