1package pgproto3
2
3import (
4	"encoding/binary"
5	"io"
6
7	"github.com/jackc/pgx/chunkreader"
8	"github.com/pkg/errors"
9)
10
11type Frontend struct {
12	cr *chunkreader.ChunkReader
13	w  io.Writer
14
15	// Backend message flyweights
16	authentication       Authentication
17	backendKeyData       BackendKeyData
18	bindComplete         BindComplete
19	closeComplete        CloseComplete
20	commandComplete      CommandComplete
21	copyBothResponse     CopyBothResponse
22	copyData             CopyData
23	copyInResponse       CopyInResponse
24	copyOutResponse      CopyOutResponse
25	copyDone             CopyDone
26	dataRow              DataRow
27	emptyQueryResponse   EmptyQueryResponse
28	errorResponse        ErrorResponse
29	functionCallResponse FunctionCallResponse
30	noData               NoData
31	noticeResponse       NoticeResponse
32	notificationResponse NotificationResponse
33	parameterDescription ParameterDescription
34	parameterStatus      ParameterStatus
35	parseComplete        ParseComplete
36	readyForQuery        ReadyForQuery
37	rowDescription       RowDescription
38
39	bodyLen    int
40	msgType    byte
41	partialMsg bool
42}
43
44func NewFrontend(r io.Reader, w io.Writer) (*Frontend, error) {
45	cr := chunkreader.NewChunkReader(r)
46	return &Frontend{cr: cr, w: w}, nil
47}
48
49func (b *Frontend) Send(msg FrontendMessage) error {
50	_, err := b.w.Write(msg.Encode(nil))
51	return err
52}
53
54func (b *Frontend) Receive() (BackendMessage, error) {
55	if !b.partialMsg {
56		header, err := b.cr.Next(5)
57		if err != nil {
58			return nil, err
59		}
60
61		b.msgType = header[0]
62		b.bodyLen = int(binary.BigEndian.Uint32(header[1:])) - 4
63		b.partialMsg = true
64	}
65
66	var msg BackendMessage
67	switch b.msgType {
68	case '1':
69		msg = &b.parseComplete
70	case '2':
71		msg = &b.bindComplete
72	case '3':
73		msg = &b.closeComplete
74	case 'A':
75		msg = &b.notificationResponse
76	case 'c':
77		msg = &b.copyDone
78	case 'C':
79		msg = &b.commandComplete
80	case 'd':
81		msg = &b.copyData
82	case 'D':
83		msg = &b.dataRow
84	case 'E':
85		msg = &b.errorResponse
86	case 'G':
87		msg = &b.copyInResponse
88	case 'H':
89		msg = &b.copyOutResponse
90	case 'I':
91		msg = &b.emptyQueryResponse
92	case 'K':
93		msg = &b.backendKeyData
94	case 'n':
95		msg = &b.noData
96	case 'N':
97		msg = &b.noticeResponse
98	case 'R':
99		msg = &b.authentication
100	case 'S':
101		msg = &b.parameterStatus
102	case 't':
103		msg = &b.parameterDescription
104	case 'T':
105		msg = &b.rowDescription
106	case 'V':
107		msg = &b.functionCallResponse
108	case 'W':
109		msg = &b.copyBothResponse
110	case 'Z':
111		msg = &b.readyForQuery
112	default:
113		return nil, errors.Errorf("unknown message type: %c", b.msgType)
114	}
115
116	msgBody, err := b.cr.Next(b.bodyLen)
117	if err != nil {
118		return nil, err
119	}
120
121	b.partialMsg = false
122
123	err = msg.Decode(msgBody)
124	return msg, err
125}
126