1package api
2
3import (
4	"encoding/json"
5	"fmt"
6	"path"
7	"sync"
8	"time"
9)
10
11const (
12	// DefaultSemaphoreSessionName is the Session Name we assign if none is provided
13	DefaultSemaphoreSessionName = "Consul API Semaphore"
14
15	// DefaultSemaphoreSessionTTL is the default session TTL if no Session is provided
16	// when creating a new Semaphore. This is used because we do not have another
17	// other check to depend upon.
18	DefaultSemaphoreSessionTTL = "15s"
19
20	// DefaultSemaphoreWaitTime is how long we block for at a time to check if semaphore
21	// acquisition is possible. This affects the minimum time it takes to cancel
22	// a Semaphore acquisition.
23	DefaultSemaphoreWaitTime = 15 * time.Second
24
25	// DefaultSemaphoreKey is the key used within the prefix to
26	// use for coordination between all the contenders.
27	DefaultSemaphoreKey = ".lock"
28
29	// SemaphoreFlagValue is a magic flag we set to indicate a key
30	// is being used for a semaphore. It is used to detect a potential
31	// conflict with a lock.
32	SemaphoreFlagValue = 0xe0f69a2baa414de0
33)
34
35var (
36	// ErrSemaphoreHeld is returned if we attempt to double lock
37	ErrSemaphoreHeld = fmt.Errorf("Semaphore already held")
38
39	// ErrSemaphoreNotHeld is returned if we attempt to unlock a semaphore
40	// that we do not hold.
41	ErrSemaphoreNotHeld = fmt.Errorf("Semaphore not held")
42
43	// ErrSemaphoreInUse is returned if we attempt to destroy a semaphore
44	// that is in use.
45	ErrSemaphoreInUse = fmt.Errorf("Semaphore in use")
46
47	// ErrSemaphoreConflict is returned if the flags on a key
48	// used for a semaphore do not match expectation
49	ErrSemaphoreConflict = fmt.Errorf("Existing key does not match semaphore use")
50)
51
52// Semaphore is used to implement a distributed semaphore
53// using the Consul KV primitives.
54type Semaphore struct {
55	c    *Client
56	opts *SemaphoreOptions
57
58	isHeld       bool
59	sessionRenew chan struct{}
60	lockSession  string
61	l            sync.Mutex
62}
63
64// SemaphoreOptions is used to parameterize the Semaphore
65type SemaphoreOptions struct {
66	Prefix            string        // Must be set and have write permissions
67	Limit             int           // Must be set, and be positive
68	Value             []byte        // Optional, value to associate with the contender entry
69	Session           string        // Optional, created if not specified
70	SessionName       string        // Optional, defaults to DefaultLockSessionName
71	SessionTTL        string        // Optional, defaults to DefaultLockSessionTTL
72	MonitorRetries    int           // Optional, defaults to 0 which means no retries
73	MonitorRetryTime  time.Duration // Optional, defaults to DefaultMonitorRetryTime
74	SemaphoreWaitTime time.Duration // Optional, defaults to DefaultSemaphoreWaitTime
75	SemaphoreTryOnce  bool          // Optional, defaults to false which means try forever
76	Namespace         string        `json:",omitempty"` // Optional, defaults to API client config, namespace of ACL token, or "default" namespace
77}
78
79// semaphoreLock is written under the DefaultSemaphoreKey and
80// is used to coordinate between all the contenders.
81type semaphoreLock struct {
82	// Limit is the integer limit of holders. This is used to
83	// verify that all the holders agree on the value.
84	Limit int
85
86	// Holders is a list of all the semaphore holders.
87	// It maps the session ID to true. It is used as a set effectively.
88	Holders map[string]bool
89}
90
91// SemaphorePrefix is used to created a Semaphore which will operate
92// at the given KV prefix and uses the given limit for the semaphore.
93// The prefix must have write privileges, and the limit must be agreed
94// upon by all contenders.
95func (c *Client) SemaphorePrefix(prefix string, limit int) (*Semaphore, error) {
96	opts := &SemaphoreOptions{
97		Prefix: prefix,
98		Limit:  limit,
99	}
100	return c.SemaphoreOpts(opts)
101}
102
103// SemaphoreOpts is used to create a Semaphore with the given options.
104// The prefix must have write privileges, and the limit must be agreed
105// upon by all contenders. If a Session is not provided, one will be created.
106func (c *Client) SemaphoreOpts(opts *SemaphoreOptions) (*Semaphore, error) {
107	if opts.Prefix == "" {
108		return nil, fmt.Errorf("missing prefix")
109	}
110	if opts.Limit <= 0 {
111		return nil, fmt.Errorf("semaphore limit must be positive")
112	}
113	if opts.SessionName == "" {
114		opts.SessionName = DefaultSemaphoreSessionName
115	}
116	if opts.SessionTTL == "" {
117		opts.SessionTTL = DefaultSemaphoreSessionTTL
118	} else {
119		if _, err := time.ParseDuration(opts.SessionTTL); err != nil {
120			return nil, fmt.Errorf("invalid SessionTTL: %v", err)
121		}
122	}
123	if opts.MonitorRetryTime == 0 {
124		opts.MonitorRetryTime = DefaultMonitorRetryTime
125	}
126	if opts.SemaphoreWaitTime == 0 {
127		opts.SemaphoreWaitTime = DefaultSemaphoreWaitTime
128	}
129	s := &Semaphore{
130		c:    c,
131		opts: opts,
132	}
133	return s, nil
134}
135
136// Acquire attempts to reserve a slot in the semaphore, blocking until
137// success, interrupted via the stopCh or an error is encountered.
138// Providing a non-nil stopCh can be used to abort the attempt.
139// On success, a channel is returned that represents our slot.
140// This channel could be closed at any time due to session invalidation,
141// communication errors, operator intervention, etc. It is NOT safe to
142// assume that the slot is held until Release() unless the Session is specifically
143// created without any associated health checks. By default Consul sessions
144// prefer liveness over safety and an application must be able to handle
145// the session being lost.
146func (s *Semaphore) Acquire(stopCh <-chan struct{}) (<-chan struct{}, error) {
147	// Hold the lock as we try to acquire
148	s.l.Lock()
149	defer s.l.Unlock()
150
151	// Check if we already hold the semaphore
152	if s.isHeld {
153		return nil, ErrSemaphoreHeld
154	}
155
156	// Check if we need to create a session first
157	s.lockSession = s.opts.Session
158	if s.lockSession == "" {
159		sess, err := s.createSession()
160		if err != nil {
161			return nil, fmt.Errorf("failed to create session: %v", err)
162		}
163
164		s.sessionRenew = make(chan struct{})
165		s.lockSession = sess
166		session := s.c.Session()
167		go session.RenewPeriodic(s.opts.SessionTTL, sess, nil, s.sessionRenew)
168
169		// If we fail to acquire the lock, cleanup the session
170		defer func() {
171			if !s.isHeld {
172				close(s.sessionRenew)
173				s.sessionRenew = nil
174			}
175		}()
176	}
177
178	// Create the contender entry
179	kv := s.c.KV()
180	wOpts := WriteOptions{Namespace: s.opts.Namespace}
181
182	made, _, err := kv.Acquire(s.contenderEntry(s.lockSession), &wOpts)
183	if err != nil || !made {
184		return nil, fmt.Errorf("failed to make contender entry: %v", err)
185	}
186
187	// Setup the query options
188	qOpts := QueryOptions{
189		WaitTime:  s.opts.SemaphoreWaitTime,
190		Namespace: s.opts.Namespace,
191	}
192
193	start := time.Now()
194	attempts := 0
195WAIT:
196	// Check if we should quit
197	select {
198	case <-stopCh:
199		return nil, nil
200	default:
201	}
202
203	// Handle the one-shot mode.
204	if s.opts.SemaphoreTryOnce && attempts > 0 {
205		elapsed := time.Since(start)
206		if elapsed > s.opts.SemaphoreWaitTime {
207			return nil, nil
208		}
209
210		// Query wait time should not exceed the semaphore wait time
211		qOpts.WaitTime = s.opts.SemaphoreWaitTime - elapsed
212	}
213	attempts++
214
215	// Read the prefix
216	pairs, meta, err := kv.List(s.opts.Prefix, &qOpts)
217	if err != nil {
218		return nil, fmt.Errorf("failed to read prefix: %v", err)
219	}
220
221	// Decode the lock
222	lockPair := s.findLock(pairs)
223	if lockPair.Flags != SemaphoreFlagValue {
224		return nil, ErrSemaphoreConflict
225	}
226	lock, err := s.decodeLock(lockPair)
227	if err != nil {
228		return nil, err
229	}
230
231	// Verify we agree with the limit
232	if lock.Limit != s.opts.Limit {
233		return nil, fmt.Errorf("semaphore limit conflict (lock: %d, local: %d)",
234			lock.Limit, s.opts.Limit)
235	}
236
237	// Prune the dead holders
238	s.pruneDeadHolders(lock, pairs)
239
240	// Check if the lock is held
241	if len(lock.Holders) >= lock.Limit {
242		qOpts.WaitIndex = meta.LastIndex
243		goto WAIT
244	}
245
246	// Create a new lock with us as a holder
247	lock.Holders[s.lockSession] = true
248	newLock, err := s.encodeLock(lock, lockPair.ModifyIndex)
249	if err != nil {
250		return nil, err
251	}
252
253	// Attempt the acquisition
254	didSet, _, err := kv.CAS(newLock, &wOpts)
255	if err != nil {
256		return nil, fmt.Errorf("failed to update lock: %v", err)
257	}
258	if !didSet {
259		// Update failed, could have been a race with another contender,
260		// retry the operation
261		goto WAIT
262	}
263
264	// Watch to ensure we maintain ownership of the slot
265	lockCh := make(chan struct{})
266	go s.monitorLock(s.lockSession, lockCh)
267
268	// Set that we own the lock
269	s.isHeld = true
270
271	// Acquired! All done
272	return lockCh, nil
273}
274
275// Release is used to voluntarily give up our semaphore slot. It is
276// an error to call this if the semaphore has not been acquired.
277func (s *Semaphore) Release() error {
278	// Hold the lock as we try to release
279	s.l.Lock()
280	defer s.l.Unlock()
281
282	// Ensure the lock is actually held
283	if !s.isHeld {
284		return ErrSemaphoreNotHeld
285	}
286
287	// Set that we no longer own the lock
288	s.isHeld = false
289
290	// Stop the session renew
291	if s.sessionRenew != nil {
292		defer func() {
293			close(s.sessionRenew)
294			s.sessionRenew = nil
295		}()
296	}
297
298	// Get and clear the lock session
299	lockSession := s.lockSession
300	s.lockSession = ""
301
302	// Remove ourselves as a lock holder
303	kv := s.c.KV()
304	key := path.Join(s.opts.Prefix, DefaultSemaphoreKey)
305
306	wOpts := WriteOptions{Namespace: s.opts.Namespace}
307	qOpts := QueryOptions{Namespace: s.opts.Namespace}
308
309READ:
310	pair, _, err := kv.Get(key, &qOpts)
311	if err != nil {
312		return err
313	}
314	if pair == nil {
315		pair = &KVPair{}
316	}
317	lock, err := s.decodeLock(pair)
318	if err != nil {
319		return err
320	}
321
322	// Create a new lock without us as a holder
323	if _, ok := lock.Holders[lockSession]; ok {
324		delete(lock.Holders, lockSession)
325		newLock, err := s.encodeLock(lock, pair.ModifyIndex)
326		if err != nil {
327			return err
328		}
329
330		// Swap the locks
331		didSet, _, err := kv.CAS(newLock, &wOpts)
332		if err != nil {
333			return fmt.Errorf("failed to update lock: %v", err)
334		}
335		if !didSet {
336			goto READ
337		}
338	}
339
340	// Destroy the contender entry
341	contenderKey := path.Join(s.opts.Prefix, lockSession)
342	if _, err := kv.Delete(contenderKey, &wOpts); err != nil {
343		return err
344	}
345	return nil
346}
347
348// Destroy is used to cleanup the semaphore entry. It is not necessary
349// to invoke. It will fail if the semaphore is in use.
350func (s *Semaphore) Destroy() error {
351	// Hold the lock as we try to acquire
352	s.l.Lock()
353	defer s.l.Unlock()
354
355	// Check if we already hold the semaphore
356	if s.isHeld {
357		return ErrSemaphoreHeld
358	}
359
360	// List for the semaphore
361	kv := s.c.KV()
362
363	q := QueryOptions{Namespace: s.opts.Namespace}
364	pairs, _, err := kv.List(s.opts.Prefix, &q)
365	if err != nil {
366		return fmt.Errorf("failed to read prefix: %v", err)
367	}
368
369	// Find the lock pair, bail if it doesn't exist
370	lockPair := s.findLock(pairs)
371	if lockPair.ModifyIndex == 0 {
372		return nil
373	}
374	if lockPair.Flags != SemaphoreFlagValue {
375		return ErrSemaphoreConflict
376	}
377
378	// Decode the lock
379	lock, err := s.decodeLock(lockPair)
380	if err != nil {
381		return err
382	}
383
384	// Prune the dead holders
385	s.pruneDeadHolders(lock, pairs)
386
387	// Check if there are any holders
388	if len(lock.Holders) > 0 {
389		return ErrSemaphoreInUse
390	}
391
392	// Attempt the delete
393	w := WriteOptions{Namespace: s.opts.Namespace}
394	didRemove, _, err := kv.DeleteCAS(lockPair, &w)
395	if err != nil {
396		return fmt.Errorf("failed to remove semaphore: %v", err)
397	}
398	if !didRemove {
399		return ErrSemaphoreInUse
400	}
401	return nil
402}
403
404// createSession is used to create a new managed session
405func (s *Semaphore) createSession() (string, error) {
406	session := s.c.Session()
407	se := &SessionEntry{
408		Name:     s.opts.SessionName,
409		TTL:      s.opts.SessionTTL,
410		Behavior: SessionBehaviorDelete,
411	}
412
413	w := WriteOptions{Namespace: s.opts.Namespace}
414	id, _, err := session.Create(se, &w)
415	if err != nil {
416		return "", err
417	}
418	return id, nil
419}
420
421// contenderEntry returns a formatted KVPair for the contender
422func (s *Semaphore) contenderEntry(session string) *KVPair {
423	return &KVPair{
424		Key:     path.Join(s.opts.Prefix, session),
425		Value:   s.opts.Value,
426		Session: session,
427		Flags:   SemaphoreFlagValue,
428	}
429}
430
431// findLock is used to find the KV Pair which is used for coordination
432func (s *Semaphore) findLock(pairs KVPairs) *KVPair {
433	key := path.Join(s.opts.Prefix, DefaultSemaphoreKey)
434	for _, pair := range pairs {
435		if pair.Key == key {
436			return pair
437		}
438	}
439	return &KVPair{Flags: SemaphoreFlagValue}
440}
441
442// decodeLock is used to decode a semaphoreLock from an
443// entry in Consul
444func (s *Semaphore) decodeLock(pair *KVPair) (*semaphoreLock, error) {
445	// Handle if there is no lock
446	if pair == nil || pair.Value == nil {
447		return &semaphoreLock{
448			Limit:   s.opts.Limit,
449			Holders: make(map[string]bool),
450		}, nil
451	}
452
453	l := &semaphoreLock{}
454	if err := json.Unmarshal(pair.Value, l); err != nil {
455		return nil, fmt.Errorf("lock decoding failed: %v", err)
456	}
457	return l, nil
458}
459
460// encodeLock is used to encode a semaphoreLock into a KVPair
461// that can be PUT
462func (s *Semaphore) encodeLock(l *semaphoreLock, oldIndex uint64) (*KVPair, error) {
463	enc, err := json.Marshal(l)
464	if err != nil {
465		return nil, fmt.Errorf("lock encoding failed: %v", err)
466	}
467	pair := &KVPair{
468		Key:         path.Join(s.opts.Prefix, DefaultSemaphoreKey),
469		Value:       enc,
470		Flags:       SemaphoreFlagValue,
471		ModifyIndex: oldIndex,
472	}
473	return pair, nil
474}
475
476// pruneDeadHolders is used to remove all the dead lock holders
477func (s *Semaphore) pruneDeadHolders(lock *semaphoreLock, pairs KVPairs) {
478	// Gather all the live holders
479	alive := make(map[string]struct{}, len(pairs))
480	for _, pair := range pairs {
481		if pair.Session != "" {
482			alive[pair.Session] = struct{}{}
483		}
484	}
485
486	// Remove any holders that are dead
487	for holder := range lock.Holders {
488		if _, ok := alive[holder]; !ok {
489			delete(lock.Holders, holder)
490		}
491	}
492}
493
494// monitorLock is a long running routine to monitor a semaphore ownership
495// It closes the stopCh if we lose our slot.
496func (s *Semaphore) monitorLock(session string, stopCh chan struct{}) {
497	defer close(stopCh)
498	kv := s.c.KV()
499	opts := QueryOptions{
500		RequireConsistent: true,
501		Namespace:         s.opts.Namespace,
502	}
503WAIT:
504	retries := s.opts.MonitorRetries
505RETRY:
506	pairs, meta, err := kv.List(s.opts.Prefix, &opts)
507	if err != nil {
508		// If configured we can try to ride out a brief Consul unavailability
509		// by doing retries. Note that we have to attempt the retry in a non-
510		// blocking fashion so that we have a clean place to reset the retry
511		// counter if service is restored.
512		if retries > 0 && IsRetryableError(err) {
513			time.Sleep(s.opts.MonitorRetryTime)
514			retries--
515			opts.WaitIndex = 0
516			goto RETRY
517		}
518		return
519	}
520	lockPair := s.findLock(pairs)
521	lock, err := s.decodeLock(lockPair)
522	if err != nil {
523		return
524	}
525	s.pruneDeadHolders(lock, pairs)
526	if _, ok := lock.Holders[session]; ok {
527		opts.WaitIndex = meta.LastIndex
528		goto WAIT
529	}
530}
531