1package pgxpool
2
3import (
4	"context"
5	"runtime"
6	"strconv"
7	"time"
8
9	"github.com/jackc/pgconn"
10	"github.com/jackc/pgx/v4"
11	"github.com/jackc/puddle"
12	errors "golang.org/x/xerrors"
13)
14
15var defaultMaxConns = int32(4)
16var defaultMinConns = int32(0)
17var defaultMaxConnLifetime = time.Hour
18var defaultMaxConnIdleTime = time.Minute * 30
19var defaultHealthCheckPeriod = time.Minute
20
21type connResource struct {
22	conn      *pgx.Conn
23	conns     []Conn
24	poolRows  []poolRow
25	poolRowss []poolRows
26}
27
28func (cr *connResource) getConn(p *Pool, res *puddle.Resource) *Conn {
29	if len(cr.conns) == 0 {
30		cr.conns = make([]Conn, 128)
31	}
32
33	c := &cr.conns[len(cr.conns)-1]
34	cr.conns = cr.conns[0 : len(cr.conns)-1]
35
36	c.res = res
37	c.p = p
38
39	return c
40}
41
42func (cr *connResource) getPoolRow(c *Conn, r pgx.Row) *poolRow {
43	if len(cr.poolRows) == 0 {
44		cr.poolRows = make([]poolRow, 128)
45	}
46
47	pr := &cr.poolRows[len(cr.poolRows)-1]
48	cr.poolRows = cr.poolRows[0 : len(cr.poolRows)-1]
49
50	pr.c = c
51	pr.r = r
52
53	return pr
54}
55
56func (cr *connResource) getPoolRows(c *Conn, r pgx.Rows) *poolRows {
57	if len(cr.poolRowss) == 0 {
58		cr.poolRowss = make([]poolRows, 128)
59	}
60
61	pr := &cr.poolRowss[len(cr.poolRowss)-1]
62	cr.poolRowss = cr.poolRowss[0 : len(cr.poolRowss)-1]
63
64	pr.c = c
65	pr.r = r
66
67	return pr
68}
69
70type Pool struct {
71	p                 *puddle.Pool
72	config            *Config
73	afterConnect      func(context.Context, *pgx.Conn) error
74	beforeAcquire     func(context.Context, *pgx.Conn) bool
75	afterRelease      func(*pgx.Conn) bool
76	minConns          int32
77	maxConnLifetime   time.Duration
78	maxConnIdleTime   time.Duration
79	healthCheckPeriod time.Duration
80	closeChan         chan struct{}
81}
82
83// Config is the configuration struct for creating a pool. It must be created by ParseConfig and then it can be
84// modified. A manually initialized ConnConfig will cause ConnectConfig to panic.
85type Config struct {
86	ConnConfig *pgx.ConnConfig
87
88	// AfterConnect is called after a connection is established, but before it is added to the pool.
89	AfterConnect func(context.Context, *pgx.Conn) error
90
91	// BeforeAcquire is called before before a connection is acquired from the pool. It must return true to allow the
92	// acquision or false to indicate that the connection should be destroyed and a different connection should be
93	// acquired.
94	BeforeAcquire func(context.Context, *pgx.Conn) bool
95
96	// AfterRelease is called after a connection is released, but before it is returned to the pool. It must return true to
97	// return the connection to the pool or false to destroy the connection.
98	AfterRelease func(*pgx.Conn) bool
99
100	// MaxConnLifetime is the duration since creation after which a connection will be automatically closed.
101	MaxConnLifetime time.Duration
102
103	// MaxConnIdleTime is the duration after which an idle connection will be automatically closed by the health check.
104	MaxConnIdleTime time.Duration
105
106	// MaxConns is the maximum size of the pool.
107	MaxConns int32
108
109	// MinConns is the minimum size of the pool. The health check will increase the number of connections to this
110	// amount if it had dropped below.
111	MinConns int32
112
113	// HealthCheckPeriod is the duration between checks of the health of idle connections.
114	HealthCheckPeriod time.Duration
115
116	// If set to true, pool doesn't do any I/O operation on initialization.
117	// And connects to the server only when the pool starts to be used.
118	// The default is false.
119	LazyConnect bool
120
121	createdByParseConfig bool // Used to enforce created by ParseConfig rule.
122}
123
124// Copy returns a deep copy of the config that is safe to use and modify.
125// The only exception is the tls.Config:
126// according to the tls.Config docs it must not be modified after creation.
127func (c *Config) Copy() *Config {
128	newConfig := new(Config)
129	*newConfig = *c
130	newConfig.ConnConfig = c.ConnConfig.Copy()
131	return newConfig
132}
133
134func (c *Config) ConnString() string { return c.ConnConfig.ConnString() }
135
136// Connect creates a new Pool and immediately establishes one connection. ctx can be used to cancel this initial
137// connection. See ParseConfig for information on connString format.
138func Connect(ctx context.Context, connString string) (*Pool, error) {
139	config, err := ParseConfig(connString)
140	if err != nil {
141		return nil, err
142	}
143
144	return ConnectConfig(ctx, config)
145}
146
147// ConnectConfig creates a new Pool and immediately establishes one connection. ctx can be used to cancel this initial
148// connection. config must have been created by ParseConfig.
149func ConnectConfig(ctx context.Context, config *Config) (*Pool, error) {
150	// Default values are set in ParseConfig. Enforce initial creation by ParseConfig rather than setting defaults from
151	// zero values.
152	if !config.createdByParseConfig {
153		panic("config must be created by ParseConfig")
154	}
155
156	p := &Pool{
157		config:            config,
158		afterConnect:      config.AfterConnect,
159		beforeAcquire:     config.BeforeAcquire,
160		afterRelease:      config.AfterRelease,
161		minConns:          config.MinConns,
162		maxConnLifetime:   config.MaxConnLifetime,
163		maxConnIdleTime:   config.MaxConnIdleTime,
164		healthCheckPeriod: config.HealthCheckPeriod,
165		closeChan:         make(chan struct{}),
166	}
167
168	p.p = puddle.NewPool(
169		func(ctx context.Context) (interface{}, error) {
170			conn, err := pgx.ConnectConfig(ctx, config.ConnConfig)
171			if err != nil {
172				return nil, err
173			}
174
175			if p.afterConnect != nil {
176				err = p.afterConnect(ctx, conn)
177				if err != nil {
178					conn.Close(ctx)
179					return nil, err
180				}
181			}
182
183			cr := &connResource{
184				conn:      conn,
185				conns:     make([]Conn, 64),
186				poolRows:  make([]poolRow, 64),
187				poolRowss: make([]poolRows, 64),
188			}
189
190			return cr, nil
191		},
192		func(value interface{}) {
193			ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
194			value.(*connResource).conn.Close(ctx)
195			cancel()
196		},
197		config.MaxConns,
198	)
199
200	go p.backgroundHealthCheck()
201
202	if !config.LazyConnect {
203		// Initially establish one connection
204		res, err := p.p.Acquire(ctx)
205		if err != nil {
206			p.Close()
207			return nil, err
208		}
209		res.Release()
210	}
211
212	return p, nil
213}
214
215// ParseConfig builds a Config from connString. It parses connString with the same behavior as pgx.ParseConfig with the
216// addition of the following variables:
217//
218// pool_max_conns: integer greater than 0
219// pool_min_conns: integer 0 or greater
220// pool_max_conn_lifetime: duration string
221// pool_max_conn_idle_time: duration string
222// pool_health_check_period: duration string
223//
224// See Config for definitions of these arguments.
225//
226//   # Example DSN
227//   user=jack password=secret host=pg.example.com port=5432 dbname=mydb sslmode=verify-ca pool_max_conns=10
228//
229//   # Example URL
230//   postgres://jack:secret@pg.example.com:5432/mydb?sslmode=verify-ca&pool_max_conns=10
231func ParseConfig(connString string) (*Config, error) {
232	connConfig, err := pgx.ParseConfig(connString)
233	if err != nil {
234		return nil, err
235	}
236
237	config := &Config{
238		ConnConfig:           connConfig,
239		createdByParseConfig: true,
240	}
241
242	if s, ok := config.ConnConfig.Config.RuntimeParams["pool_max_conns"]; ok {
243		delete(connConfig.Config.RuntimeParams, "pool_max_conns")
244		n, err := strconv.ParseInt(s, 10, 32)
245		if err != nil {
246			return nil, errors.Errorf("cannot parse pool_max_conns: %w", err)
247		}
248		if n < 1 {
249			return nil, errors.Errorf("pool_max_conns too small: %d", n)
250		}
251		config.MaxConns = int32(n)
252	} else {
253		config.MaxConns = defaultMaxConns
254		if numCPU := int32(runtime.NumCPU()); numCPU > config.MaxConns {
255			config.MaxConns = numCPU
256		}
257	}
258
259	if s, ok := config.ConnConfig.Config.RuntimeParams["pool_min_conns"]; ok {
260		delete(connConfig.Config.RuntimeParams, "pool_min_conns")
261		n, err := strconv.ParseInt(s, 10, 32)
262		if err != nil {
263			return nil, errors.Errorf("cannot parse pool_min_conns: %w", err)
264		}
265		config.MinConns = int32(n)
266	} else {
267		config.MinConns = defaultMinConns
268	}
269
270	if s, ok := config.ConnConfig.Config.RuntimeParams["pool_max_conn_lifetime"]; ok {
271		delete(connConfig.Config.RuntimeParams, "pool_max_conn_lifetime")
272		d, err := time.ParseDuration(s)
273		if err != nil {
274			return nil, errors.Errorf("invalid pool_max_conn_lifetime: %w", err)
275		}
276		config.MaxConnLifetime = d
277	} else {
278		config.MaxConnLifetime = defaultMaxConnLifetime
279	}
280
281	if s, ok := config.ConnConfig.Config.RuntimeParams["pool_max_conn_idle_time"]; ok {
282		delete(connConfig.Config.RuntimeParams, "pool_max_conn_idle_time")
283		d, err := time.ParseDuration(s)
284		if err != nil {
285			return nil, errors.Errorf("invalid pool_max_conn_idle_time: %w", err)
286		}
287		config.MaxConnIdleTime = d
288	} else {
289		config.MaxConnIdleTime = defaultMaxConnIdleTime
290	}
291
292	if s, ok := config.ConnConfig.Config.RuntimeParams["pool_health_check_period"]; ok {
293		delete(connConfig.Config.RuntimeParams, "pool_health_check_period")
294		d, err := time.ParseDuration(s)
295		if err != nil {
296			return nil, errors.Errorf("invalid pool_health_check_period: %w", err)
297		}
298		config.HealthCheckPeriod = d
299	} else {
300		config.HealthCheckPeriod = defaultHealthCheckPeriod
301	}
302
303	return config, nil
304}
305
306// Close closes all connections in the pool and rejects future Acquire calls. Blocks until all connections are returned
307// to pool and closed.
308func (p *Pool) Close() {
309	close(p.closeChan)
310	p.p.Close()
311}
312
313func (p *Pool) backgroundHealthCheck() {
314	ticker := time.NewTicker(p.healthCheckPeriod)
315
316	for {
317		select {
318		case <-p.closeChan:
319			ticker.Stop()
320			return
321		case <-ticker.C:
322			p.checkIdleConnsHealth()
323			p.checkMinConns()
324		}
325	}
326}
327
328func (p *Pool) checkIdleConnsHealth() {
329	resources := p.p.AcquireAllIdle()
330
331	now := time.Now()
332	for _, res := range resources {
333		if now.Sub(res.CreationTime()) > p.maxConnLifetime {
334			res.Destroy()
335		} else if res.IdleDuration() > p.maxConnIdleTime {
336			res.Destroy()
337		} else {
338			res.ReleaseUnused()
339		}
340	}
341}
342
343func (p *Pool) checkMinConns() {
344	for i := p.minConns - p.Stat().TotalConns(); i > 0; i-- {
345		go func() {
346			ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
347			defer cancel()
348			p.p.CreateResource(ctx)
349		}()
350	}
351}
352
353func (p *Pool) Acquire(ctx context.Context) (*Conn, error) {
354	for {
355		res, err := p.p.Acquire(ctx)
356		if err != nil {
357			return nil, err
358		}
359
360		cr := res.Value().(*connResource)
361		if p.beforeAcquire == nil || p.beforeAcquire(ctx, cr.conn) {
362			return cr.getConn(p, res), nil
363		}
364
365		res.Destroy()
366	}
367}
368
369// AcquireAllIdle atomically acquires all currently idle connections. Its intended use is for health check and
370// keep-alive functionality. It does not update pool statistics.
371func (p *Pool) AcquireAllIdle(ctx context.Context) []*Conn {
372	resources := p.p.AcquireAllIdle()
373	conns := make([]*Conn, 0, len(resources))
374	for _, res := range resources {
375		cr := res.Value().(*connResource)
376		if p.beforeAcquire == nil || p.beforeAcquire(ctx, cr.conn) {
377			conns = append(conns, cr.getConn(p, res))
378		} else {
379			res.Destroy()
380		}
381	}
382
383	return conns
384}
385
386// Config returns a copy of config that was used to initialize this pool.
387func (p *Pool) Config() *Config { return p.config.Copy() }
388
389func (p *Pool) Stat() *Stat {
390	return &Stat{s: p.p.Stat()}
391}
392
393func (p *Pool) Exec(ctx context.Context, sql string, arguments ...interface{}) (pgconn.CommandTag, error) {
394	c, err := p.Acquire(ctx)
395	if err != nil {
396		return nil, err
397	}
398	defer c.Release()
399
400	return c.Exec(ctx, sql, arguments...)
401}
402
403func (p *Pool) Query(ctx context.Context, sql string, args ...interface{}) (pgx.Rows, error) {
404	c, err := p.Acquire(ctx)
405	if err != nil {
406		return errRows{err: err}, err
407	}
408
409	rows, err := c.Query(ctx, sql, args...)
410	if err != nil {
411		c.Release()
412		return errRows{err: err}, err
413	}
414
415	return c.getPoolRows(rows), nil
416}
417
418func (p *Pool) QueryRow(ctx context.Context, sql string, args ...interface{}) pgx.Row {
419	c, err := p.Acquire(ctx)
420	if err != nil {
421		return errRow{err: err}
422	}
423
424	row := c.QueryRow(ctx, sql, args...)
425	return c.getPoolRow(row)
426}
427
428func (p *Pool) SendBatch(ctx context.Context, b *pgx.Batch) pgx.BatchResults {
429	c, err := p.Acquire(ctx)
430	if err != nil {
431		return errBatchResults{err: err}
432	}
433
434	br := c.SendBatch(ctx, b)
435	return &poolBatchResults{br: br, c: c}
436}
437
438func (p *Pool) Begin(ctx context.Context) (pgx.Tx, error) {
439	return p.BeginTx(ctx, pgx.TxOptions{})
440}
441func (p *Pool) BeginTx(ctx context.Context, txOptions pgx.TxOptions) (pgx.Tx, error) {
442	c, err := p.Acquire(ctx)
443	if err != nil {
444		return nil, err
445	}
446
447	t, err := c.BeginTx(ctx, txOptions)
448	if err != nil {
449		c.Release()
450		return nil, err
451	}
452
453	return &Tx{t: t, c: c}, err
454}
455
456func (p *Pool) CopyFrom(ctx context.Context, tableName pgx.Identifier, columnNames []string, rowSrc pgx.CopyFromSource) (int64, error) {
457	c, err := p.Acquire(ctx)
458	if err != nil {
459		return 0, err
460	}
461	defer c.Release()
462
463	return c.Conn().CopyFrom(ctx, tableName, columnNames, rowSrc)
464}
465