1package pgproto3_test
2
3import (
4	"io"
5	"testing"
6
7	"github.com/jackc/pgproto3/v2"
8	"github.com/stretchr/testify/assert"
9)
10
11type interruptReader struct {
12	chunks [][]byte
13}
14
15func (ir *interruptReader) Read(p []byte) (n int, err error) {
16	if len(ir.chunks) == 0 {
17		return 0, io.EOF
18	}
19
20	n = copy(p, ir.chunks[0])
21	if n != len(ir.chunks[0]) {
22		panic("this test reader doesn't support partial reads of chunks")
23	}
24
25	ir.chunks = ir.chunks[1:]
26
27	return n, nil
28}
29
30func (ir *interruptReader) push(p []byte) {
31	ir.chunks = append(ir.chunks, p)
32}
33
34func TestFrontendReceiveInterrupted(t *testing.T) {
35	t.Parallel()
36
37	server := &interruptReader{}
38	server.push([]byte{'Z', 0, 0, 0, 5})
39
40	frontend := pgproto3.NewFrontend(pgproto3.NewChunkReader(server), nil)
41
42	msg, err := frontend.Receive()
43	if err == nil {
44		t.Fatal("expected err")
45	}
46	if msg != nil {
47		t.Fatalf("did not expect msg, but %v", msg)
48	}
49
50	server.push([]byte{'I'})
51
52	msg, err = frontend.Receive()
53	if err != nil {
54		t.Fatal(err)
55	}
56	if msg, ok := msg.(*pgproto3.ReadyForQuery); !ok || msg.TxStatus != 'I' {
57		t.Fatalf("unexpected msg: %v", msg)
58	}
59}
60
61func TestFrontendReceiveUnexpectedEOF(t *testing.T) {
62	t.Parallel()
63
64	server := &interruptReader{}
65	server.push([]byte{'Z', 0, 0, 0, 5})
66
67	frontend := pgproto3.NewFrontend(pgproto3.NewChunkReader(server), nil)
68
69	msg, err := frontend.Receive()
70	if err == nil {
71		t.Fatal("expected err")
72	}
73	if msg != nil {
74		t.Fatalf("did not expect msg, but %v", msg)
75	}
76
77	msg, err = frontend.Receive()
78	assert.Nil(t, msg)
79	assert.Equal(t, io.ErrUnexpectedEOF, err)
80}
81