1package pgproto3_test
2
3import (
4	"io"
5	"testing"
6
7	"github.com/jackc/pgio"
8	"github.com/jackc/pgproto3/v2"
9	"github.com/stretchr/testify/assert"
10	"github.com/stretchr/testify/require"
11)
12
13func TestBackendReceiveInterrupted(t *testing.T) {
14	t.Parallel()
15
16	server := &interruptReader{}
17	server.push([]byte{'Q', 0, 0, 0, 6})
18
19	backend := pgproto3.NewBackend(pgproto3.NewChunkReader(server), nil)
20
21	msg, err := backend.Receive()
22	if err == nil {
23		t.Fatal("expected err")
24	}
25	if msg != nil {
26		t.Fatalf("did not expect msg, but %v", msg)
27	}
28
29	server.push([]byte{'I', 0})
30
31	msg, err = backend.Receive()
32	if err != nil {
33		t.Fatal(err)
34	}
35	if msg, ok := msg.(*pgproto3.Query); !ok || msg.String != "I" {
36		t.Fatalf("unexpected msg: %v", msg)
37	}
38}
39
40func TestBackendReceiveUnexpectedEOF(t *testing.T) {
41	t.Parallel()
42
43	server := &interruptReader{}
44	server.push([]byte{'Q', 0, 0, 0, 6})
45
46	backend := pgproto3.NewBackend(pgproto3.NewChunkReader(server), nil)
47
48	// Receive regular msg
49	msg, err := backend.Receive()
50	assert.Nil(t, msg)
51	assert.Equal(t, io.ErrUnexpectedEOF, err)
52
53	// Receive StartupMessage msg
54	dst := []byte{}
55	dst = pgio.AppendUint32(dst, 1000) // tell the backend we expect 1000 bytes to be read
56	dst = pgio.AppendUint32(dst, 1)    // only send 1 byte
57	server.push(dst)
58
59	msg, err = backend.ReceiveStartupMessage()
60	assert.Nil(t, msg)
61	assert.Equal(t, io.ErrUnexpectedEOF, err)
62}
63
64func TestStartupMessage(t *testing.T) {
65	t.Parallel()
66
67	t.Run("valid StartupMessage", func(t *testing.T) {
68		want := &pgproto3.StartupMessage{
69			ProtocolVersion: pgproto3.ProtocolVersionNumber,
70			Parameters: map[string]string{
71				"username": "tester",
72			},
73		}
74		dst := []byte{}
75		dst = want.Encode(dst)
76
77		server := &interruptReader{}
78		server.push(dst)
79
80		backend := pgproto3.NewBackend(pgproto3.NewChunkReader(server), nil)
81
82		msg, err := backend.ReceiveStartupMessage()
83		require.NoError(t, err)
84		require.Equal(t, want, msg)
85	})
86
87	t.Run("invalid packet length", func(t *testing.T) {
88		wantErr := "invalid length of startup packet"
89		tests := []struct {
90			name      string
91			packetLen uint32
92		}{
93			{
94				name: "large packet length",
95				// Since the StartupMessage contains the "Length of message contents
96				//  in bytes, including self", the max startup packet length is actually
97				//  10000+4. Therefore, let's go past the limit with 10005
98				packetLen: 10005,
99			},
100			{
101				name:      "short packet length",
102				packetLen: 3,
103			},
104		}
105		for _, tt := range tests {
106			t.Run(tt.name, func(t *testing.T) {
107				server := &interruptReader{}
108				dst := []byte{}
109				dst = pgio.AppendUint32(dst, tt.packetLen)
110				dst = pgio.AppendUint32(dst, pgproto3.ProtocolVersionNumber)
111				server.push(dst)
112
113				backend := pgproto3.NewBackend(pgproto3.NewChunkReader(server), nil)
114
115				msg, err := backend.ReceiveStartupMessage()
116				require.Error(t, err)
117				require.Nil(t, msg)
118				require.Contains(t, err.Error(), wantErr)
119			})
120		}
121	})
122}
123