1package connectionheader
2
3import (
4	"net/http"
5	"net/http/httptest"
6	"testing"
7
8	"github.com/stretchr/testify/assert"
9)
10
11func TestRemover(t *testing.T) {
12	testCases := []struct {
13		desc       string
14		reqHeaders map[string]string
15		expected   http.Header
16	}{
17		{
18			desc: "simple remove",
19			reqHeaders: map[string]string{
20				"Foo":            "bar",
21				connectionHeader: "foo",
22			},
23			expected: http.Header{},
24		},
25		{
26			desc: "remove and Upgrade",
27			reqHeaders: map[string]string{
28				upgradeHeader:    "test",
29				"Foo":            "bar",
30				connectionHeader: "Upgrade,foo",
31			},
32			expected: http.Header{
33				upgradeHeader:    []string{"test"},
34				connectionHeader: []string{"Upgrade"},
35			},
36		},
37		{
38			desc: "no remove",
39			reqHeaders: map[string]string{
40				"Foo":            "bar",
41				connectionHeader: "fii",
42			},
43			expected: http.Header{
44				"Foo": []string{"bar"},
45			},
46		},
47	}
48
49	for _, test := range testCases {
50		test := test
51		t.Run(test.desc, func(t *testing.T) {
52			t.Parallel()
53
54			next := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {})
55
56			h := Remover(next)
57
58			req := httptest.NewRequest(http.MethodGet, "https://localhost", nil)
59
60			for k, v := range test.reqHeaders {
61				req.Header.Set(k, v)
62			}
63
64			rw := httptest.NewRecorder()
65
66			h.ServeHTTP(rw, req)
67
68			assert.Equal(t, test.expected, req.Header)
69		})
70	}
71}
72