1package pgproto3 2 3import ( 4 "encoding/binary" 5 "io" 6 7 "github.com/jackc/pgx/chunkreader" 8 "github.com/pkg/errors" 9) 10 11type Backend struct { 12 cr *chunkreader.ChunkReader 13 w io.Writer 14 15 // Frontend message flyweights 16 bind Bind 17 _close Close 18 describe Describe 19 execute Execute 20 flush Flush 21 parse Parse 22 passwordMessage PasswordMessage 23 query Query 24 startupMessage StartupMessage 25 sync Sync 26 terminate Terminate 27 28 bodyLen int 29 msgType byte 30 partialMsg bool 31} 32 33func NewBackend(r io.Reader, w io.Writer) (*Backend, error) { 34 cr := chunkreader.NewChunkReader(r) 35 return &Backend{cr: cr, w: w}, nil 36} 37 38func (b *Backend) Send(msg BackendMessage) error { 39 _, err := b.w.Write(msg.Encode(nil)) 40 return err 41} 42 43func (b *Backend) ReceiveStartupMessage() (*StartupMessage, error) { 44 buf, err := b.cr.Next(4) 45 if err != nil { 46 return nil, err 47 } 48 msgSize := int(binary.BigEndian.Uint32(buf) - 4) 49 50 buf, err = b.cr.Next(msgSize) 51 if err != nil { 52 return nil, err 53 } 54 55 err = b.startupMessage.Decode(buf) 56 if err != nil { 57 return nil, err 58 } 59 60 return &b.startupMessage, nil 61} 62 63func (b *Backend) Receive() (FrontendMessage, error) { 64 if !b.partialMsg { 65 header, err := b.cr.Next(5) 66 if err != nil { 67 return nil, err 68 } 69 70 b.msgType = header[0] 71 b.bodyLen = int(binary.BigEndian.Uint32(header[1:])) - 4 72 b.partialMsg = true 73 } 74 75 var msg FrontendMessage 76 switch b.msgType { 77 case 'B': 78 msg = &b.bind 79 case 'C': 80 msg = &b._close 81 case 'D': 82 msg = &b.describe 83 case 'E': 84 msg = &b.execute 85 case 'H': 86 msg = &b.flush 87 case 'P': 88 msg = &b.parse 89 case 'p': 90 msg = &b.passwordMessage 91 case 'Q': 92 msg = &b.query 93 case 'S': 94 msg = &b.sync 95 case 'X': 96 msg = &b.terminate 97 default: 98 return nil, errors.Errorf("unknown message type: %c", b.msgType) 99 } 100 101 msgBody, err := b.cr.Next(b.bodyLen) 102 if err != nil { 103 return nil, err 104 } 105 106 b.partialMsg = false 107 108 err = msg.Decode(msgBody) 109 return msg, err 110} 111