1package middleware
2
3import (
4	"bytes"
5	"net/http"
6	"net/http/httptest"
7	"testing"
8
9	"github.com/go-chi/chi/v5"
10)
11
12func TestContentType(t *testing.T) {
13	t.Parallel()
14
15	var tests = []struct {
16		name                string
17		inputValue          string
18		allowedContentTypes []string
19		want                int
20	}{
21		{
22			"should accept requests with a matching content type",
23			"application/json; charset=UTF-8",
24			[]string{"application/json"},
25			http.StatusOK,
26		},
27		{
28			"should accept requests with a matching content type no charset",
29			"application/json",
30			[]string{"application/json"},
31			http.StatusOK,
32		},
33		{
34			"should accept requests with a matching content-type with extra values",
35			"application/json; foo=bar; charset=UTF-8; spam=eggs",
36			[]string{"application/json"},
37			http.StatusOK,
38		},
39		{
40			"should accept requests with a matching content type when multiple content types are supported",
41			"text/xml; charset=UTF-8",
42			[]string{"application/json", "text/xml"},
43			http.StatusOK,
44		},
45		{
46			"should not accept requests with a mismatching content type",
47			"text/plain; charset=latin-1",
48			[]string{"application/json"},
49			http.StatusUnsupportedMediaType,
50		},
51		{
52			"should not accept requests with a mismatching content type even if multiple content types are allowed",
53			"text/plain; charset=Latin-1",
54			[]string{"application/json", "text/xml"},
55			http.StatusUnsupportedMediaType,
56		},
57	}
58
59	for _, tt := range tests {
60		var tt = tt
61		t.Run(tt.name, func(t *testing.T) {
62			t.Parallel()
63
64			recorder := httptest.NewRecorder()
65
66			r := chi.NewRouter()
67			r.Use(AllowContentType(tt.allowedContentTypes...))
68			r.Post("/", func(w http.ResponseWriter, r *http.Request) {})
69
70			body := []byte("This is my content. There are many like this but this one is mine")
71			req := httptest.NewRequest("POST", "/", bytes.NewReader(body))
72			req.Header.Set("Content-Type", tt.inputValue)
73
74			r.ServeHTTP(recorder, req)
75			res := recorder.Result()
76
77			if res.StatusCode != tt.want {
78				t.Errorf("response is incorrect, got %d, want %d", recorder.Code, tt.want)
79			}
80		})
81	}
82}
83