1package connectionheader 2 3import ( 4 "net/http" 5 "net/http/httptest" 6 "testing" 7 8 "github.com/stretchr/testify/assert" 9) 10 11func TestRemover(t *testing.T) { 12 testCases := []struct { 13 desc string 14 reqHeaders map[string]string 15 expected http.Header 16 }{ 17 { 18 desc: "simple remove", 19 reqHeaders: map[string]string{ 20 "Foo": "bar", 21 connectionHeader: "foo", 22 }, 23 expected: http.Header{}, 24 }, 25 { 26 desc: "remove and Upgrade", 27 reqHeaders: map[string]string{ 28 upgradeHeader: "test", 29 "Foo": "bar", 30 connectionHeader: "Upgrade,foo", 31 }, 32 expected: http.Header{ 33 upgradeHeader: []string{"test"}, 34 connectionHeader: []string{"Upgrade"}, 35 }, 36 }, 37 { 38 desc: "no remove", 39 reqHeaders: map[string]string{ 40 "Foo": "bar", 41 connectionHeader: "fii", 42 }, 43 expected: http.Header{ 44 "Foo": []string{"bar"}, 45 }, 46 }, 47 } 48 49 for _, test := range testCases { 50 test := test 51 t.Run(test.desc, func(t *testing.T) { 52 t.Parallel() 53 54 next := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {}) 55 56 h := Remover(next) 57 58 req := httptest.NewRequest(http.MethodGet, "https://localhost", nil) 59 60 for k, v := range test.reqHeaders { 61 req.Header.Set(k, v) 62 } 63 64 rw := httptest.NewRecorder() 65 66 h.ServeHTTP(rw, req) 67 68 assert.Equal(t, test.expected, req.Header) 69 }) 70 } 71} 72