1package pgproto3
2
3import (
4	"encoding/binary"
5	"io"
6
7	"github.com/jackc/pgx/chunkreader"
8	"github.com/pkg/errors"
9)
10
11type Backend struct {
12	cr *chunkreader.ChunkReader
13	w  io.Writer
14
15	// Frontend message flyweights
16	bind            Bind
17	_close          Close
18	describe        Describe
19	execute         Execute
20	flush           Flush
21	parse           Parse
22	passwordMessage PasswordMessage
23	query           Query
24	startupMessage  StartupMessage
25	sync            Sync
26	terminate       Terminate
27
28	bodyLen    int
29	msgType    byte
30	partialMsg bool
31}
32
33func NewBackend(r io.Reader, w io.Writer) (*Backend, error) {
34	cr := chunkreader.NewChunkReader(r)
35	return &Backend{cr: cr, w: w}, nil
36}
37
38func (b *Backend) Send(msg BackendMessage) error {
39	_, err := b.w.Write(msg.Encode(nil))
40	return err
41}
42
43func (b *Backend) ReceiveStartupMessage() (*StartupMessage, error) {
44	buf, err := b.cr.Next(4)
45	if err != nil {
46		return nil, err
47	}
48	msgSize := int(binary.BigEndian.Uint32(buf) - 4)
49
50	buf, err = b.cr.Next(msgSize)
51	if err != nil {
52		return nil, err
53	}
54
55	err = b.startupMessage.Decode(buf)
56	if err != nil {
57		return nil, err
58	}
59
60	return &b.startupMessage, nil
61}
62
63func (b *Backend) Receive() (FrontendMessage, error) {
64	if !b.partialMsg {
65		header, err := b.cr.Next(5)
66		if err != nil {
67			return nil, err
68		}
69
70		b.msgType = header[0]
71		b.bodyLen = int(binary.BigEndian.Uint32(header[1:])) - 4
72		b.partialMsg = true
73	}
74
75	var msg FrontendMessage
76	switch b.msgType {
77	case 'B':
78		msg = &b.bind
79	case 'C':
80		msg = &b._close
81	case 'D':
82		msg = &b.describe
83	case 'E':
84		msg = &b.execute
85	case 'H':
86		msg = &b.flush
87	case 'P':
88		msg = &b.parse
89	case 'p':
90		msg = &b.passwordMessage
91	case 'Q':
92		msg = &b.query
93	case 'S':
94		msg = &b.sync
95	case 'X':
96		msg = &b.terminate
97	default:
98		return nil, errors.Errorf("unknown message type: %c", b.msgType)
99	}
100
101	msgBody, err := b.cr.Next(b.bodyLen)
102	if err != nil {
103		return nil, err
104	}
105
106	b.partialMsg = false
107
108	err = msg.Decode(msgBody)
109	return msg, err
110}
111