1package middleware 2 3import ( 4 "bytes" 5 "net/http" 6 "net/http/httptest" 7 "testing" 8 9 "github.com/go-chi/chi/v5" 10) 11 12func TestContentType(t *testing.T) { 13 t.Parallel() 14 15 var tests = []struct { 16 name string 17 inputValue string 18 allowedContentTypes []string 19 want int 20 }{ 21 { 22 "should accept requests with a matching content type", 23 "application/json; charset=UTF-8", 24 []string{"application/json"}, 25 http.StatusOK, 26 }, 27 { 28 "should accept requests with a matching content type no charset", 29 "application/json", 30 []string{"application/json"}, 31 http.StatusOK, 32 }, 33 { 34 "should accept requests with a matching content-type with extra values", 35 "application/json; foo=bar; charset=UTF-8; spam=eggs", 36 []string{"application/json"}, 37 http.StatusOK, 38 }, 39 { 40 "should accept requests with a matching content type when multiple content types are supported", 41 "text/xml; charset=UTF-8", 42 []string{"application/json", "text/xml"}, 43 http.StatusOK, 44 }, 45 { 46 "should not accept requests with a mismatching content type", 47 "text/plain; charset=latin-1", 48 []string{"application/json"}, 49 http.StatusUnsupportedMediaType, 50 }, 51 { 52 "should not accept requests with a mismatching content type even if multiple content types are allowed", 53 "text/plain; charset=Latin-1", 54 []string{"application/json", "text/xml"}, 55 http.StatusUnsupportedMediaType, 56 }, 57 } 58 59 for _, tt := range tests { 60 var tt = tt 61 t.Run(tt.name, func(t *testing.T) { 62 t.Parallel() 63 64 recorder := httptest.NewRecorder() 65 66 r := chi.NewRouter() 67 r.Use(AllowContentType(tt.allowedContentTypes...)) 68 r.Post("/", func(w http.ResponseWriter, r *http.Request) {}) 69 70 body := []byte("This is my content. There are many like this but this one is mine") 71 req := httptest.NewRequest("POST", "/", bytes.NewReader(body)) 72 req.Header.Set("Content-Type", tt.inputValue) 73 74 r.ServeHTTP(recorder, req) 75 res := recorder.Result() 76 77 if res.StatusCode != tt.want { 78 t.Errorf("response is incorrect, got %d, want %d", recorder.Code, tt.want) 79 } 80 }) 81 } 82} 83