1// This file and its contents are licensed under the Apache License 2.0.
2// Please see the included NOTICE for copyright information and
3// LICENSE for a copy of the license.
4
5package util
6
7import (
8	"context"
9	"fmt"
10	"strconv"
11	"sync"
12	"time"
13
14	pgx "github.com/jackc/pgx/v4"
15
16	"github.com/timescale/promscale/pkg/log"
17)
18
19const defaultConnectionTimeout = time.Minute
20
21var SharedLeaseFailure = fmt.Errorf("failed to acquire shared lease")
22
23func GetSharedLease(ctx context.Context, conn *pgx.Conn, id int64) error {
24	gotten, err := runLockFunction(ctx, conn, "SELECT pg_try_advisory_lock_shared($1)", id)
25	if err != nil {
26		return fmt.Errorf("Unable to get shared schema lock. Please make sure that no operations requiring exclusive locking are running: %w", err)
27	}
28	if !gotten {
29		return SharedLeaseFailure
30	}
31	return nil
32}
33
34// PgLeaderLock is implementation of leader election based on PostgreSQL advisory locks. All adapters within a HA group are trying
35// to obtain an advisory lock for particular group. The one who holds the lock can write to the database. Due to the fact
36// that Prometheus HA setup provides no consistency guarantees this implementation is best effort in regards
37// to metrics that is written (some duplicates or data loss are possible during fail-over)
38// `leader-election-pg-advisory-lock-prometheus-timeout` config must be set when using PgLeaderLock. It will
39// trigger leader resign (if instance is a leader) and will prevent an instance to become a leader if there are no requests coming
40// from Prometheus within a given timeout. Make sure to provide a reasonable value for the timeout (should be co-related with
41// Prometheus scrape interval, eg. 2x or 3x more then scrape interval to prevent leader flipping).
42// Recommended architecture when using PgLeaderLock is to have one adapter instance for one Prometheus instance.
43type PgLeaderLock struct {
44	PgAdvisoryLock
45	obtained        bool
46	numLockAttempts int
47}
48
49type AfterConnectFunc = func(ctx context.Context, conn *pgx.Conn) error
50
51func NewPgLeaderLock(groupLockID int64, connStr string, afterConnect AfterConnectFunc) (*PgLeaderLock, error) {
52	if afterConnect == nil {
53		afterConnect = checkConnection
54	}
55	lock := &PgLeaderLock{
56		PgAdvisoryLock{
57			connStr:      connStr,
58			groupLockID:  groupLockID,
59			afterConnect: afterConnect,
60		},
61		false,
62		0,
63	}
64	_, err := lock.tryLock()
65	if err != nil {
66		return nil, err
67	}
68	return lock, nil
69}
70
71// ID returns the group lock ID for this instance.
72func (l *PgLeaderLock) ID() string {
73	return strconv.FormatInt(int64(l.groupLockID), 10)
74}
75
76// BecomeLeader tries to become a leader by acquiring the lock.
77func (l *PgLeaderLock) BecomeLeader() (bool, error) {
78	return l.tryLock()
79}
80
81// IsLeader returns the current leader status for this instance.
82func (l *PgLeaderLock) IsLeader() (bool, error) {
83	return l.tryLock()
84}
85
86// Resign releases the leader status of this instance.
87func (l *PgLeaderLock) Resign() error {
88	log.Info("msg", "Resigning as the leader (will no longer be writing)", "component", "leader_election", "status", "follower", "group_id", l.groupLockID)
89	return l.release()
90}
91
92// tryLock tries to obtain the lock if its not already the leader. In the case
93// that it is the leader, it verifies the connection to make sure the lock hasn't
94// been already lost.
95func (l *PgLeaderLock) tryLock() (bool, error) {
96	l.mutex.Lock()
97	defer l.mutex.Unlock()
98	defer func() { l.numLockAttempts++ }()
99
100	if l.obtained && l.conn != nil {
101		// we already hold the lock verify the connection
102		err := l.conn.QueryRow(context.Background(), "SELECT").Scan()
103		if err != nil {
104			l.obtained = false
105			l.connCleanUp()
106			return false, err
107		}
108		return true, nil
109	}
110
111	gotLock, err := l.getAdvisoryLock()
112	if err != nil {
113		l.obtained = false
114		l.connCleanUp()
115		return false, err
116	}
117
118	if !gotLock {
119		if l.numLockAttempts == 0 {
120			log.Info("msg", "I am starting as a follower (will not be writing until I become the leader)", "component", "leader_election", "status", "follower", "group_id", l.groupLockID)
121		} else if l.obtained {
122			log.Info("msg", "I have lost the leader lock and am now a follower (will not be writing)", "component", "leader_election", "status", "follower", "group_id", l.groupLockID)
123		}
124		l.obtained = false
125		return false, nil
126	}
127
128	if !l.obtained {
129		l.obtained = true
130		log.Info("msg", "I have become the leader (starting to write incoming data)", "component", "leader_election", "status", "leader", "group_id", l.groupLockID)
131	}
132
133	return true, nil
134}
135
136// Locked returns if the instance was able to obtain the locks.
137// Does NOT verify the lock.
138func (l *PgLeaderLock) locked() bool {
139	l.mutex.RLock()
140	defer l.mutex.RUnlock()
141	return l.obtained
142}
143
144// Release releases the already obtained locks.
145func (l *PgLeaderLock) release() error {
146	l.mutex.Lock()
147	defer l.mutex.Unlock()
148
149	if !l.obtained {
150		return fmt.Errorf("can't release while not holding the lock")
151	}
152	defer func() { l.obtained = false }()
153
154	unlocked, err := l.unlock()
155	if err != nil {
156		l.connCleanUp()
157		return err
158	}
159	if !unlocked {
160		log.Debug("msg", fmt.Sprintf("release for a lock that was not held: group id %d", l.groupLockID))
161	}
162
163	return nil
164}
165
166type PgAdvisoryLock struct {
167	conn         *pgx.Conn
168	connStr      string
169	groupLockID  int64
170	afterConnect AfterConnectFunc
171
172	mutex sync.RWMutex
173}
174
175type AdvisoryLock interface {
176	GetAdvisoryLock() (bool, error)
177	GetSharedAdvisoryLock() (bool, error)
178	Unlock() (bool, error)
179	UnlockShared() (bool, error)
180	Close()
181}
182
183// PgAdvisoryLock is a AdvisoryLock
184var _ AdvisoryLock = (*PgAdvisoryLock)(nil)
185
186// NewPgAdvisoryLock creates a new instance with specified lock ID, connection pool and lock timeout.
187func NewPgAdvisoryLock(groupLockID int64, connStr string) (*PgAdvisoryLock, error) {
188	lock := &PgAdvisoryLock{
189		connStr:      connStr,
190		groupLockID:  groupLockID,
191		afterConnect: checkConnection,
192	}
193	return lock, nil
194}
195
196func (l *PgAdvisoryLock) getConn(connStr string, cur, maxRetries int) (*pgx.Conn, error) {
197	if maxRetries == cur {
198		return nil, fmt.Errorf("max attempts reached. giving up on getting a db connection")
199	}
200
201	cfg, err := pgx.ParseConfig(connStr)
202	if err != nil {
203		return nil, fmt.Errorf("error parsing config connection: %w", err)
204	}
205
206	ctx := context.Background()
207	if cfg.ConnectTimeout.Seconds() == 0 {
208		// Set the defaultConnectionTimeout if the connection string does not contain the
209		// the connection timeout information.
210		cctx, cancel := context.WithTimeout(context.Background(), defaultConnectionTimeout)
211		defer cancel()
212		ctx = cctx
213	}
214
215	lockConn, err := pgx.ConnectConfig(ctx, cfg)
216	if err != nil {
217		return nil, fmt.Errorf("error getting DB connection: %w", err)
218	}
219
220	err = l.afterConnect(ctx, lockConn)
221	if err != nil {
222		log.Error("msg", "Lock connection initialization failed", "err", err)
223		return l.getConn(connStr, cur+1, maxRetries)
224	}
225
226	return lockConn, nil
227}
228
229func (l *PgAdvisoryLock) GetAdvisoryLock() (bool, error) {
230	l.mutex.Lock()
231	defer l.mutex.Unlock()
232	return l.getAdvisoryLock()
233}
234
235func (l *PgAdvisoryLock) getAdvisoryLock() (bool, error) {
236	return l.runLockFunction("SELECT pg_try_advisory_lock($1)")
237}
238
239func (l *PgAdvisoryLock) GetSharedAdvisoryLock() (bool, error) {
240	l.mutex.Lock()
241	defer l.mutex.Unlock()
242	return l.getSharedAdvisoryLock()
243}
244
245func (l *PgAdvisoryLock) getSharedAdvisoryLock() (bool, error) {
246	return l.runLockFunction("SELECT pg_try_advisory_lock_shared($1)")
247}
248
249func (l *PgAdvisoryLock) Conn() (*pgx.Conn, error) {
250	err := l.ensureConnInit()
251	if err != nil {
252		return nil, err
253	}
254	return l.conn, nil
255}
256
257func (l *PgAdvisoryLock) runLockFunction(query string) (bool, error) {
258	err := l.ensureConnInit()
259	if err != nil {
260		return false, err
261	}
262	return runLockFunction(context.Background(), l.conn, query, l.groupLockID)
263}
264
265func (l *PgAdvisoryLock) ensureConnInit() error {
266	if l.conn != nil {
267		return nil
268	}
269	conn, err := l.getConn(l.connStr, 0, 10)
270	l.conn = conn
271	return err
272}
273
274func (l *PgAdvisoryLock) Unlock() (bool, error) {
275	l.mutex.Lock()
276	defer l.mutex.Unlock()
277	return l.unlock()
278}
279
280func (l *PgAdvisoryLock) unlock() (bool, error) {
281	return l.runLockFunction("SELECT pg_advisory_unlock($1)")
282}
283
284func (l *PgAdvisoryLock) UnlockShared() (bool, error) {
285	l.mutex.Lock()
286	defer l.mutex.Unlock()
287	return l.unlockShared()
288}
289
290func (l *PgAdvisoryLock) unlockShared() (bool, error) {
291	return l.runLockFunction("SELECT pg_advisory_unlock_shared($1)")
292}
293
294func runLockFunction(ctx context.Context, conn *pgx.Conn, query string, lockId int64) (result bool, err error) {
295	err = conn.QueryRow(ctx, query, lockId).Scan(&result)
296	if err != nil {
297		return false, fmt.Errorf("error while trying to read response rows from locking function: %w", err)
298	}
299	return result, nil
300}
301
302//Close cleans up the connection
303func (l *PgAdvisoryLock) Close() {
304	l.mutex.Lock()
305	defer l.mutex.Unlock()
306	l.connCleanUp()
307}
308
309func (l *PgAdvisoryLock) connCleanUp() {
310	if l.conn != nil {
311		if err := l.conn.Close(context.Background()); err != nil {
312			log.Error("err", err)
313		}
314	}
315	l.conn = nil
316}
317
318func checkConnection(ctx context.Context, conn *pgx.Conn) error {
319	_, err := conn.Exec(ctx, "SELECT 1")
320	if err != nil {
321		return fmt.Errorf("invalid connection: %w", err)
322	}
323	return nil
324}
325