1/*
2Copyright 2014 The Kubernetes Authors.
3
4Licensed under the Apache License, Version 2.0 (the "License");
5you may not use this file except in compliance with the License.
6You may obtain a copy of the License at
7
8    http://www.apache.org/licenses/LICENSE-2.0
9
10Unless required by applicable law or agreed to in writing, software
11distributed under the License is distributed on an "AS IS" BASIS,
12WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13See the License for the specific language governing permissions and
14limitations under the License.
15*/
16
17package httplog
18
19import (
20	"net/http"
21	"net/http/httptest"
22	"reflect"
23	"testing"
24)
25
26func TestDefaultStacktracePred(t *testing.T) {
27	for _, x := range []int{101, 200, 204, 302, 400, 404} {
28		if DefaultStacktracePred(x) {
29			t.Fatalf("should not log on %v by default", x)
30		}
31	}
32
33	for _, x := range []int{500, 100} {
34		if !DefaultStacktracePred(x) {
35			t.Fatalf("should log on %v by default", x)
36		}
37	}
38}
39
40func TestStatusIsNot(t *testing.T) {
41	statusTestTable := []struct {
42		status   int
43		statuses []int
44		want     bool
45	}{
46		{http.StatusOK, []int{}, true},
47		{http.StatusOK, []int{http.StatusOK}, false},
48		{http.StatusCreated, []int{http.StatusOK, http.StatusAccepted}, true},
49	}
50	for _, tt := range statusTestTable {
51		sp := StatusIsNot(tt.statuses...)
52		got := sp(tt.status)
53		if got != tt.want {
54			t.Errorf("Expected %v, got %v", tt.want, got)
55		}
56	}
57}
58
59func TestWithLogging(t *testing.T) {
60	req, err := http.NewRequest("GET", "http://example.com", nil)
61	if err != nil {
62		t.Errorf("Unexpected error: %v", err)
63	}
64	var handler http.Handler
65	handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
66	handler = WithLogging(WithLogging(handler, DefaultStacktracePred), DefaultStacktracePred)
67
68	func() {
69		defer func() {
70			if r := recover(); r == nil {
71				t.Errorf("Expected newLogged to panic")
72			}
73		}()
74		w := httptest.NewRecorder()
75		handler.ServeHTTP(w, req)
76	}()
77}
78
79func TestLogOf(t *testing.T) {
80	logOfTests := []bool{true, false}
81	for _, makeLogger := range logOfTests {
82		req, err := http.NewRequest("GET", "http://example.com", nil)
83		if err != nil {
84			t.Errorf("Unexpected error: %v", err)
85		}
86		var want string
87		var handler http.Handler
88		handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
89			got := reflect.TypeOf(LogOf(r, w)).String()
90			if want != got {
91				t.Errorf("Expected %v, got %v", want, got)
92			}
93		})
94		if makeLogger {
95			handler = WithLogging(handler, DefaultStacktracePred)
96			want = "*httplog.respLogger"
97		} else {
98			want = "*httplog.passthroughLogger"
99		}
100
101		w := httptest.NewRecorder()
102		handler.ServeHTTP(w, req)
103	}
104}
105
106func TestUnlogged(t *testing.T) {
107	unloggedTests := []bool{true, false}
108	for _, makeLogger := range unloggedTests {
109		req, err := http.NewRequest("GET", "http://example.com", nil)
110		if err != nil {
111			t.Errorf("Unexpected error: %v", err)
112		}
113
114		origWriter := httptest.NewRecorder()
115		var handler http.Handler
116		handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
117			got := Unlogged(r, w)
118			if origWriter != got {
119				t.Errorf("Expected origin writer, got %#v", got)
120			}
121		})
122		if makeLogger {
123			handler = WithLogging(handler, DefaultStacktracePred)
124		}
125
126		handler.ServeHTTP(origWriter, req)
127	}
128}
129
130type testResponseWriter struct{}
131
132func (*testResponseWriter) Header() http.Header       { return nil }
133func (*testResponseWriter) Write([]byte) (int, error) { return 0, nil }
134func (*testResponseWriter) WriteHeader(int)           {}
135
136func TestLoggedStatus(t *testing.T) {
137	req, err := http.NewRequest("GET", "http://example.com", nil)
138	if err != nil {
139		t.Errorf("unexpected error: %v", err)
140	}
141
142	var tw http.ResponseWriter = new(testResponseWriter)
143	logger := newLogged(req, tw)
144	logger.Write(nil)
145
146	if logger.status != http.StatusOK {
147		t.Errorf("expected status after write to be %v, got %v", http.StatusOK, logger.status)
148	}
149
150	tw = new(testResponseWriter)
151	logger = newLogged(req, tw)
152	logger.WriteHeader(http.StatusForbidden)
153	logger.Write(nil)
154
155	if logger.status != http.StatusForbidden {
156		t.Errorf("expected status after write to remain %v, got %v", http.StatusForbidden, logger.status)
157	}
158}
159