1/* 2** Zabbix 3** Copyright (C) 2001-2021 Zabbix SIA 4** 5** This program is free software; you can redistribute it and/or modify 6** it under the terms of the GNU General Public License as published by 7** the Free Software Foundation; either version 2 of the License, or 8** (at your option) any later version. 9** 10** This program is distributed in the hope that it will be useful, 11** but WITHOUT ANY WARRANTY; without even the implied warranty of 12** MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 13** GNU General Public License for more details. 14** 15** You should have received a copy of the GNU General Public License 16** along with this program; if not, write to the Free Software 17** Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA. 18**/ 19 20package postgres 21 22import ( 23 "context" 24 "crypto/tls" 25 "database/sql" 26 "fmt" 27 "net" 28 "net/url" 29 "path/filepath" 30 "strings" 31 "sync" 32 "time" 33 34 "github.com/jackc/pgx/v4/pgxpool" 35 "github.com/jackc/pgx/v4/stdlib" 36 "github.com/omeid/go-yarn" 37 "zabbix.com/pkg/log" 38 "zabbix.com/pkg/tlsconfig" 39 "zabbix.com/pkg/uri" 40 "zabbix.com/pkg/zbxerr" 41) 42 43const MinSupportedPGVersion = 100000 44 45type PostgresClient interface { 46 Query(ctx context.Context, query string, args ...interface{}) (rows *sql.Rows, err error) 47 QueryByName(ctx context.Context, queryName string, args ...interface{}) (rows *sql.Rows, err error) 48 QueryRow(ctx context.Context, query string, args ...interface{}) (row *sql.Row, err error) 49 QueryRowByName(ctx context.Context, queryName string, args ...interface{}) (row *sql.Row, err error) 50 PostgresVersion() int 51} 52 53// PGConn holds pointer to the Pool of Postgres Instance. 54type PGConn struct { 55 client *sql.DB 56 callTimeout time.Duration 57 ctx context.Context 58 lastTimeAccess time.Time 59 version int 60 queryStorage *yarn.Yarn 61} 62 63var errorQueryNotFound = "query %q not found" 64 65// Query wraps pgxpool.Query. 66func (conn *PGConn) Query(ctx context.Context, query string, args ...interface{}) (rows *sql.Rows, err error) { 67 rows, err = conn.client.QueryContext(ctx, query, args...) 68 69 if ctxErr := ctx.Err(); ctxErr != nil { 70 err = ctxErr 71 } 72 73 return 74} 75 76// QueryByName executes a query from queryStorage by its name and returns a singe row. 77func (conn *PGConn) QueryByName(ctx context.Context, queryName string, args ...interface{}) (rows *sql.Rows, err error) { 78 if sql, ok := (*conn.queryStorage).Get(queryName + sqlExt); ok { 79 normalizedSQL := strings.TrimRight(strings.TrimSpace(sql), ";") 80 81 return conn.Query(ctx, normalizedSQL, args...) 82 } 83 84 return nil, fmt.Errorf(errorQueryNotFound, queryName) 85} 86 87// QueryRow wraps pgxpool.QueryRow. 88func (conn *PGConn) QueryRow(ctx context.Context, query string, args ...interface{}) (row *sql.Row, err error) { 89 row = conn.client.QueryRowContext(ctx, query, args...) 90 91 if ctxErr := ctx.Err(); ctxErr != nil { 92 err = ctxErr 93 } 94 95 return 96} 97 98// QueryRowByName executes a query from queryStorage by its name and returns a singe row. 99func (conn *PGConn) QueryRowByName(ctx context.Context, queryName string, args ...interface{}) (row *sql.Row, err error) { 100 if sql, ok := (*conn.queryStorage).Get(queryName + sqlExt); ok { 101 normalizedSQL := strings.TrimRight(strings.TrimSpace(sql), ";") 102 103 return conn.QueryRow(ctx, normalizedSQL, args...) 104 } 105 106 return nil, fmt.Errorf(errorQueryNotFound, queryName) 107} 108 109// GetPostgresVersion exec SQL query to retrieve the version of PostgreSQL server we are currently connected to. 110func getPostgresVersion(ctx context.Context, conn *sql.DB) (version int, err error) { 111 err = conn.QueryRowContext(ctx, `select current_setting('server_version_num');`).Scan(&version) 112 113 return 114} 115 116// PostgresVersion returns the version of PostgreSQL server we are currently connected to. 117func (conn *PGConn) PostgresVersion() int { 118 return conn.version 119} 120 121// updateAccessTime updates the last time a connection was accessed. 122func (conn *PGConn) updateAccessTime() { 123 conn.lastTimeAccess = time.Now() 124} 125 126// ConnManager is a thread-safe structure for manage connections. 127type ConnManager struct { 128 sync.Mutex 129 connMutex sync.Mutex 130 connections map[uri.URI]*PGConn 131 keepAlive time.Duration 132 connectTimeout time.Duration 133 callTimeout time.Duration 134 Destroy context.CancelFunc 135 queryStorage yarn.Yarn 136} 137 138// NewConnManager initializes connManager structure and runs Go Routine that watches for unused connections. 139func NewConnManager(keepAlive, connectTimeout, callTimeout, 140 hkInterval time.Duration, queryStorage yarn.Yarn) *ConnManager { 141 ctx, cancel := context.WithCancel(context.Background()) 142 143 connMgr := &ConnManager{ 144 connections: make(map[uri.URI]*PGConn), 145 keepAlive: keepAlive, 146 connectTimeout: connectTimeout, 147 callTimeout: callTimeout, 148 Destroy: cancel, // Destroy stops originated goroutines and closes connections. 149 queryStorage: queryStorage, 150 } 151 152 go connMgr.housekeeper(ctx, hkInterval) 153 154 return connMgr 155} 156 157// closeUnused closes each connection that has not been accessed at least within the keepalive interval. 158func (c *ConnManager) closeUnused() { 159 c.connMutex.Lock() 160 defer c.connMutex.Unlock() 161 162 for uri, conn := range c.connections { 163 if time.Since(conn.lastTimeAccess) > c.keepAlive { 164 conn.client.Close() 165 delete(c.connections, uri) 166 log.Debugf("[%s] Closed unused connection: %s", pluginName, uri.Addr()) 167 } 168 } 169} 170 171// closeAll closes all existed connections. 172func (c *ConnManager) closeAll() { 173 c.connMutex.Lock() 174 for uri, conn := range c.connections { 175 conn.client.Close() 176 delete(c.connections, uri) 177 } 178 c.connMutex.Unlock() 179} 180 181// housekeeper repeatedly checks for unused connections and closes them. 182func (c *ConnManager) housekeeper(ctx context.Context, interval time.Duration) { 183 ticker := time.NewTicker(interval) 184 185 for { 186 select { 187 case <-ctx.Done(): 188 ticker.Stop() 189 c.closeAll() 190 191 return 192 case <-ticker.C: 193 c.closeUnused() 194 } 195 } 196} 197 198// create creates a new connection with given credentials. 199func (c *ConnManager) create(uri uri.URI, details tlsconfig.Details) (*PGConn, error) { 200 c.connMutex.Lock() 201 defer c.connMutex.Unlock() 202 203 if _, ok := c.connections[uri]; ok { 204 // Should never happen. 205 panic("connection already exists") 206 } 207 208 ctx := context.Background() 209 210 host := uri.Host() 211 port := uri.Port() 212 213 if uri.Scheme() == "unix" { 214 socket := uri.Addr() 215 host = filepath.Dir(socket) 216 217 ext := filepath.Ext(filepath.Base(socket)) 218 if len(ext) <= 1 { 219 return nil, fmt.Errorf("incorrect socket: %q", socket) 220 } 221 222 port = ext[1:] 223 } 224 225 dbname, err := url.QueryUnescape(uri.GetParam("dbname")) 226 if err != nil { 227 return nil, err 228 } 229 230 dsn := fmt.Sprintf("host=%s port=%s dbname=%s user=%s", 231 host, port, dbname, uri.User()) 232 233 if uri.Password() != "" { 234 dsn += " password=" + uri.Password() 235 } 236 237 client, err := createTLSClient(dsn, c.connectTimeout, details) 238 if err != nil { 239 return nil, err 240 } 241 242 serverVersion, err := getPostgresVersion(ctx, client) 243 if err != nil { 244 return nil, err 245 } 246 247 if serverVersion < MinSupportedPGVersion { 248 return nil, fmt.Errorf("postgres version %d is not supported", serverVersion) 249 } 250 251 c.connections[uri] = &PGConn{ 252 client: client, 253 callTimeout: c.callTimeout, 254 version: serverVersion, 255 lastTimeAccess: time.Now(), 256 ctx: ctx, 257 queryStorage: &c.queryStorage, 258 } 259 260 log.Debugf("[%s] Created new connection: %s", pluginName, uri.Addr()) 261 262 return c.connections[uri], nil 263} 264 265func createTLSClient(dsn string, timeout time.Duration, details tlsconfig.Details) (*sql.DB, error) { 266 config, err := pgxpool.ParseConfig(dsn) 267 if err != nil { 268 return nil, err 269 } 270 271 config.ConnConfig.DialFunc = func(ctx context.Context, network, addr string) (net.Conn, error) { 272 d := net.Dialer{} 273 ctxTimeout, cancel := context.WithTimeout(context.Background(), timeout) 274 defer cancel() 275 276 conn, err := d.DialContext(ctxTimeout, network, addr) 277 278 return conn, err 279 } 280 281 config.ConnConfig.TLSConfig, err = getTLSConfig(details) 282 if err != nil { 283 return nil, err 284 } 285 286 return stdlib.OpenDB(*config.ConnConfig), nil 287} 288 289func getTLSConfig(details tlsconfig.Details) (*tls.Config, error) { 290 switch details.TlsConnect { 291 case "required": 292 return &tls.Config{InsecureSkipVerify: true}, nil 293 case "verify_ca": 294 return tlsconfig.CreateConfig(details, true) 295 case "verify_full": 296 return tlsconfig.CreateConfig(details, false) 297 } 298 299 return nil, nil 300} 301 302// get returns a connection with given uri if it exists and also updates lastTimeAccess, otherwise returns nil. 303func (c *ConnManager) get(uri uri.URI) *PGConn { 304 c.connMutex.Lock() 305 defer c.connMutex.Unlock() 306 307 if conn, ok := c.connections[uri]; ok { 308 conn.updateAccessTime() 309 return conn 310 } 311 312 return nil 313} 314 315// GetConnection returns an existing connection or creates a new one. 316func (c *ConnManager) GetConnection(uri uri.URI, details tlsconfig.Details) (conn *PGConn, err error) { 317 c.Lock() 318 defer c.Unlock() 319 320 conn = c.get(uri) 321 322 if conn == nil { 323 conn, err = c.create(uri, details) 324 } 325 326 if err != nil { 327 err = zbxerr.ErrorConnectionFailed.Wrap(err) 328 } 329 330 return 331} 332