1package lock
2
3import (
4	"database/sql"
5	"errors"
6	"fmt"
7	"hash/crc32"
8	"strconv"
9	"strings"
10	"sync"
11
12	"code.cloudfoundry.org/lager"
13)
14
15const (
16	LockTypeResourceConfigChecking = iota
17	LockTypeBuildTracking
18	LockTypeBatch
19	LockTypeVolumeCreating
20	LockTypeContainerCreating
21	LockTypeDatabaseMigration
22	LockTypeActiveTasks
23	LockTypeResourceScanning
24	LockTypeJobScheduling
25)
26
27var ErrLostLock = errors.New("lock was lost while held, possibly due to connection breakage")
28
29func NewBuildTrackingLockID(buildID int) LockID {
30	return LockID{LockTypeBuildTracking, buildID}
31}
32
33func NewResourceConfigCheckingLockID(resourceConfigID int) LockID {
34	return LockID{LockTypeResourceConfigChecking, resourceConfigID}
35}
36
37func NewTaskLockID(taskName string) LockID {
38	return LockID{LockTypeBatch, lockIDFromString(taskName)}
39}
40
41func NewVolumeCreatingLockID(volumeID int) LockID {
42	return LockID{LockTypeVolumeCreating, volumeID}
43}
44
45func NewDatabaseMigrationLockID() LockID {
46	return LockID{LockTypeDatabaseMigration}
47}
48
49func NewActiveTasksLockID() LockID {
50	return LockID{LockTypeActiveTasks}
51}
52
53func NewResourceScanningLockID() LockID {
54	return LockID{LockTypeResourceScanning}
55}
56
57func NewJobSchedulingLockID(jobID int) LockID {
58	return LockID{LockTypeJobScheduling, jobID}
59}
60
61//go:generate counterfeiter . LockFactory
62
63type LockFactory interface {
64	Acquire(logger lager.Logger, ids LockID) (Lock, bool, error)
65}
66
67type lockFactory struct {
68	db           LockDB
69	locks        lockRepo
70	acquireMutex *sync.Mutex
71
72	acquireFunc LogFunc
73	releaseFunc LogFunc
74}
75
76type LogFunc func(logger lager.Logger, id LockID)
77
78func NewLockFactory(
79	conn *sql.DB,
80	acquire LogFunc,
81	release LogFunc,
82) LockFactory {
83	return &lockFactory{
84		db: &lockDB{
85			conn:  conn,
86			mutex: &sync.Mutex{},
87		},
88		acquireFunc: acquire,
89		releaseFunc: release,
90		locks: lockRepo{
91			locks: map[string]bool{},
92			mutex: &sync.Mutex{},
93		},
94		acquireMutex: &sync.Mutex{},
95	}
96}
97
98func NewTestLockFactory(db LockDB) LockFactory {
99	return &lockFactory{
100		db: db,
101		locks: lockRepo{
102			locks: map[string]bool{},
103			mutex: &sync.Mutex{},
104		},
105		acquireMutex: &sync.Mutex{},
106		acquireFunc:  func(logger lager.Logger, id LockID) {},
107		releaseFunc:  func(logger lager.Logger, id LockID) {},
108	}
109}
110
111func (f *lockFactory) Acquire(logger lager.Logger, id LockID) (Lock, bool, error) {
112	l := &lock{
113		logger:       logger,
114		db:           f.db,
115		id:           id,
116		locks:        f.locks,
117		acquireMutex: f.acquireMutex,
118		acquired:     f.acquireFunc,
119		released:     f.releaseFunc,
120	}
121
122	acquired, err := l.Acquire()
123	if err != nil {
124		return nil, false, err
125	}
126
127	if !acquired {
128		return nil, false, nil
129	}
130
131	return l, true, nil
132}
133
134//go:generate counterfeiter . Lock
135
136type Lock interface {
137	Release() error
138}
139
140//go:generate counterfeiter . LockDB
141
142type LockDB interface {
143	Acquire(id LockID) (bool, error)
144	Release(id LockID) (bool, error)
145}
146
147type lock struct {
148	id LockID
149
150	logger       lager.Logger
151	db           LockDB
152	locks        lockRepo
153	acquireMutex *sync.Mutex
154
155	acquired LogFunc
156	released LogFunc
157}
158
159func (l *lock) Acquire() (bool, error) {
160	l.acquireMutex.Lock()
161	defer l.acquireMutex.Unlock()
162
163	logger := l.logger.Session("acquire", lager.Data{"id": l.id})
164
165	if l.locks.IsRegistered(l.id) {
166		logger.Debug("not-acquired-already-held-locally")
167		return false, nil
168	}
169
170	acquired, err := l.db.Acquire(l.id)
171	if err != nil {
172		logger.Error("failed-to-register-in-db", err)
173		return false, err
174	}
175
176	if !acquired {
177		logger.Debug("not-acquired-already-held-in-db")
178		return false, nil
179	}
180
181	l.locks.Register(l.id)
182
183	l.acquired(logger, l.id)
184
185	return true, nil
186}
187
188func (l *lock) Release() error {
189	logger := l.logger.Session("release", lager.Data{"id": l.id})
190
191	released, err := l.db.Release(l.id)
192	if err != nil {
193		logger.Error("failed-to-release-in-db-but-continuing-anyway", err)
194	}
195
196	l.locks.Unregister(l.id)
197
198	if !released {
199		logger.Error("failed-to-release", ErrLostLock)
200		return ErrLostLock
201	}
202
203	l.released(logger, l.id)
204
205	return nil
206}
207
208type lockDB struct {
209	conn  *sql.DB
210	mutex *sync.Mutex
211}
212
213func (db *lockDB) Acquire(id LockID) (bool, error) {
214	db.mutex.Lock()
215	defer db.mutex.Unlock()
216
217	var acquired bool
218	err := db.conn.QueryRow(`SELECT pg_try_advisory_lock(`+id.toDBParams()+`)`, id.toDBArgs()...).Scan(&acquired)
219	if err != nil {
220		return false, err
221	}
222
223	return acquired, nil
224}
225
226func (db *lockDB) Release(id LockID) (bool, error) {
227	db.mutex.Lock()
228	defer db.mutex.Unlock()
229
230	var released bool
231	err := db.conn.QueryRow(`SELECT pg_advisory_unlock(`+id.toDBParams()+`)`, id.toDBArgs()...).Scan(&released)
232	if err != nil {
233		return false, err
234	}
235
236	return released, nil
237}
238
239type lockRepo struct {
240	locks map[string]bool
241	mutex *sync.Mutex
242}
243
244func (lr lockRepo) IsRegistered(id LockID) bool {
245	lr.mutex.Lock()
246	defer lr.mutex.Unlock()
247
248	if _, ok := lr.locks[id.toKey()]; ok {
249		return true
250	}
251	return false
252}
253
254func (lr lockRepo) Register(id LockID) {
255	lr.mutex.Lock()
256	defer lr.mutex.Unlock()
257
258	lr.locks[id.toKey()] = true
259}
260
261func (lr lockRepo) Unregister(id LockID) {
262	lr.mutex.Lock()
263	defer lr.mutex.Unlock()
264
265	delete(lr.locks, id.toKey())
266}
267
268type LockID []int
269
270func (l LockID) toKey() string {
271	s := []string{}
272	for i := range l {
273		s = append(s, strconv.Itoa(l[i]))
274	}
275	return strings.Join(s, "+")
276}
277
278func (l LockID) toDBParams() string {
279	s := []string{}
280	for i := range l {
281		s = append(s, fmt.Sprintf("$%d", i+1))
282	}
283
284	return strings.Join(s, ",")
285}
286
287func (l LockID) toDBArgs() []interface{} {
288	result := []interface{}{}
289	for i := range l {
290		result = append(result, l[i])
291	}
292
293	return result
294}
295
296func lockIDFromString(taskName string) int {
297	return int(int32(crc32.ChecksumIEEE([]byte(taskName))))
298}
299