1package middleware
2
3import (
4	"net/http"
5	"net/http/httptest"
6	"testing"
7
8	"github.com/go-chi/chi"
9)
10
11func TestContentCharset(t *testing.T) {
12	t.Parallel()
13
14	var tests = []struct {
15		name                string
16		inputValue          string
17		inputContentCharset []string
18		want                int
19	}{
20		{
21			"should accept requests with a matching charset",
22			"application/json; charset=UTF-8",
23			[]string{"UTF-8"},
24			http.StatusOK,
25		},
26		{
27			"should be case-insensitive",
28			"application/json; charset=utf-8",
29			[]string{"UTF-8"},
30			http.StatusOK,
31		},
32		{
33			"should accept requests with a matching charset with extra values",
34			"application/json; foo=bar; charset=UTF-8; spam=eggs",
35			[]string{"UTF-8"},
36			http.StatusOK,
37		},
38		{
39			"should accept requests with a matching charset when multiple charsets are supported",
40			"text/xml; charset=UTF-8",
41			[]string{"UTF-8", "Latin-1"},
42			http.StatusOK,
43		},
44		{
45			"should accept requests with no charset if empty charset headers are allowed",
46			"text/xml",
47			[]string{"UTF-8", ""},
48			http.StatusOK,
49		},
50		{
51			"should not accept requests with no charset if empty charset headers are not allowed",
52			"text/xml",
53			[]string{"UTF-8"},
54			http.StatusUnsupportedMediaType,
55		},
56		{
57			"should not accept requests with a mismatching charset",
58			"text/plain; charset=Latin-1",
59			[]string{"UTF-8"},
60			http.StatusUnsupportedMediaType,
61		},
62		{
63			"should not accept requests with a mismatching charset even if empty charsets are allowed",
64			"text/plain; charset=Latin-1",
65			[]string{"UTF-8", ""},
66			http.StatusUnsupportedMediaType,
67		},
68	}
69
70	for _, tt := range tests {
71		var tt = tt
72		t.Run(tt.name, func(t *testing.T) {
73			t.Parallel()
74
75			var recorder = httptest.NewRecorder()
76
77			var r = chi.NewRouter()
78			r.Use(ContentCharset(tt.inputContentCharset...))
79			r.Get("/", func(w http.ResponseWriter, r *http.Request) {})
80
81			var req, _ = http.NewRequest("GET", "/", nil)
82			req.Header.Set("Content-Type", tt.inputValue)
83
84			r.ServeHTTP(recorder, req)
85			var res = recorder.Result()
86
87			if res.StatusCode != tt.want {
88				t.Errorf("response is incorrect, got %d, want %d", recorder.Code, tt.want)
89			}
90		})
91	}
92}
93
94func TestSplit(t *testing.T) {
95	t.Parallel()
96
97	var s1, s2 = split("  type1;type2  ", ";")
98
99	if s1 != "type1" || s2 != "type2" {
100		t.Errorf("Want type1, type2 got %s, %s", s1, s2)
101	}
102
103	s1, s2 = split("type1  ", ";")
104
105	if s1 != "type1" {
106		t.Errorf("Want \"type1\" got \"%s\"", s1)
107	}
108	if s2 != "" {
109		t.Errorf("Want empty string got \"%s\"", s2)
110	}
111}
112
113func TestContentEncoding(t *testing.T) {
114	t.Parallel()
115
116	if !contentEncoding("application/json; foo=bar; charset=utf-8; spam=eggs", []string{"utf-8"}...) {
117		t.Error("Want true, got false")
118	}
119
120	if contentEncoding("text/plain; charset=latin-1", []string{"utf-8"}...) {
121		t.Error("Want false, got true")
122	}
123
124	if !contentEncoding("text/xml; charset=UTF-8", []string{"latin-1", "utf-8"}...) {
125		t.Error("Want true, got false")
126	}
127}
128