1package pgproto3_test 2 3import ( 4 "io" 5 "testing" 6 7 "github.com/jackc/pgio" 8 "github.com/jackc/pgproto3/v2" 9 "github.com/stretchr/testify/assert" 10 "github.com/stretchr/testify/require" 11) 12 13func TestBackendReceiveInterrupted(t *testing.T) { 14 t.Parallel() 15 16 server := &interruptReader{} 17 server.push([]byte{'Q', 0, 0, 0, 6}) 18 19 backend := pgproto3.NewBackend(pgproto3.NewChunkReader(server), nil) 20 21 msg, err := backend.Receive() 22 if err == nil { 23 t.Fatal("expected err") 24 } 25 if msg != nil { 26 t.Fatalf("did not expect msg, but %v", msg) 27 } 28 29 server.push([]byte{'I', 0}) 30 31 msg, err = backend.Receive() 32 if err != nil { 33 t.Fatal(err) 34 } 35 if msg, ok := msg.(*pgproto3.Query); !ok || msg.String != "I" { 36 t.Fatalf("unexpected msg: %v", msg) 37 } 38} 39 40func TestBackendReceiveUnexpectedEOF(t *testing.T) { 41 t.Parallel() 42 43 server := &interruptReader{} 44 server.push([]byte{'Q', 0, 0, 0, 6}) 45 46 backend := pgproto3.NewBackend(pgproto3.NewChunkReader(server), nil) 47 48 // Receive regular msg 49 msg, err := backend.Receive() 50 assert.Nil(t, msg) 51 assert.Equal(t, io.ErrUnexpectedEOF, err) 52 53 // Receive StartupMessage msg 54 dst := []byte{} 55 dst = pgio.AppendUint32(dst, 1000) // tell the backend we expect 1000 bytes to be read 56 dst = pgio.AppendUint32(dst, 1) // only send 1 byte 57 server.push(dst) 58 59 msg, err = backend.ReceiveStartupMessage() 60 assert.Nil(t, msg) 61 assert.Equal(t, io.ErrUnexpectedEOF, err) 62} 63 64func TestStartupMessage(t *testing.T) { 65 t.Parallel() 66 67 t.Run("valid StartupMessage", func(t *testing.T) { 68 want := &pgproto3.StartupMessage{ 69 ProtocolVersion: pgproto3.ProtocolVersionNumber, 70 Parameters: map[string]string{ 71 "username": "tester", 72 }, 73 } 74 dst := []byte{} 75 dst = want.Encode(dst) 76 77 server := &interruptReader{} 78 server.push(dst) 79 80 backend := pgproto3.NewBackend(pgproto3.NewChunkReader(server), nil) 81 82 msg, err := backend.ReceiveStartupMessage() 83 require.NoError(t, err) 84 require.Equal(t, want, msg) 85 }) 86 87 t.Run("invalid packet length", func(t *testing.T) { 88 wantErr := "invalid length of startup packet" 89 tests := []struct { 90 name string 91 packetLen uint32 92 }{ 93 { 94 name: "large packet length", 95 // Since the StartupMessage contains the "Length of message contents 96 // in bytes, including self", the max startup packet length is actually 97 // 10000+4. Therefore, let's go past the limit with 10005 98 packetLen: 10005, 99 }, 100 { 101 name: "short packet length", 102 packetLen: 3, 103 }, 104 } 105 for _, tt := range tests { 106 t.Run(tt.name, func(t *testing.T) { 107 server := &interruptReader{} 108 dst := []byte{} 109 dst = pgio.AppendUint32(dst, tt.packetLen) 110 dst = pgio.AppendUint32(dst, pgproto3.ProtocolVersionNumber) 111 server.push(dst) 112 113 backend := pgproto3.NewBackend(pgproto3.NewChunkReader(server), nil) 114 115 msg, err := backend.ReceiveStartupMessage() 116 require.Error(t, err) 117 require.Nil(t, msg) 118 require.Contains(t, err.Error(), wantErr) 119 }) 120 } 121 }) 122} 123