1/*
2** Zabbix
3** Copyright (C) 2001-2021 Zabbix SIA
4**
5** This program is free software; you can redistribute it and/or modify
6** it under the terms of the GNU General Public License as published by
7** the Free Software Foundation; either version 2 of the License, or
8** (at your option) any later version.
9**
10** This program is distributed in the hope that it will be useful,
11** but WITHOUT ANY WARRANTY; without even the implied warranty of
12** MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13** GNU General Public License for more details.
14**
15** You should have received a copy of the GNU General Public License
16** along with this program; if not, write to the Free Software
17** Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.
18**/
19
20package postgres
21
22import (
23	"context"
24	"crypto/tls"
25	"database/sql"
26	"fmt"
27	"net"
28	"net/url"
29	"path/filepath"
30	"strings"
31	"sync"
32	"time"
33
34	"github.com/jackc/pgx/v4/pgxpool"
35	"github.com/jackc/pgx/v4/stdlib"
36	"github.com/omeid/go-yarn"
37	"zabbix.com/pkg/log"
38	"zabbix.com/pkg/tlsconfig"
39	"zabbix.com/pkg/uri"
40	"zabbix.com/pkg/zbxerr"
41)
42
43const MinSupportedPGVersion = 100000
44
45type PostgresClient interface {
46	Query(ctx context.Context, query string, args ...interface{}) (rows *sql.Rows, err error)
47	QueryByName(ctx context.Context, queryName string, args ...interface{}) (rows *sql.Rows, err error)
48	QueryRow(ctx context.Context, query string, args ...interface{}) (row *sql.Row, err error)
49	QueryRowByName(ctx context.Context, queryName string, args ...interface{}) (row *sql.Row, err error)
50	PostgresVersion() int
51}
52
53// PGConn holds pointer to the Pool of Postgres Instance.
54type PGConn struct {
55	client         *sql.DB
56	callTimeout    time.Duration
57	ctx            context.Context
58	lastTimeAccess time.Time
59	version        int
60	queryStorage   *yarn.Yarn
61}
62
63var errorQueryNotFound = "query %q not found"
64
65// Query wraps pgxpool.Query.
66func (conn *PGConn) Query(ctx context.Context, query string, args ...interface{}) (rows *sql.Rows, err error) {
67	rows, err = conn.client.QueryContext(ctx, query, args...)
68
69	if ctxErr := ctx.Err(); ctxErr != nil {
70		err = ctxErr
71	}
72
73	return
74}
75
76// QueryByName executes a query from queryStorage by its name and returns a singe row.
77func (conn *PGConn) QueryByName(ctx context.Context, queryName string, args ...interface{}) (rows *sql.Rows, err error) {
78	if sql, ok := (*conn.queryStorage).Get(queryName + sqlExt); ok {
79		normalizedSQL := strings.TrimRight(strings.TrimSpace(sql), ";")
80
81		return conn.Query(ctx, normalizedSQL, args...)
82	}
83
84	return nil, fmt.Errorf(errorQueryNotFound, queryName)
85}
86
87// QueryRow wraps pgxpool.QueryRow.
88func (conn *PGConn) QueryRow(ctx context.Context, query string, args ...interface{}) (row *sql.Row, err error) {
89	row = conn.client.QueryRowContext(ctx, query, args...)
90
91	if ctxErr := ctx.Err(); ctxErr != nil {
92		err = ctxErr
93	}
94
95	return
96}
97
98// QueryRowByName executes a query from queryStorage by its name and returns a singe row.
99func (conn *PGConn) QueryRowByName(ctx context.Context, queryName string, args ...interface{}) (row *sql.Row, err error) {
100	if sql, ok := (*conn.queryStorage).Get(queryName + sqlExt); ok {
101		normalizedSQL := strings.TrimRight(strings.TrimSpace(sql), ";")
102
103		return conn.QueryRow(ctx, normalizedSQL, args...)
104	}
105
106	return nil, fmt.Errorf(errorQueryNotFound, queryName)
107}
108
109// GetPostgresVersion exec SQL query to retrieve the version of PostgreSQL server we are currently connected to.
110func getPostgresVersion(ctx context.Context, conn *sql.DB) (version int, err error) {
111	err = conn.QueryRowContext(ctx, `select current_setting('server_version_num');`).Scan(&version)
112
113	return
114}
115
116// PostgresVersion returns the version of PostgreSQL server we are currently connected to.
117func (conn *PGConn) PostgresVersion() int {
118	return conn.version
119}
120
121// updateAccessTime updates the last time a connection was accessed.
122func (conn *PGConn) updateAccessTime() {
123	conn.lastTimeAccess = time.Now()
124}
125
126// ConnManager is a thread-safe structure for manage connections.
127type ConnManager struct {
128	sync.Mutex
129	connMutex      sync.Mutex
130	connections    map[uri.URI]*PGConn
131	keepAlive      time.Duration
132	connectTimeout time.Duration
133	callTimeout    time.Duration
134	Destroy        context.CancelFunc
135	queryStorage   yarn.Yarn
136}
137
138// NewConnManager initializes connManager structure and runs Go Routine that watches for unused connections.
139func NewConnManager(keepAlive, connectTimeout, callTimeout,
140	hkInterval time.Duration, queryStorage yarn.Yarn) *ConnManager {
141	ctx, cancel := context.WithCancel(context.Background())
142
143	connMgr := &ConnManager{
144		connections:    make(map[uri.URI]*PGConn),
145		keepAlive:      keepAlive,
146		connectTimeout: connectTimeout,
147		callTimeout:    callTimeout,
148		Destroy:        cancel, // Destroy stops originated goroutines and closes connections.
149		queryStorage:   queryStorage,
150	}
151
152	go connMgr.housekeeper(ctx, hkInterval)
153
154	return connMgr
155}
156
157// closeUnused closes each connection that has not been accessed at least within the keepalive interval.
158func (c *ConnManager) closeUnused() {
159	c.connMutex.Lock()
160	defer c.connMutex.Unlock()
161
162	for uri, conn := range c.connections {
163		if time.Since(conn.lastTimeAccess) > c.keepAlive {
164			conn.client.Close()
165			delete(c.connections, uri)
166			log.Debugf("[%s] Closed unused connection: %s", pluginName, uri.Addr())
167		}
168	}
169}
170
171// closeAll closes all existed connections.
172func (c *ConnManager) closeAll() {
173	c.connMutex.Lock()
174	for uri, conn := range c.connections {
175		conn.client.Close()
176		delete(c.connections, uri)
177	}
178	c.connMutex.Unlock()
179}
180
181// housekeeper repeatedly checks for unused connections and closes them.
182func (c *ConnManager) housekeeper(ctx context.Context, interval time.Duration) {
183	ticker := time.NewTicker(interval)
184
185	for {
186		select {
187		case <-ctx.Done():
188			ticker.Stop()
189			c.closeAll()
190
191			return
192		case <-ticker.C:
193			c.closeUnused()
194		}
195	}
196}
197
198// create creates a new connection with given credentials.
199func (c *ConnManager) create(uri uri.URI, details tlsconfig.Details) (*PGConn, error) {
200	c.connMutex.Lock()
201	defer c.connMutex.Unlock()
202
203	if _, ok := c.connections[uri]; ok {
204		// Should never happen.
205		panic("connection already exists")
206	}
207
208	ctx := context.Background()
209
210	host := uri.Host()
211	port := uri.Port()
212
213	if uri.Scheme() == "unix" {
214		socket := uri.Addr()
215		host = filepath.Dir(socket)
216
217		ext := filepath.Ext(filepath.Base(socket))
218		if len(ext) <= 1 {
219			return nil, fmt.Errorf("incorrect socket: %q", socket)
220		}
221
222		port = ext[1:]
223	}
224
225	dbname, err := url.QueryUnescape(uri.GetParam("dbname"))
226	if err != nil {
227		return nil, err
228	}
229
230	dsn := fmt.Sprintf("host=%s port=%s dbname=%s user=%s",
231		host, port, dbname, uri.User())
232
233	if uri.Password() != "" {
234		dsn += " password=" + uri.Password()
235	}
236
237	client, err := createTLSClient(dsn, c.connectTimeout, details)
238	if err != nil {
239		return nil, err
240	}
241
242	serverVersion, err := getPostgresVersion(ctx, client)
243	if err != nil {
244		return nil, err
245	}
246
247	if serverVersion < MinSupportedPGVersion {
248		return nil, fmt.Errorf("postgres version %d is not supported", serverVersion)
249	}
250
251	c.connections[uri] = &PGConn{
252		client:         client,
253		callTimeout:    c.callTimeout,
254		version:        serverVersion,
255		lastTimeAccess: time.Now(),
256		ctx:            ctx,
257		queryStorage:   &c.queryStorage,
258	}
259
260	log.Debugf("[%s] Created new connection: %s", pluginName, uri.Addr())
261
262	return c.connections[uri], nil
263}
264
265func createTLSClient(dsn string, timeout time.Duration, details tlsconfig.Details) (*sql.DB, error) {
266	config, err := pgxpool.ParseConfig(dsn)
267	if err != nil {
268		return nil, err
269	}
270
271	config.ConnConfig.DialFunc = func(ctx context.Context, network, addr string) (net.Conn, error) {
272		d := net.Dialer{}
273		ctxTimeout, cancel := context.WithTimeout(context.Background(), timeout)
274		defer cancel()
275
276		conn, err := d.DialContext(ctxTimeout, network, addr)
277
278		return conn, err
279	}
280
281	config.ConnConfig.TLSConfig, err = getTLSConfig(details)
282	if err != nil {
283		return nil, err
284	}
285
286	return stdlib.OpenDB(*config.ConnConfig), nil
287}
288
289func getTLSConfig(details tlsconfig.Details) (*tls.Config, error) {
290	switch details.TlsConnect {
291	case "required":
292		return &tls.Config{InsecureSkipVerify: true}, nil
293	case "verify_ca":
294		return tlsconfig.CreateConfig(details, true)
295	case "verify_full":
296		return tlsconfig.CreateConfig(details, false)
297	}
298
299	return nil, nil
300}
301
302// get returns a connection with given uri if it exists and also updates lastTimeAccess, otherwise returns nil.
303func (c *ConnManager) get(uri uri.URI) *PGConn {
304	c.connMutex.Lock()
305	defer c.connMutex.Unlock()
306
307	if conn, ok := c.connections[uri]; ok {
308		conn.updateAccessTime()
309		return conn
310	}
311
312	return nil
313}
314
315// GetConnection returns an existing connection or creates a new one.
316func (c *ConnManager) GetConnection(uri uri.URI, details tlsconfig.Details) (conn *PGConn, err error) {
317	c.Lock()
318	defer c.Unlock()
319
320	conn = c.get(uri)
321
322	if conn == nil {
323		conn, err = c.create(uri, details)
324	}
325
326	if err != nil {
327		err = zbxerr.ErrorConnectionFailed.Wrap(err)
328	}
329
330	return
331}
332