1package zk
2
3import (
4	"errors"
5	"fmt"
6	"strconv"
7	"strings"
8)
9
10var (
11	// ErrDeadlock is returned by Lock when trying to lock twice without unlocking first
12	ErrDeadlock = errors.New("zk: trying to acquire a lock twice")
13	// ErrNotLocked is returned by Unlock when trying to release a lock that has not first be acquired.
14	ErrNotLocked = errors.New("zk: not locked")
15)
16
17// Lock is a mutual exclusion lock.
18type Lock struct {
19	c        *Conn
20	path     string
21	acl      []ACL
22	lockPath string
23	seq      int
24}
25
26// NewLock creates a new lock instance using the provided connection, path, and acl.
27// The path must be a node that is only used by this lock. A lock instances starts
28// unlocked until Lock() is called.
29func NewLock(c *Conn, path string, acl []ACL) *Lock {
30	return &Lock{
31		c:    c,
32		path: path,
33		acl:  acl,
34	}
35}
36
37func parseSeq(path string) (int, error) {
38	parts := strings.Split(path, "-")
39	// python client uses a __LOCK__ prefix
40	if len(parts) == 1 {
41		parts = strings.Split(path, "__")
42	}
43	return strconv.Atoi(parts[len(parts)-1])
44}
45
46// Lock attempts to acquire the lock. It works like LockWithData, but it doesn't
47// write any data to the lock node.
48func (l *Lock) Lock() error {
49	return l.LockWithData([]byte{})
50}
51
52// LockWithData attempts to acquire the lock, writing data into the lock node.
53// It will wait to return until the lock is acquired or an error occurs. If
54// this instance already has the lock then ErrDeadlock is returned.
55func (l *Lock) LockWithData(data []byte) error {
56	if l.lockPath != "" {
57		return ErrDeadlock
58	}
59
60	prefix := fmt.Sprintf("%s/lock-", l.path)
61
62	path := ""
63	var err error
64	for i := 0; i < 3; i++ {
65		path, err = l.c.CreateProtectedEphemeralSequential(prefix, data, l.acl)
66		if err == ErrNoNode {
67			// Create parent node.
68			parts := strings.Split(l.path, "/")
69			pth := ""
70			for _, p := range parts[1:] {
71				var exists bool
72				pth += "/" + p
73				exists, _, err = l.c.Exists(pth)
74				if err != nil {
75					return err
76				}
77				if exists == true {
78					continue
79				}
80				_, err = l.c.Create(pth, []byte{}, 0, l.acl)
81				if err != nil && err != ErrNodeExists {
82					return err
83				}
84			}
85		} else if err == nil {
86			break
87		} else {
88			return err
89		}
90	}
91	if err != nil {
92		return err
93	}
94
95	seq, err := parseSeq(path)
96	if err != nil {
97		return err
98	}
99
100	for {
101		children, _, err := l.c.Children(l.path)
102		if err != nil {
103			return err
104		}
105
106		lowestSeq := seq
107		prevSeq := -1
108		prevSeqPath := ""
109		for _, p := range children {
110			s, err := parseSeq(p)
111			if err != nil {
112				return err
113			}
114			if s < lowestSeq {
115				lowestSeq = s
116			}
117			if s < seq && s > prevSeq {
118				prevSeq = s
119				prevSeqPath = p
120			}
121		}
122
123		if seq == lowestSeq {
124			// Acquired the lock
125			break
126		}
127
128		// Wait on the node next in line for the lock
129		_, _, ch, err := l.c.GetW(l.path + "/" + prevSeqPath)
130		if err != nil && err != ErrNoNode {
131			return err
132		} else if err != nil && err == ErrNoNode {
133			// try again
134			continue
135		}
136
137		ev := <-ch
138		if ev.Err != nil {
139			return ev.Err
140		}
141	}
142
143	l.seq = seq
144	l.lockPath = path
145	return nil
146}
147
148// Unlock releases an acquired lock. If the lock is not currently acquired by
149// this Lock instance than ErrNotLocked is returned.
150func (l *Lock) Unlock() error {
151	if l.lockPath == "" {
152		return ErrNotLocked
153	}
154	if err := l.c.Delete(l.lockPath, -1); err != nil {
155		return err
156	}
157	l.lockPath = ""
158	l.seq = 0
159	return nil
160}
161