1package pgproto3_test 2 3import ( 4 "io" 5 "testing" 6 7 "github.com/jackc/pgproto3/v2" 8 "github.com/stretchr/testify/assert" 9) 10 11type interruptReader struct { 12 chunks [][]byte 13} 14 15func (ir *interruptReader) Read(p []byte) (n int, err error) { 16 if len(ir.chunks) == 0 { 17 return 0, io.EOF 18 } 19 20 n = copy(p, ir.chunks[0]) 21 if n != len(ir.chunks[0]) { 22 panic("this test reader doesn't support partial reads of chunks") 23 } 24 25 ir.chunks = ir.chunks[1:] 26 27 return n, nil 28} 29 30func (ir *interruptReader) push(p []byte) { 31 ir.chunks = append(ir.chunks, p) 32} 33 34func TestFrontendReceiveInterrupted(t *testing.T) { 35 t.Parallel() 36 37 server := &interruptReader{} 38 server.push([]byte{'Z', 0, 0, 0, 5}) 39 40 frontend := pgproto3.NewFrontend(pgproto3.NewChunkReader(server), nil) 41 42 msg, err := frontend.Receive() 43 if err == nil { 44 t.Fatal("expected err") 45 } 46 if msg != nil { 47 t.Fatalf("did not expect msg, but %v", msg) 48 } 49 50 server.push([]byte{'I'}) 51 52 msg, err = frontend.Receive() 53 if err != nil { 54 t.Fatal(err) 55 } 56 if msg, ok := msg.(*pgproto3.ReadyForQuery); !ok || msg.TxStatus != 'I' { 57 t.Fatalf("unexpected msg: %v", msg) 58 } 59} 60 61func TestFrontendReceiveUnexpectedEOF(t *testing.T) { 62 t.Parallel() 63 64 server := &interruptReader{} 65 server.push([]byte{'Z', 0, 0, 0, 5}) 66 67 frontend := pgproto3.NewFrontend(pgproto3.NewChunkReader(server), nil) 68 69 msg, err := frontend.Receive() 70 if err == nil { 71 t.Fatal("expected err") 72 } 73 if msg != nil { 74 t.Fatalf("did not expect msg, but %v", msg) 75 } 76 77 msg, err = frontend.Receive() 78 assert.Nil(t, msg) 79 assert.Equal(t, io.ErrUnexpectedEOF, err) 80} 81