1package lock 2 3import ( 4 "database/sql" 5 "errors" 6 "fmt" 7 "hash/crc32" 8 "strconv" 9 "strings" 10 "sync" 11 12 "code.cloudfoundry.org/lager" 13) 14 15const ( 16 LockTypeResourceConfigChecking = iota 17 LockTypeBuildTracking 18 LockTypeBatch 19 LockTypeVolumeCreating 20 LockTypeContainerCreating 21 LockTypeDatabaseMigration 22 LockTypeActiveTasks 23 LockTypeResourceScanning 24 LockTypeJobScheduling 25) 26 27var ErrLostLock = errors.New("lock was lost while held, possibly due to connection breakage") 28 29func NewBuildTrackingLockID(buildID int) LockID { 30 return LockID{LockTypeBuildTracking, buildID} 31} 32 33func NewResourceConfigCheckingLockID(resourceConfigID int) LockID { 34 return LockID{LockTypeResourceConfigChecking, resourceConfigID} 35} 36 37func NewTaskLockID(taskName string) LockID { 38 return LockID{LockTypeBatch, lockIDFromString(taskName)} 39} 40 41func NewVolumeCreatingLockID(volumeID int) LockID { 42 return LockID{LockTypeVolumeCreating, volumeID} 43} 44 45func NewDatabaseMigrationLockID() LockID { 46 return LockID{LockTypeDatabaseMigration} 47} 48 49func NewActiveTasksLockID() LockID { 50 return LockID{LockTypeActiveTasks} 51} 52 53func NewResourceScanningLockID() LockID { 54 return LockID{LockTypeResourceScanning} 55} 56 57func NewJobSchedulingLockID(jobID int) LockID { 58 return LockID{LockTypeJobScheduling, jobID} 59} 60 61//go:generate counterfeiter . LockFactory 62 63type LockFactory interface { 64 Acquire(logger lager.Logger, ids LockID) (Lock, bool, error) 65} 66 67type lockFactory struct { 68 db LockDB 69 locks lockRepo 70 acquireMutex *sync.Mutex 71 72 acquireFunc LogFunc 73 releaseFunc LogFunc 74} 75 76type LogFunc func(logger lager.Logger, id LockID) 77 78func NewLockFactory( 79 conn *sql.DB, 80 acquire LogFunc, 81 release LogFunc, 82) LockFactory { 83 return &lockFactory{ 84 db: &lockDB{ 85 conn: conn, 86 mutex: &sync.Mutex{}, 87 }, 88 acquireFunc: acquire, 89 releaseFunc: release, 90 locks: lockRepo{ 91 locks: map[string]bool{}, 92 mutex: &sync.Mutex{}, 93 }, 94 acquireMutex: &sync.Mutex{}, 95 } 96} 97 98func NewTestLockFactory(db LockDB) LockFactory { 99 return &lockFactory{ 100 db: db, 101 locks: lockRepo{ 102 locks: map[string]bool{}, 103 mutex: &sync.Mutex{}, 104 }, 105 acquireMutex: &sync.Mutex{}, 106 acquireFunc: func(logger lager.Logger, id LockID) {}, 107 releaseFunc: func(logger lager.Logger, id LockID) {}, 108 } 109} 110 111func (f *lockFactory) Acquire(logger lager.Logger, id LockID) (Lock, bool, error) { 112 l := &lock{ 113 logger: logger, 114 db: f.db, 115 id: id, 116 locks: f.locks, 117 acquireMutex: f.acquireMutex, 118 acquired: f.acquireFunc, 119 released: f.releaseFunc, 120 } 121 122 acquired, err := l.Acquire() 123 if err != nil { 124 return nil, false, err 125 } 126 127 if !acquired { 128 return nil, false, nil 129 } 130 131 return l, true, nil 132} 133 134//go:generate counterfeiter . Lock 135 136type Lock interface { 137 Release() error 138} 139 140//go:generate counterfeiter . LockDB 141 142type LockDB interface { 143 Acquire(id LockID) (bool, error) 144 Release(id LockID) (bool, error) 145} 146 147type lock struct { 148 id LockID 149 150 logger lager.Logger 151 db LockDB 152 locks lockRepo 153 acquireMutex *sync.Mutex 154 155 acquired LogFunc 156 released LogFunc 157} 158 159func (l *lock) Acquire() (bool, error) { 160 l.acquireMutex.Lock() 161 defer l.acquireMutex.Unlock() 162 163 logger := l.logger.Session("acquire", lager.Data{"id": l.id}) 164 165 if l.locks.IsRegistered(l.id) { 166 logger.Debug("not-acquired-already-held-locally") 167 return false, nil 168 } 169 170 acquired, err := l.db.Acquire(l.id) 171 if err != nil { 172 logger.Error("failed-to-register-in-db", err) 173 return false, err 174 } 175 176 if !acquired { 177 logger.Debug("not-acquired-already-held-in-db") 178 return false, nil 179 } 180 181 l.locks.Register(l.id) 182 183 l.acquired(logger, l.id) 184 185 return true, nil 186} 187 188func (l *lock) Release() error { 189 logger := l.logger.Session("release", lager.Data{"id": l.id}) 190 191 released, err := l.db.Release(l.id) 192 if err != nil { 193 logger.Error("failed-to-release-in-db-but-continuing-anyway", err) 194 } 195 196 l.locks.Unregister(l.id) 197 198 if !released { 199 logger.Error("failed-to-release", ErrLostLock) 200 return ErrLostLock 201 } 202 203 l.released(logger, l.id) 204 205 return nil 206} 207 208type lockDB struct { 209 conn *sql.DB 210 mutex *sync.Mutex 211} 212 213func (db *lockDB) Acquire(id LockID) (bool, error) { 214 db.mutex.Lock() 215 defer db.mutex.Unlock() 216 217 var acquired bool 218 err := db.conn.QueryRow(`SELECT pg_try_advisory_lock(`+id.toDBParams()+`)`, id.toDBArgs()...).Scan(&acquired) 219 if err != nil { 220 return false, err 221 } 222 223 return acquired, nil 224} 225 226func (db *lockDB) Release(id LockID) (bool, error) { 227 db.mutex.Lock() 228 defer db.mutex.Unlock() 229 230 var released bool 231 err := db.conn.QueryRow(`SELECT pg_advisory_unlock(`+id.toDBParams()+`)`, id.toDBArgs()...).Scan(&released) 232 if err != nil { 233 return false, err 234 } 235 236 return released, nil 237} 238 239type lockRepo struct { 240 locks map[string]bool 241 mutex *sync.Mutex 242} 243 244func (lr lockRepo) IsRegistered(id LockID) bool { 245 lr.mutex.Lock() 246 defer lr.mutex.Unlock() 247 248 if _, ok := lr.locks[id.toKey()]; ok { 249 return true 250 } 251 return false 252} 253 254func (lr lockRepo) Register(id LockID) { 255 lr.mutex.Lock() 256 defer lr.mutex.Unlock() 257 258 lr.locks[id.toKey()] = true 259} 260 261func (lr lockRepo) Unregister(id LockID) { 262 lr.mutex.Lock() 263 defer lr.mutex.Unlock() 264 265 delete(lr.locks, id.toKey()) 266} 267 268type LockID []int 269 270func (l LockID) toKey() string { 271 s := []string{} 272 for i := range l { 273 s = append(s, strconv.Itoa(l[i])) 274 } 275 return strings.Join(s, "+") 276} 277 278func (l LockID) toDBParams() string { 279 s := []string{} 280 for i := range l { 281 s = append(s, fmt.Sprintf("$%d", i+1)) 282 } 283 284 return strings.Join(s, ",") 285} 286 287func (l LockID) toDBArgs() []interface{} { 288 result := []interface{}{} 289 for i := range l { 290 result = append(result, l[i]) 291 } 292 293 return result 294} 295 296func lockIDFromString(taskName string) int { 297 return int(int32(crc32.ChecksumIEEE([]byte(taskName)))) 298} 299