1package zk
2
3import (
4	"errors"
5	"fmt"
6	"strconv"
7	"strings"
8)
9
10var (
11	// ErrDeadlock is returned by Lock when trying to lock twice without unlocking first
12	ErrDeadlock = errors.New("zk: trying to acquire a lock twice")
13	// ErrNotLocked is returned by Unlock when trying to release a lock that has not first be acquired.
14	ErrNotLocked = errors.New("zk: not locked")
15)
16
17// Lock is a mutual exclusion lock.
18type Lock struct {
19	c        *Conn
20	path     string
21	acl      []ACL
22	lockPath string
23	seq      int
24}
25
26// NewLock creates a new lock instance using the provided connection, path, and acl.
27// The path must be a node that is only used by this lock. A lock instances starts
28// unlocked until Lock() is called.
29func NewLock(c *Conn, path string, acl []ACL) *Lock {
30	return &Lock{
31		c:    c,
32		path: path,
33		acl:  acl,
34	}
35}
36
37func parseSeq(path string) (int, error) {
38	parts := strings.Split(path, "-")
39	return strconv.Atoi(parts[len(parts)-1])
40}
41
42// Lock attempts to acquire the lock. It will wait to return until the lock
43// is acquired or an error occurs. If this instance already has the lock
44// then ErrDeadlock is returned.
45func (l *Lock) Lock() error {
46	if l.lockPath != "" {
47		return ErrDeadlock
48	}
49
50	prefix := fmt.Sprintf("%s/lock-", l.path)
51
52	path := ""
53	var err error
54	for i := 0; i < 3; i++ {
55		path, err = l.c.CreateProtectedEphemeralSequential(prefix, []byte{}, l.acl)
56		if err == ErrNoNode {
57			// Create parent node.
58			parts := strings.Split(l.path, "/")
59			pth := ""
60			for _, p := range parts[1:] {
61				var exists bool
62				pth += "/" + p
63				exists, _, err = l.c.Exists(pth)
64				if err != nil {
65					return err
66				}
67				if exists == true {
68					continue
69				}
70				_, err = l.c.Create(pth, []byte{}, 0, l.acl)
71				if err != nil && err != ErrNodeExists {
72					return err
73				}
74			}
75		} else if err == nil {
76			break
77		} else {
78			return err
79		}
80	}
81	if err != nil {
82		return err
83	}
84
85	seq, err := parseSeq(path)
86	if err != nil {
87		return err
88	}
89
90	for {
91		children, _, err := l.c.Children(l.path)
92		if err != nil {
93			return err
94		}
95
96		lowestSeq := seq
97		prevSeq := -1
98		prevSeqPath := ""
99		for _, p := range children {
100			s, err := parseSeq(p)
101			if err != nil {
102				return err
103			}
104			if s < lowestSeq {
105				lowestSeq = s
106			}
107			if s < seq && s > prevSeq {
108				prevSeq = s
109				prevSeqPath = p
110			}
111		}
112
113		if seq == lowestSeq {
114			// Acquired the lock
115			break
116		}
117
118		// Wait on the node next in line for the lock
119		_, _, ch, err := l.c.GetW(l.path + "/" + prevSeqPath)
120		if err != nil && err != ErrNoNode {
121			return err
122		} else if err != nil && err == ErrNoNode {
123			// try again
124			continue
125		}
126
127		ev := <-ch
128		if ev.Err != nil {
129			return ev.Err
130		}
131	}
132
133	l.seq = seq
134	l.lockPath = path
135	return nil
136}
137
138// Unlock releases an acquired lock. If the lock is not currently acquired by
139// this Lock instance than ErrNotLocked is returned.
140func (l *Lock) Unlock() error {
141	if l.lockPath == "" {
142		return ErrNotLocked
143	}
144	if err := l.c.Delete(l.lockPath, -1); err != nil {
145		return err
146	}
147	l.lockPath = ""
148	l.seq = 0
149	return nil
150}
151