1package pgproto3 2 3import ( 4 "encoding/binary" 5 "fmt" 6 "io" 7) 8 9// Backend acts as a server for the PostgreSQL wire protocol version 3. 10type Backend struct { 11 cr ChunkReader 12 w io.Writer 13 14 // Frontend message flyweights 15 bind Bind 16 cancelRequest CancelRequest 17 _close Close 18 copyFail CopyFail 19 describe Describe 20 execute Execute 21 flush Flush 22 gssEncRequest GSSEncRequest 23 parse Parse 24 passwordMessage PasswordMessage 25 query Query 26 sslRequest SSLRequest 27 startupMessage StartupMessage 28 sync Sync 29 terminate Terminate 30 31 bodyLen int 32 msgType byte 33 partialMsg bool 34} 35 36// NewBackend creates a new Backend. 37func NewBackend(cr ChunkReader, w io.Writer) *Backend { 38 return &Backend{cr: cr, w: w} 39} 40 41// Send sends a message to the frontend. 42func (b *Backend) Send(msg BackendMessage) error { 43 _, err := b.w.Write(msg.Encode(nil)) 44 return err 45} 46 47// ReceiveStartupMessage receives the initial connection message. This method is used of the normal Receive method 48// because the initial connection message is "special" and does not include the message type as the first byte. This 49// will return either a StartupMessage, SSLRequest, GSSEncRequest, or CancelRequest. 50func (b *Backend) ReceiveStartupMessage() (FrontendMessage, error) { 51 buf, err := b.cr.Next(4) 52 if err != nil { 53 return nil, err 54 } 55 msgSize := int(binary.BigEndian.Uint32(buf) - 4) 56 57 buf, err = b.cr.Next(msgSize) 58 if err != nil { 59 return nil, err 60 } 61 62 code := binary.BigEndian.Uint32(buf) 63 64 switch code { 65 case ProtocolVersionNumber: 66 err = b.startupMessage.Decode(buf) 67 if err != nil { 68 return nil, err 69 } 70 return &b.startupMessage, nil 71 case sslRequestNumber: 72 err = b.sslRequest.Decode(buf) 73 if err != nil { 74 return nil, err 75 } 76 return &b.sslRequest, nil 77 case cancelRequestCode: 78 err = b.cancelRequest.Decode(buf) 79 if err != nil { 80 return nil, err 81 } 82 return &b.cancelRequest, nil 83 case gssEncReqNumber: 84 err = b.gssEncRequest.Decode(buf) 85 if err != nil { 86 return nil, err 87 } 88 return &b.gssEncRequest, nil 89 default: 90 return nil, fmt.Errorf("unknown startup message code: %d", code) 91 } 92} 93 94// Receive receives a message from the frontend. The returned message is only valid until the next call to Receive. 95func (b *Backend) Receive() (FrontendMessage, error) { 96 if !b.partialMsg { 97 header, err := b.cr.Next(5) 98 if err != nil { 99 return nil, err 100 } 101 102 b.msgType = header[0] 103 b.bodyLen = int(binary.BigEndian.Uint32(header[1:])) - 4 104 b.partialMsg = true 105 } 106 107 var msg FrontendMessage 108 switch b.msgType { 109 case 'B': 110 msg = &b.bind 111 case 'C': 112 msg = &b._close 113 case 'D': 114 msg = &b.describe 115 case 'E': 116 msg = &b.execute 117 case 'f': 118 msg = &b.copyFail 119 case 'H': 120 msg = &b.flush 121 case 'P': 122 msg = &b.parse 123 case 'p': 124 msg = &b.passwordMessage 125 case 'Q': 126 msg = &b.query 127 case 'S': 128 msg = &b.sync 129 case 'X': 130 msg = &b.terminate 131 default: 132 return nil, fmt.Errorf("unknown message type: %c", b.msgType) 133 } 134 135 msgBody, err := b.cr.Next(b.bodyLen) 136 if err != nil { 137 return nil, err 138 } 139 140 b.partialMsg = false 141 142 err = msg.Decode(msgBody) 143 return msg, err 144} 145