1package pgproto3_test
2
3import (
4	"testing"
5
6	"github.com/pkg/errors"
7
8	"github.com/jackc/pgx/pgproto3"
9)
10
11type interruptReader struct {
12	chunks [][]byte
13}
14
15func (ir *interruptReader) Read(p []byte) (n int, err error) {
16	if len(ir.chunks) == 0 {
17		return 0, errors.New("no data")
18	}
19
20	n = copy(p, ir.chunks[0])
21	if n != len(ir.chunks[0]) {
22		panic("this test reader doesn't support partial reads of chunks")
23	}
24
25	ir.chunks = ir.chunks[1:]
26
27	return n, nil
28}
29
30func (ir *interruptReader) push(p []byte) {
31	ir.chunks = append(ir.chunks, p)
32}
33
34func TestFrontendReceiveInterrupted(t *testing.T) {
35	t.Parallel()
36
37	server := &interruptReader{}
38	server.push([]byte{'Z', 0, 0, 0, 5})
39
40	frontend, err := pgproto3.NewFrontend(server, nil)
41	if err != nil {
42		t.Fatal(err)
43	}
44
45	msg, err := frontend.Receive()
46	if err == nil {
47		t.Fatal("expected err")
48	}
49	if msg != nil {
50		t.Fatalf("did not expect msg, but %v", msg)
51	}
52
53	server.push([]byte{'I'})
54
55	msg, err = frontend.Receive()
56	if err != nil {
57		t.Fatal(err)
58	}
59	if msg, ok := msg.(*pgproto3.ReadyForQuery); !ok || msg.TxStatus != 'I' {
60		t.Fatalf("unexpected msg: %v", msg)
61	}
62}
63