1package pgproto3
2
3import (
4	"encoding/binary"
5	"fmt"
6	"io"
7)
8
9// Backend acts as a server for the PostgreSQL wire protocol version 3.
10type Backend struct {
11	cr ChunkReader
12	w  io.Writer
13
14	// Frontend message flyweights
15	bind            Bind
16	cancelRequest   CancelRequest
17	_close          Close
18	copyFail        CopyFail
19	describe        Describe
20	execute         Execute
21	flush           Flush
22	parse           Parse
23	passwordMessage PasswordMessage
24	query           Query
25	sslRequest      SSLRequest
26	startupMessage  StartupMessage
27	sync            Sync
28	terminate       Terminate
29
30	bodyLen    int
31	msgType    byte
32	partialMsg bool
33}
34
35// NewBackend creates a new Backend.
36func NewBackend(cr ChunkReader, w io.Writer) *Backend {
37	return &Backend{cr: cr, w: w}
38}
39
40// Send sends a message to the frontend.
41func (b *Backend) Send(msg BackendMessage) error {
42	_, err := b.w.Write(msg.Encode(nil))
43	return err
44}
45
46// ReceiveStartupMessage receives the initial connection message. This method is used of the normal Receive method
47// because the initial connection message is "special" and does not include the message type as the first byte. This
48// will return either a StartupMessage, SSLRequest, or CancelRequest.
49func (b *Backend) ReceiveStartupMessage() (FrontendMessage, error) {
50	buf, err := b.cr.Next(4)
51	if err != nil {
52		return nil, err
53	}
54	msgSize := int(binary.BigEndian.Uint32(buf) - 4)
55
56	buf, err = b.cr.Next(msgSize)
57	if err != nil {
58		return nil, err
59	}
60
61	code := binary.BigEndian.Uint32(buf)
62
63	switch code {
64	case ProtocolVersionNumber:
65		err = b.startupMessage.Decode(buf)
66		if err != nil {
67			return nil, err
68		}
69		return &b.startupMessage, nil
70	case sslRequestNumber:
71		err = b.sslRequest.Decode(buf)
72		if err != nil {
73			return nil, err
74		}
75		return &b.sslRequest, nil
76	case cancelRequestCode:
77		err = b.cancelRequest.Decode(buf)
78		if err != nil {
79			return nil, err
80		}
81		return &b.cancelRequest, nil
82	default:
83		return nil, fmt.Errorf("unknown startup message code: %d", code)
84	}
85}
86
87// Receive receives a message from the frontend.
88func (b *Backend) Receive() (FrontendMessage, error) {
89	if !b.partialMsg {
90		header, err := b.cr.Next(5)
91		if err != nil {
92			return nil, err
93		}
94
95		b.msgType = header[0]
96		b.bodyLen = int(binary.BigEndian.Uint32(header[1:])) - 4
97		b.partialMsg = true
98	}
99
100	var msg FrontendMessage
101	switch b.msgType {
102	case 'B':
103		msg = &b.bind
104	case 'C':
105		msg = &b._close
106	case 'D':
107		msg = &b.describe
108	case 'E':
109		msg = &b.execute
110	case 'f':
111		msg = &b.copyFail
112	case 'H':
113		msg = &b.flush
114	case 'P':
115		msg = &b.parse
116	case 'p':
117		msg = &b.passwordMessage
118	case 'Q':
119		msg = &b.query
120	case 'S':
121		msg = &b.sync
122	case 'X':
123		msg = &b.terminate
124	default:
125		return nil, fmt.Errorf("unknown message type: %c", b.msgType)
126	}
127
128	msgBody, err := b.cr.Next(b.bodyLen)
129	if err != nil {
130		return nil, err
131	}
132
133	b.partialMsg = false
134
135	err = msg.Decode(msgBody)
136	return msg, err
137}
138