1package pgx 2 3import ( 4 "context" 5 "encoding/binary" 6 "fmt" 7 "strings" 8 "time" 9 10 "github.com/pkg/errors" 11 12 "github.com/jackc/pgx/pgio" 13 "github.com/jackc/pgx/pgproto3" 14 "github.com/jackc/pgx/pgtype" 15) 16 17const ( 18 copyBothResponse = 'W' 19 walData = 'w' 20 senderKeepalive = 'k' 21 standbyStatusUpdate = 'r' 22 initialReplicationResponseTimeout = 5 * time.Second 23) 24 25var epochNano int64 26 27func init() { 28 epochNano = time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC).UnixNano() 29} 30 31// Format the given 64bit LSN value into the XXX/XXX format, 32// which is the format reported by postgres. 33func FormatLSN(lsn uint64) string { 34 return fmt.Sprintf("%X/%X", uint32(lsn>>32), uint32(lsn)) 35} 36 37// Parse the given XXX/XXX format LSN as reported by postgres, 38// into a 64 bit integer as used internally by the wire procotols 39func ParseLSN(lsn string) (outputLsn uint64, err error) { 40 var upperHalf uint64 41 var lowerHalf uint64 42 var nparsed int 43 nparsed, err = fmt.Sscanf(lsn, "%X/%X", &upperHalf, &lowerHalf) 44 if err != nil { 45 return 46 } 47 48 if nparsed != 2 { 49 err = errors.New(fmt.Sprintf("Failed to parsed LSN: %s", lsn)) 50 return 51 } 52 53 outputLsn = (upperHalf << 32) + lowerHalf 54 return 55} 56 57// The WAL message contains WAL payload entry data 58type WalMessage struct { 59 // The WAL start position of this data. This 60 // is the WAL position we need to track. 61 WalStart uint64 62 // The server wal end and server time are 63 // documented to track the end position and current 64 // time of the server, both of which appear to be 65 // unimplemented in pg 9.5. 66 ServerWalEnd uint64 67 ServerTime uint64 68 // The WAL data is the raw unparsed binary WAL entry. 69 // The contents of this are determined by the output 70 // logical encoding plugin. 71 WalData []byte 72} 73 74func (w *WalMessage) Time() time.Time { 75 return time.Unix(0, (int64(w.ServerTime)*1000)+epochNano) 76} 77 78func (w *WalMessage) ByteLag() uint64 { 79 return (w.ServerWalEnd - w.WalStart) 80} 81 82func (w *WalMessage) String() string { 83 return fmt.Sprintf("Wal: %s Time: %s Lag: %d", FormatLSN(w.WalStart), w.Time(), w.ByteLag()) 84} 85 86// The server heartbeat is sent periodically from the server, 87// including server status, and a reply request field 88type ServerHeartbeat struct { 89 // The current max wal position on the server, 90 // used for lag tracking 91 ServerWalEnd uint64 92 // The server time, in microseconds since jan 1 2000 93 ServerTime uint64 94 // If 1, the server is requesting a standby status message 95 // to be sent immediately. 96 ReplyRequested byte 97} 98 99func (s *ServerHeartbeat) Time() time.Time { 100 return time.Unix(0, (int64(s.ServerTime)*1000)+epochNano) 101} 102 103func (s *ServerHeartbeat) String() string { 104 return fmt.Sprintf("WalEnd: %s ReplyRequested: %d T: %s", FormatLSN(s.ServerWalEnd), s.ReplyRequested, s.Time()) 105} 106 107// The replication message wraps all possible messages from the 108// server received during replication. At most one of the wal message 109// or server heartbeat will be non-nil 110type ReplicationMessage struct { 111 WalMessage *WalMessage 112 ServerHeartbeat *ServerHeartbeat 113} 114 115// The standby status is the client side heartbeat sent to the postgresql 116// server to track the client wal positions. For practical purposes, 117// all wal positions are typically set to the same value. 118type StandbyStatus struct { 119 // The WAL position that's been locally written 120 WalWritePosition uint64 121 // The WAL position that's been locally flushed 122 WalFlushPosition uint64 123 // The WAL position that's been locally applied 124 WalApplyPosition uint64 125 // The client time in microseconds since jan 1 2000 126 ClientTime uint64 127 // If 1, requests the server to immediately send a 128 // server heartbeat 129 ReplyRequested byte 130} 131 132// Create a standby status struct, which sets all the WAL positions 133// to the given wal position, and the client time to the current time. 134// The wal positions are, in order: 135// WalFlushPosition 136// WalApplyPosition 137// WalWritePosition 138// 139// If only one position is provided, it will be used as the value for all 3 140// status fields. Note you must provide either 1 wal position, or all 3 141// in order to initialize the standby status. 142func NewStandbyStatus(walPositions ...uint64) (status *StandbyStatus, err error) { 143 if len(walPositions) == 1 { 144 status = new(StandbyStatus) 145 status.WalFlushPosition = walPositions[0] 146 status.WalApplyPosition = walPositions[0] 147 status.WalWritePosition = walPositions[0] 148 } else if len(walPositions) == 3 { 149 status = new(StandbyStatus) 150 status.WalFlushPosition = walPositions[0] 151 status.WalApplyPosition = walPositions[1] 152 status.WalWritePosition = walPositions[2] 153 } else { 154 err = errors.New(fmt.Sprintf("Invalid number of wal positions provided, need 1 or 3, got %d", len(walPositions))) 155 return 156 } 157 status.ClientTime = uint64((time.Now().UnixNano() - epochNano) / 1000) 158 return 159} 160 161func ReplicationConnect(config ConnConfig) (r *ReplicationConn, err error) { 162 if config.RuntimeParams == nil { 163 config.RuntimeParams = make(map[string]string) 164 } 165 config.RuntimeParams["replication"] = "database" 166 167 c, err := Connect(config) 168 if err != nil { 169 return 170 } 171 return &ReplicationConn{c: c}, nil 172} 173 174type ReplicationConn struct { 175 c *Conn 176} 177 178// Send standby status to the server, which both acts as a keepalive 179// message to the server, as well as carries the WAL position of the 180// client, which then updates the server's replication slot position. 181func (rc *ReplicationConn) SendStandbyStatus(k *StandbyStatus) (err error) { 182 buf := rc.c.wbuf 183 buf = append(buf, copyData) 184 sp := len(buf) 185 buf = pgio.AppendInt32(buf, -1) 186 187 buf = append(buf, standbyStatusUpdate) 188 buf = pgio.AppendInt64(buf, int64(k.WalWritePosition)) 189 buf = pgio.AppendInt64(buf, int64(k.WalFlushPosition)) 190 buf = pgio.AppendInt64(buf, int64(k.WalApplyPosition)) 191 buf = pgio.AppendInt64(buf, int64(k.ClientTime)) 192 buf = append(buf, k.ReplyRequested) 193 194 pgio.SetInt32(buf[sp:], int32(len(buf[sp:]))) 195 196 _, err = rc.c.conn.Write(buf) 197 if err != nil { 198 rc.c.die(err) 199 } 200 201 return 202} 203 204func (rc *ReplicationConn) Close() error { 205 return rc.c.Close() 206} 207 208func (rc *ReplicationConn) IsAlive() bool { 209 return rc.c.IsAlive() 210} 211 212func (rc *ReplicationConn) CauseOfDeath() error { 213 return rc.c.CauseOfDeath() 214} 215 216func (rc *ReplicationConn) GetConnInfo() *pgtype.ConnInfo { 217 return rc.c.ConnInfo 218} 219 220func (rc *ReplicationConn) readReplicationMessage() (r *ReplicationMessage, err error) { 221 msg, err := rc.c.rxMsg() 222 if err != nil { 223 return 224 } 225 226 switch msg := msg.(type) { 227 case *pgproto3.NoticeResponse: 228 pgError := rc.c.rxErrorResponse((*pgproto3.ErrorResponse)(msg)) 229 if rc.c.shouldLog(LogLevelInfo) { 230 rc.c.log(LogLevelInfo, pgError.Error(), nil) 231 } 232 case *pgproto3.ErrorResponse: 233 err = rc.c.rxErrorResponse(msg) 234 if rc.c.shouldLog(LogLevelError) { 235 rc.c.log(LogLevelError, err.Error(), nil) 236 } 237 return 238 case *pgproto3.CopyBothResponse: 239 // This is the tail end of the replication process start, 240 // and can be safely ignored 241 return 242 case *pgproto3.CopyData: 243 msgType := msg.Data[0] 244 rp := 1 245 246 switch msgType { 247 case walData: 248 walStart := binary.BigEndian.Uint64(msg.Data[rp:]) 249 rp += 8 250 serverWalEnd := binary.BigEndian.Uint64(msg.Data[rp:]) 251 rp += 8 252 serverTime := binary.BigEndian.Uint64(msg.Data[rp:]) 253 rp += 8 254 walData := msg.Data[rp:] 255 walMessage := WalMessage{WalStart: walStart, 256 ServerWalEnd: serverWalEnd, 257 ServerTime: serverTime, 258 WalData: walData, 259 } 260 261 return &ReplicationMessage{WalMessage: &walMessage}, nil 262 case senderKeepalive: 263 serverWalEnd := binary.BigEndian.Uint64(msg.Data[rp:]) 264 rp += 8 265 serverTime := binary.BigEndian.Uint64(msg.Data[rp:]) 266 rp += 8 267 replyNow := msg.Data[rp] 268 rp += 1 269 h := &ServerHeartbeat{ServerWalEnd: serverWalEnd, ServerTime: serverTime, ReplyRequested: replyNow} 270 return &ReplicationMessage{ServerHeartbeat: h}, nil 271 default: 272 if rc.c.shouldLog(LogLevelError) { 273 rc.c.log(LogLevelError, "Unexpected data playload message type", map[string]interface{}{"type": msgType}) 274 } 275 } 276 default: 277 if rc.c.shouldLog(LogLevelError) { 278 rc.c.log(LogLevelError, "Unexpected replication message type", map[string]interface{}{"type": msg}) 279 } 280 } 281 return 282} 283 284// Wait for a single replication message. 285// 286// Properly using this requires some knowledge of the postgres replication mechanisms, 287// as the client can receive both WAL data (the ultimate payload) and server heartbeat 288// updates. The caller also must send standby status updates in order to keep the connection 289// alive and working. 290// 291// This returns the context error when there is no replication message before 292// the context is canceled. 293func (rc *ReplicationConn) WaitForReplicationMessage(ctx context.Context) (*ReplicationMessage, error) { 294 select { 295 case <-ctx.Done(): 296 return nil, ctx.Err() 297 default: 298 } 299 300 go func() { 301 select { 302 case <-ctx.Done(): 303 if err := rc.c.conn.SetDeadline(time.Now()); err != nil { 304 rc.Close() // Close connection if unable to set deadline 305 return 306 } 307 rc.c.closedChan <- ctx.Err() 308 case <-rc.c.doneChan: 309 } 310 }() 311 312 r, opErr := rc.readReplicationMessage() 313 314 var err error 315 select { 316 case err = <-rc.c.closedChan: 317 if err := rc.c.conn.SetDeadline(time.Time{}); err != nil { 318 rc.Close() // Close connection if unable to disable deadline 319 return nil, err 320 } 321 322 if opErr == nil { 323 err = nil 324 } 325 case rc.c.doneChan <- struct{}{}: 326 err = opErr 327 } 328 329 return r, err 330} 331 332func (rc *ReplicationConn) sendReplicationModeQuery(sql string) (*Rows, error) { 333 rc.c.lastActivityTime = time.Now() 334 335 rows := rc.c.getRows(sql, nil) 336 337 if err := rc.c.lock(); err != nil { 338 rows.fatal(err) 339 return rows, err 340 } 341 rows.unlockConn = true 342 343 err := rc.c.sendSimpleQuery(sql) 344 if err != nil { 345 rows.fatal(err) 346 } 347 348 msg, err := rc.c.rxMsg() 349 if err != nil { 350 return nil, err 351 } 352 353 switch msg := msg.(type) { 354 case *pgproto3.RowDescription: 355 rows.fields = rc.c.rxRowDescription(msg) 356 // We don't have c.PgTypes here because we're a replication 357 // connection. This means the field descriptions will have 358 // only OIDs. Not much we can do about this. 359 default: 360 if e := rc.c.processContextFreeMsg(msg); e != nil { 361 rows.fatal(e) 362 return rows, e 363 } 364 } 365 366 return rows, rows.err 367} 368 369// Execute the "IDENTIFY_SYSTEM" command as documented here: 370// https://www.postgresql.org/docs/9.5/static/protocol-replication.html 371// 372// This will return (if successful) a result set that has a single row 373// that contains the systemid, current timeline, xlogpos and database 374// name. 375// 376// NOTE: Because this is a replication mode connection, we don't have 377// type names, so the field descriptions in the result will have only 378// OIDs and no DataTypeName values 379func (rc *ReplicationConn) IdentifySystem() (r *Rows, err error) { 380 return rc.sendReplicationModeQuery("IDENTIFY_SYSTEM") 381} 382 383// Execute the "TIMELINE_HISTORY" command as documented here: 384// https://www.postgresql.org/docs/9.5/static/protocol-replication.html 385// 386// This will return (if successful) a result set that has a single row 387// that contains the filename of the history file and the content 388// of the history file. If called for timeline 1, typically this will 389// generate an error that the timeline history file does not exist. 390// 391// NOTE: Because this is a replication mode connection, we don't have 392// type names, so the field descriptions in the result will have only 393// OIDs and no DataTypeName values 394func (rc *ReplicationConn) TimelineHistory(timeline int) (r *Rows, err error) { 395 return rc.sendReplicationModeQuery(fmt.Sprintf("TIMELINE_HISTORY %d", timeline)) 396} 397 398// Start a replication connection, sending WAL data to the given replication 399// receiver. This function wraps a START_REPLICATION command as documented 400// here: 401// https://www.postgresql.org/docs/9.5/static/protocol-replication.html 402// 403// Once started, the client needs to invoke WaitForReplicationMessage() in order 404// to fetch the WAL and standby status. Also, it is the responsibility of the caller 405// to periodically send StandbyStatus messages to update the replication slot position. 406// 407// This function assumes that slotName has already been created. In order to omit the timeline argument 408// pass a -1 for the timeline to get the server default behavior. 409func (rc *ReplicationConn) StartReplication(slotName string, startLsn uint64, timeline int64, pluginArguments ...string) (err error) { 410 queryString := fmt.Sprintf("START_REPLICATION SLOT %s LOGICAL %s", slotName, FormatLSN(startLsn)) 411 if timeline >= 0 { 412 timelineOption := fmt.Sprintf("TIMELINE %d", timeline) 413 pluginArguments = append(pluginArguments, timelineOption) 414 } 415 416 if len(pluginArguments) > 0 { 417 queryString += fmt.Sprintf(" ( %s )", strings.Join(pluginArguments, ", ")) 418 } 419 420 if err = rc.c.sendQuery(queryString); err != nil { 421 return 422 } 423 424 ctx, cancelFn := context.WithTimeout(context.Background(), initialReplicationResponseTimeout) 425 defer cancelFn() 426 427 // The first replication message that comes back here will be (in a success case) 428 // a empty CopyBoth that is (apparently) sent as the confirmation that the replication has 429 // started. This call will either return nil, nil or if it returns an error 430 // that indicates the start replication command failed 431 var r *ReplicationMessage 432 r, err = rc.WaitForReplicationMessage(ctx) 433 if err != nil && r != nil { 434 if rc.c.shouldLog(LogLevelError) { 435 rc.c.log(LogLevelError, "Unexpected replication message", map[string]interface{}{"msg": r, "err": err}) 436 } 437 } 438 439 return 440} 441 442// Create the replication slot, using the given name and output plugin. 443func (rc *ReplicationConn) CreateReplicationSlot(slotName, outputPlugin string) (err error) { 444 _, err = rc.c.Exec(fmt.Sprintf("CREATE_REPLICATION_SLOT %s LOGICAL %s NOEXPORT_SNAPSHOT", slotName, outputPlugin)) 445 return 446} 447 448// Create the replication slot, using the given name and output plugin, and return the consistent_point and snapshot_name values. 449func (rc *ReplicationConn) CreateReplicationSlotEx(slotName, outputPlugin string) (consistentPoint string, snapshotName string, err error) { 450 var dummy string 451 var rows *Rows 452 rows, err = rc.sendReplicationModeQuery(fmt.Sprintf("CREATE_REPLICATION_SLOT %s LOGICAL %s", slotName, outputPlugin)) 453 defer rows.Close() 454 for rows.Next() { 455 rows.Scan(&dummy, &consistentPoint, &snapshotName, &dummy) 456 } 457 return 458} 459 460// Drop the replication slot for the given name 461func (rc *ReplicationConn) DropReplicationSlot(slotName string) (err error) { 462 _, err = rc.c.Exec(fmt.Sprintf("DROP_REPLICATION_SLOT %s", slotName)) 463 return 464} 465