1package pgproto3_test 2 3import ( 4 "testing" 5 6 "github.com/pkg/errors" 7 8 "github.com/jackc/pgx/pgproto3" 9) 10 11type interruptReader struct { 12 chunks [][]byte 13} 14 15func (ir *interruptReader) Read(p []byte) (n int, err error) { 16 if len(ir.chunks) == 0 { 17 return 0, errors.New("no data") 18 } 19 20 n = copy(p, ir.chunks[0]) 21 if n != len(ir.chunks[0]) { 22 panic("this test reader doesn't support partial reads of chunks") 23 } 24 25 ir.chunks = ir.chunks[1:] 26 27 return n, nil 28} 29 30func (ir *interruptReader) push(p []byte) { 31 ir.chunks = append(ir.chunks, p) 32} 33 34func TestFrontendReceiveInterrupted(t *testing.T) { 35 t.Parallel() 36 37 server := &interruptReader{} 38 server.push([]byte{'Z', 0, 0, 0, 5}) 39 40 frontend, err := pgproto3.NewFrontend(server, nil) 41 if err != nil { 42 t.Fatal(err) 43 } 44 45 msg, err := frontend.Receive() 46 if err == nil { 47 t.Fatal("expected err") 48 } 49 if msg != nil { 50 t.Fatalf("did not expect msg, but %v", msg) 51 } 52 53 server.push([]byte{'I'}) 54 55 msg, err = frontend.Receive() 56 if err != nil { 57 t.Fatal(err) 58 } 59 if msg, ok := msg.(*pgproto3.ReadyForQuery); !ok || msg.TxStatus != 'I' { 60 t.Fatalf("unexpected msg: %v", msg) 61 } 62} 63