1package pgproto3
2
3import (
4	"encoding/binary"
5	"fmt"
6	"io"
7)
8
9// Backend acts as a server for the PostgreSQL wire protocol version 3.
10type Backend struct {
11	cr ChunkReader
12	w  io.Writer
13
14	// Frontend message flyweights
15	bind            Bind
16	cancelRequest   CancelRequest
17	_close          Close
18	copyFail        CopyFail
19	describe        Describe
20	execute         Execute
21	flush           Flush
22	gssEncRequest   GSSEncRequest
23	parse           Parse
24	passwordMessage PasswordMessage
25	query           Query
26	sslRequest      SSLRequest
27	startupMessage  StartupMessage
28	sync            Sync
29	terminate       Terminate
30
31	bodyLen    int
32	msgType    byte
33	partialMsg bool
34}
35
36// NewBackend creates a new Backend.
37func NewBackend(cr ChunkReader, w io.Writer) *Backend {
38	return &Backend{cr: cr, w: w}
39}
40
41// Send sends a message to the frontend.
42func (b *Backend) Send(msg BackendMessage) error {
43	_, err := b.w.Write(msg.Encode(nil))
44	return err
45}
46
47// ReceiveStartupMessage receives the initial connection message. This method is used of the normal Receive method
48// because the initial connection message is "special" and does not include the message type as the first byte. This
49// will return either a StartupMessage, SSLRequest, GSSEncRequest, or CancelRequest.
50func (b *Backend) ReceiveStartupMessage() (FrontendMessage, error) {
51	buf, err := b.cr.Next(4)
52	if err != nil {
53		return nil, err
54	}
55	msgSize := int(binary.BigEndian.Uint32(buf) - 4)
56
57	buf, err = b.cr.Next(msgSize)
58	if err != nil {
59		return nil, err
60	}
61
62	code := binary.BigEndian.Uint32(buf)
63
64	switch code {
65	case ProtocolVersionNumber:
66		err = b.startupMessage.Decode(buf)
67		if err != nil {
68			return nil, err
69		}
70		return &b.startupMessage, nil
71	case sslRequestNumber:
72		err = b.sslRequest.Decode(buf)
73		if err != nil {
74			return nil, err
75		}
76		return &b.sslRequest, nil
77	case cancelRequestCode:
78		err = b.cancelRequest.Decode(buf)
79		if err != nil {
80			return nil, err
81		}
82		return &b.cancelRequest, nil
83	case gssEncReqNumber:
84		err = b.gssEncRequest.Decode(buf)
85		if err != nil {
86			return nil, err
87		}
88		return &b.gssEncRequest, nil
89	default:
90		return nil, fmt.Errorf("unknown startup message code: %d", code)
91	}
92}
93
94// Receive receives a message from the frontend. The returned message is only valid until the next call to Receive.
95func (b *Backend) Receive() (FrontendMessage, error) {
96	if !b.partialMsg {
97		header, err := b.cr.Next(5)
98		if err != nil {
99			return nil, err
100		}
101
102		b.msgType = header[0]
103		b.bodyLen = int(binary.BigEndian.Uint32(header[1:])) - 4
104		b.partialMsg = true
105	}
106
107	var msg FrontendMessage
108	switch b.msgType {
109	case 'B':
110		msg = &b.bind
111	case 'C':
112		msg = &b._close
113	case 'D':
114		msg = &b.describe
115	case 'E':
116		msg = &b.execute
117	case 'f':
118		msg = &b.copyFail
119	case 'H':
120		msg = &b.flush
121	case 'P':
122		msg = &b.parse
123	case 'p':
124		msg = &b.passwordMessage
125	case 'Q':
126		msg = &b.query
127	case 'S':
128		msg = &b.sync
129	case 'X':
130		msg = &b.terminate
131	default:
132		return nil, fmt.Errorf("unknown message type: %c", b.msgType)
133	}
134
135	msgBody, err := b.cr.Next(b.bodyLen)
136	if err != nil {
137		return nil, err
138	}
139
140	b.partialMsg = false
141
142	err = msg.Decode(msgBody)
143	return msg, err
144}
145