1// Package glsql (Gitaly SQL) is a helper package to work with plain SQL queries.
2package glsql
3
4import (
5	"context"
6	"database/sql"
7
8	// Blank import to enable integration of github.com/lib/pq into database/sql
9	_ "github.com/lib/pq"
10	migrate "github.com/rubenv/sql-migrate"
11	"gitlab.com/gitlab-org/gitaly/v14/internal/praefect/config"
12	"gitlab.com/gitlab-org/gitaly/v14/internal/praefect/datastore/migrations"
13)
14
15// OpenDB returns connection pool to the database.
16func OpenDB(conf config.DB) (*sql.DB, error) {
17	db, err := sql.Open("postgres", conf.ToPQString(false))
18	if err != nil {
19		return nil, err
20	}
21
22	if err := db.Ping(); err != nil {
23		db.Close()
24		return nil, err
25	}
26
27	return db, nil
28}
29
30// Migrate will apply all pending SQL migrations.
31func Migrate(db *sql.DB, ignoreUnknown bool) (int, error) {
32	migrationSource := &migrate.MemoryMigrationSource{Migrations: migrations.All()}
33	migrate.SetIgnoreUnknown(ignoreUnknown)
34	return migrate.Exec(db, "postgres", migrationSource, migrate.Up)
35}
36
37// Querier is an abstraction on *sql.DB and *sql.Tx that allows to use their methods without awareness about actual type.
38type Querier interface {
39	QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error)
40	QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row
41	ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error)
42}
43
44// Notification represent a notification from the database.
45type Notification struct {
46	// Channel is a name of the receiving channel.
47	Channel string
48	// Payload is a payload of the notification.
49	Payload string
50}
51
52// ListenHandler contains a set of methods that would be called on corresponding notifications received.
53type ListenHandler interface {
54	// Notification would be triggered once a new notification received.
55	Notification(Notification)
56	// Disconnect would be triggered once a connection to remote service is lost.
57	// Passed in error will never be nil and will contain cause of the disconnection.
58	Disconnect(error)
59	// Connected would be triggered once a connection to remote service is established.
60	Connected()
61}
62
63// DestProvider returns list of pointers that will be used to scan values into.
64type DestProvider interface {
65	// To returns list of pointers.
66	// It is not an idempotent operation and each call will return a new list.
67	To() []interface{}
68}
69
70// ScanAll reads all data from 'rows' into holders provided by 'in'.
71func ScanAll(rows *sql.Rows, in DestProvider) (err error) {
72	for rows.Next() {
73		if err = rows.Scan(in.To()...); err != nil {
74			return err
75		}
76	}
77
78	return nil
79}
80
81// Uint64Provider allows to use it with ScanAll function to read all rows into it and return result as a slice.
82type Uint64Provider []*uint64
83
84// Values returns list of values read from *sql.Rows
85func (p *Uint64Provider) Values() []uint64 {
86	if len(*p) == 0 {
87		return nil
88	}
89
90	r := make([]uint64, len(*p))
91	for i, v := range *p {
92		r[i] = *v
93	}
94	return r
95}
96
97// To returns a list of pointers that will be used as a destination for scan operation.
98func (p *Uint64Provider) To() []interface{} {
99	var d uint64
100	*p = append(*p, &d)
101	return []interface{}{&d}
102}
103
104// StringProvider allows ScanAll to read all rows and return the result as a slice.
105type StringProvider []*string
106
107// Values returns list of values read from *sql.Rows
108func (p *StringProvider) Values() []string {
109	if len(*p) == 0 {
110		return nil
111	}
112
113	r := make([]string, len(*p))
114	for i, v := range *p {
115		r[i] = *v
116	}
117	return r
118}
119
120// To returns a list of pointers that will be used as a destination for scan operation.
121func (p *StringProvider) To() []interface{} {
122	var d string
123	*p = append(*p, &d)
124	return []interface{}{&d}
125}
126