1package pgx
2
3import (
4	"context"
5	"encoding/binary"
6	"fmt"
7	"strings"
8	"time"
9
10	"github.com/pkg/errors"
11
12	"github.com/jackc/pgx/pgio"
13	"github.com/jackc/pgx/pgproto3"
14	"github.com/jackc/pgx/pgtype"
15)
16
17const (
18	copyBothResponse                  = 'W'
19	walData                           = 'w'
20	senderKeepalive                   = 'k'
21	standbyStatusUpdate               = 'r'
22	initialReplicationResponseTimeout = 5 * time.Second
23)
24
25var epochNano int64
26
27func init() {
28	epochNano = time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC).UnixNano()
29}
30
31// Format the given 64bit LSN value into the XXX/XXX format,
32// which is the format reported by postgres.
33func FormatLSN(lsn uint64) string {
34	return fmt.Sprintf("%X/%X", uint32(lsn>>32), uint32(lsn))
35}
36
37// Parse the given XXX/XXX format LSN as reported by postgres,
38// into a 64 bit integer as used internally by the wire procotols
39func ParseLSN(lsn string) (outputLsn uint64, err error) {
40	var upperHalf uint64
41	var lowerHalf uint64
42	var nparsed int
43	nparsed, err = fmt.Sscanf(lsn, "%X/%X", &upperHalf, &lowerHalf)
44	if err != nil {
45		return
46	}
47
48	if nparsed != 2 {
49		err = errors.New(fmt.Sprintf("Failed to parsed LSN: %s", lsn))
50		return
51	}
52
53	outputLsn = (upperHalf << 32) + lowerHalf
54	return
55}
56
57// The WAL message contains WAL payload entry data
58type WalMessage struct {
59	// The WAL start position of this data. This
60	// is the WAL position we need to track.
61	WalStart uint64
62	// The server wal end and server time are
63	// documented to track the end position and current
64	// time of the server, both of which appear to be
65	// unimplemented in pg 9.5.
66	ServerWalEnd uint64
67	ServerTime   uint64
68	// The WAL data is the raw unparsed binary WAL entry.
69	// The contents of this are determined by the output
70	// logical encoding plugin.
71	WalData []byte
72}
73
74func (w *WalMessage) Time() time.Time {
75	return time.Unix(0, (int64(w.ServerTime)*1000)+epochNano)
76}
77
78func (w *WalMessage) ByteLag() uint64 {
79	return (w.ServerWalEnd - w.WalStart)
80}
81
82func (w *WalMessage) String() string {
83	return fmt.Sprintf("Wal: %s Time: %s Lag: %d", FormatLSN(w.WalStart), w.Time(), w.ByteLag())
84}
85
86// The server heartbeat is sent periodically from the server,
87// including server status, and a reply request field
88type ServerHeartbeat struct {
89	// The current max wal position on the server,
90	// used for lag tracking
91	ServerWalEnd uint64
92	// The server time, in microseconds since jan 1 2000
93	ServerTime uint64
94	// If 1, the server is requesting a standby status message
95	// to be sent immediately.
96	ReplyRequested byte
97}
98
99func (s *ServerHeartbeat) Time() time.Time {
100	return time.Unix(0, (int64(s.ServerTime)*1000)+epochNano)
101}
102
103func (s *ServerHeartbeat) String() string {
104	return fmt.Sprintf("WalEnd: %s ReplyRequested: %d T: %s", FormatLSN(s.ServerWalEnd), s.ReplyRequested, s.Time())
105}
106
107// The replication message wraps all possible messages from the
108// server received during replication. At most one of the wal message
109// or server heartbeat will be non-nil
110type ReplicationMessage struct {
111	WalMessage      *WalMessage
112	ServerHeartbeat *ServerHeartbeat
113}
114
115// The standby status is the client side heartbeat sent to the postgresql
116// server to track the client wal positions. For practical purposes,
117// all wal positions are typically set to the same value.
118type StandbyStatus struct {
119	// The WAL position that's been locally written
120	WalWritePosition uint64
121	// The WAL position that's been locally flushed
122	WalFlushPosition uint64
123	// The WAL position that's been locally applied
124	WalApplyPosition uint64
125	// The client time in microseconds since jan 1 2000
126	ClientTime uint64
127	// If 1, requests the server to immediately send a
128	// server heartbeat
129	ReplyRequested byte
130}
131
132// Create a standby status struct, which sets all the WAL positions
133// to the given wal position, and the client time to the current time.
134// The wal positions are, in order:
135// WalFlushPosition
136// WalApplyPosition
137// WalWritePosition
138//
139// If only one position is provided, it will be used as the value for all 3
140// status fields. Note you must provide either 1 wal position, or all 3
141// in order to initialize the standby status.
142func NewStandbyStatus(walPositions ...uint64) (status *StandbyStatus, err error) {
143	if len(walPositions) == 1 {
144		status = new(StandbyStatus)
145		status.WalFlushPosition = walPositions[0]
146		status.WalApplyPosition = walPositions[0]
147		status.WalWritePosition = walPositions[0]
148	} else if len(walPositions) == 3 {
149		status = new(StandbyStatus)
150		status.WalFlushPosition = walPositions[0]
151		status.WalApplyPosition = walPositions[1]
152		status.WalWritePosition = walPositions[2]
153	} else {
154		err = errors.New(fmt.Sprintf("Invalid number of wal positions provided, need 1 or 3, got %d", len(walPositions)))
155		return
156	}
157	status.ClientTime = uint64((time.Now().UnixNano() - epochNano) / 1000)
158	return
159}
160
161func ReplicationConnect(config ConnConfig) (r *ReplicationConn, err error) {
162	if config.RuntimeParams == nil {
163		config.RuntimeParams = make(map[string]string)
164	}
165	config.RuntimeParams["replication"] = "database"
166
167	c, err := Connect(config)
168	if err != nil {
169		return
170	}
171	return &ReplicationConn{c: c}, nil
172}
173
174type ReplicationConn struct {
175	c *Conn
176}
177
178// Send standby status to the server, which both acts as a keepalive
179// message to the server, as well as carries the WAL position of the
180// client, which then updates the server's replication slot position.
181func (rc *ReplicationConn) SendStandbyStatus(k *StandbyStatus) (err error) {
182	buf := rc.c.wbuf
183	buf = append(buf, copyData)
184	sp := len(buf)
185	buf = pgio.AppendInt32(buf, -1)
186
187	buf = append(buf, standbyStatusUpdate)
188	buf = pgio.AppendInt64(buf, int64(k.WalWritePosition))
189	buf = pgio.AppendInt64(buf, int64(k.WalFlushPosition))
190	buf = pgio.AppendInt64(buf, int64(k.WalApplyPosition))
191	buf = pgio.AppendInt64(buf, int64(k.ClientTime))
192	buf = append(buf, k.ReplyRequested)
193
194	pgio.SetInt32(buf[sp:], int32(len(buf[sp:])))
195
196	_, err = rc.c.conn.Write(buf)
197	if err != nil {
198		rc.c.die(err)
199	}
200
201	return
202}
203
204func (rc *ReplicationConn) Close() error {
205	return rc.c.Close()
206}
207
208func (rc *ReplicationConn) IsAlive() bool {
209	return rc.c.IsAlive()
210}
211
212func (rc *ReplicationConn) CauseOfDeath() error {
213	return rc.c.CauseOfDeath()
214}
215
216func (rc *ReplicationConn) GetConnInfo() *pgtype.ConnInfo {
217	return rc.c.ConnInfo
218}
219
220func (rc *ReplicationConn) readReplicationMessage() (r *ReplicationMessage, err error) {
221	msg, err := rc.c.rxMsg()
222	if err != nil {
223		return
224	}
225
226	switch msg := msg.(type) {
227	case *pgproto3.NoticeResponse:
228		pgError := rc.c.rxErrorResponse((*pgproto3.ErrorResponse)(msg))
229		if rc.c.shouldLog(LogLevelInfo) {
230			rc.c.log(LogLevelInfo, pgError.Error(), nil)
231		}
232	case *pgproto3.ErrorResponse:
233		err = rc.c.rxErrorResponse(msg)
234		if rc.c.shouldLog(LogLevelError) {
235			rc.c.log(LogLevelError, err.Error(), nil)
236		}
237		return
238	case *pgproto3.CopyBothResponse:
239		// This is the tail end of the replication process start,
240		// and can be safely ignored
241		return
242	case *pgproto3.CopyData:
243		msgType := msg.Data[0]
244		rp := 1
245
246		switch msgType {
247		case walData:
248			walStart := binary.BigEndian.Uint64(msg.Data[rp:])
249			rp += 8
250			serverWalEnd := binary.BigEndian.Uint64(msg.Data[rp:])
251			rp += 8
252			serverTime := binary.BigEndian.Uint64(msg.Data[rp:])
253			rp += 8
254			walData := msg.Data[rp:]
255			walMessage := WalMessage{WalStart: walStart,
256				ServerWalEnd: serverWalEnd,
257				ServerTime:   serverTime,
258				WalData:      walData,
259			}
260
261			return &ReplicationMessage{WalMessage: &walMessage}, nil
262		case senderKeepalive:
263			serverWalEnd := binary.BigEndian.Uint64(msg.Data[rp:])
264			rp += 8
265			serverTime := binary.BigEndian.Uint64(msg.Data[rp:])
266			rp += 8
267			replyNow := msg.Data[rp]
268			rp += 1
269			h := &ServerHeartbeat{ServerWalEnd: serverWalEnd, ServerTime: serverTime, ReplyRequested: replyNow}
270			return &ReplicationMessage{ServerHeartbeat: h}, nil
271		default:
272			if rc.c.shouldLog(LogLevelError) {
273				rc.c.log(LogLevelError, "Unexpected data playload message type", map[string]interface{}{"type": msgType})
274			}
275		}
276	default:
277		if rc.c.shouldLog(LogLevelError) {
278			rc.c.log(LogLevelError, "Unexpected replication message type", map[string]interface{}{"type": msg})
279		}
280	}
281	return
282}
283
284// Wait for a single replication message.
285//
286// Properly using this requires some knowledge of the postgres replication mechanisms,
287// as the client can receive both WAL data (the ultimate payload) and server heartbeat
288// updates. The caller also must send standby status updates in order to keep the connection
289// alive and working.
290//
291// This returns the context error when there is no replication message before
292// the context is canceled.
293func (rc *ReplicationConn) WaitForReplicationMessage(ctx context.Context) (*ReplicationMessage, error) {
294	select {
295	case <-ctx.Done():
296		return nil, ctx.Err()
297	default:
298	}
299
300	go func() {
301		select {
302		case <-ctx.Done():
303			if err := rc.c.conn.SetDeadline(time.Now()); err != nil {
304				rc.Close() // Close connection if unable to set deadline
305				return
306			}
307			rc.c.closedChan <- ctx.Err()
308		case <-rc.c.doneChan:
309		}
310	}()
311
312	r, opErr := rc.readReplicationMessage()
313
314	var err error
315	select {
316	case err = <-rc.c.closedChan:
317		if err := rc.c.conn.SetDeadline(time.Time{}); err != nil {
318			rc.Close() // Close connection if unable to disable deadline
319			return nil, err
320		}
321
322		if opErr == nil {
323			err = nil
324		}
325	case rc.c.doneChan <- struct{}{}:
326		err = opErr
327	}
328
329	return r, err
330}
331
332func (rc *ReplicationConn) sendReplicationModeQuery(sql string) (*Rows, error) {
333	rc.c.lastActivityTime = time.Now()
334
335	rows := rc.c.getRows(sql, nil)
336
337	if err := rc.c.lock(); err != nil {
338		rows.fatal(err)
339		return rows, err
340	}
341	rows.unlockConn = true
342
343	err := rc.c.sendSimpleQuery(sql)
344	if err != nil {
345		rows.fatal(err)
346	}
347
348	msg, err := rc.c.rxMsg()
349	if err != nil {
350		return nil, err
351	}
352
353	switch msg := msg.(type) {
354	case *pgproto3.RowDescription:
355		rows.fields = rc.c.rxRowDescription(msg)
356		// We don't have c.PgTypes here because we're a replication
357		// connection. This means the field descriptions will have
358		// only OIDs. Not much we can do about this.
359	default:
360		if e := rc.c.processContextFreeMsg(msg); e != nil {
361			rows.fatal(e)
362			return rows, e
363		}
364	}
365
366	return rows, rows.err
367}
368
369// Execute the "IDENTIFY_SYSTEM" command as documented here:
370// https://www.postgresql.org/docs/9.5/static/protocol-replication.html
371//
372// This will return (if successful) a result set that has a single row
373// that contains the systemid, current timeline, xlogpos and database
374// name.
375//
376// NOTE: Because this is a replication mode connection, we don't have
377// type names, so the field descriptions in the result will have only
378// OIDs and no DataTypeName values
379func (rc *ReplicationConn) IdentifySystem() (r *Rows, err error) {
380	return rc.sendReplicationModeQuery("IDENTIFY_SYSTEM")
381}
382
383// Execute the "TIMELINE_HISTORY" command as documented here:
384// https://www.postgresql.org/docs/9.5/static/protocol-replication.html
385//
386// This will return (if successful) a result set that has a single row
387// that contains the filename of the history file and the content
388// of the history file. If called for timeline 1, typically this will
389// generate an error that the timeline history file does not exist.
390//
391// NOTE: Because this is a replication mode connection, we don't have
392// type names, so the field descriptions in the result will have only
393// OIDs and no DataTypeName values
394func (rc *ReplicationConn) TimelineHistory(timeline int) (r *Rows, err error) {
395	return rc.sendReplicationModeQuery(fmt.Sprintf("TIMELINE_HISTORY %d", timeline))
396}
397
398// Start a replication connection, sending WAL data to the given replication
399// receiver. This function wraps a START_REPLICATION command as documented
400// here:
401// https://www.postgresql.org/docs/9.5/static/protocol-replication.html
402//
403// Once started, the client needs to invoke WaitForReplicationMessage() in order
404// to fetch the WAL and standby status. Also, it is the responsibility of the caller
405// to periodically send StandbyStatus messages to update the replication slot position.
406//
407// This function assumes that slotName has already been created. In order to omit the timeline argument
408// pass a -1 for the timeline to get the server default behavior.
409func (rc *ReplicationConn) StartReplication(slotName string, startLsn uint64, timeline int64, pluginArguments ...string) (err error) {
410	queryString := fmt.Sprintf("START_REPLICATION SLOT %s LOGICAL %s", slotName, FormatLSN(startLsn))
411	if timeline >= 0 {
412		timelineOption := fmt.Sprintf("TIMELINE %d", timeline)
413		pluginArguments = append(pluginArguments, timelineOption)
414	}
415
416	if len(pluginArguments) > 0 {
417		queryString += fmt.Sprintf(" ( %s )", strings.Join(pluginArguments, ", "))
418	}
419
420	if err = rc.c.sendQuery(queryString); err != nil {
421		return
422	}
423
424	ctx, cancelFn := context.WithTimeout(context.Background(), initialReplicationResponseTimeout)
425	defer cancelFn()
426
427	// The first replication message that comes back here will be (in a success case)
428	// a empty CopyBoth that is (apparently) sent as the confirmation that the replication has
429	// started. This call will either return nil, nil or if it returns an error
430	// that indicates the start replication command failed
431	var r *ReplicationMessage
432	r, err = rc.WaitForReplicationMessage(ctx)
433	if err != nil && r != nil {
434		if rc.c.shouldLog(LogLevelError) {
435			rc.c.log(LogLevelError, "Unexpected replication message", map[string]interface{}{"msg": r, "err": err})
436		}
437	}
438
439	return
440}
441
442// Create the replication slot, using the given name and output plugin.
443func (rc *ReplicationConn) CreateReplicationSlot(slotName, outputPlugin string) (err error) {
444	_, err = rc.c.Exec(fmt.Sprintf("CREATE_REPLICATION_SLOT %s LOGICAL %s NOEXPORT_SNAPSHOT", slotName, outputPlugin))
445	return
446}
447
448// Create the replication slot, using the given name and output plugin, and return the consistent_point and snapshot_name values.
449func (rc *ReplicationConn) CreateReplicationSlotEx(slotName, outputPlugin string) (consistentPoint string, snapshotName string, err error) {
450	var dummy string
451	var rows *Rows
452	rows, err = rc.sendReplicationModeQuery(fmt.Sprintf("CREATE_REPLICATION_SLOT %s LOGICAL %s", slotName, outputPlugin))
453	defer rows.Close()
454	for rows.Next() {
455		rows.Scan(&dummy, &consistentPoint, &snapshotName, &dummy)
456	}
457	return
458}
459
460// Drop the replication slot for the given name
461func (rc *ReplicationConn) DropReplicationSlot(slotName string) (err error) {
462	_, err = rc.c.Exec(fmt.Sprintf("DROP_REPLICATION_SLOT %s", slotName))
463	return
464}
465