1package log
2
3import (
4	"bytes"
5	"fmt"
6	"io/ioutil"
7	"net/http"
8	"net/http/httptest"
9	"net/url"
10	"reflect"
11	"regexp"
12	"testing"
13	"time"
14
15	"github.com/google/go-cmp/cmp"
16	"github.com/rs/zerolog"
17)
18
19func TestGenerateUUID(t *testing.T) {
20	prev := uuid()
21	for i := 0; i < 100; i++ {
22		id := uuid()
23		if id == "" {
24			t.Fatal("random pool failure")
25		}
26		if prev == id {
27			t.Fatalf("Should get a new ID!")
28		}
29		matched := regexp.MustCompile("[a-fA-F0-9]{8}-[a-fA-F0-9]{4}-[a-fA-F0-9]{4}-[a-fA-F0-9]{4}-[a-fA-F0-9]{12}").MatchString(id)
30		if !matched {
31			t.Fatalf("expected match %s %v", id, matched)
32		}
33	}
34}
35
36func decodeIfBinary(out fmt.Stringer) string {
37	return out.String()
38}
39
40func TestNewHandler(t *testing.T) {
41	log := zerolog.New(nil).With().
42		Str("foo", "bar").
43		Logger()
44	lh := NewHandler(log)
45	h := lh(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
46		l := FromRequest(r)
47		if !reflect.DeepEqual(*l, log) {
48			t.Fail()
49		}
50	}))
51	h.ServeHTTP(nil, &http.Request{})
52}
53
54func TestURLHandler(t *testing.T) {
55	out := &bytes.Buffer{}
56	r := &http.Request{
57		URL: &url.URL{Path: "/path", RawQuery: "foo=bar"},
58	}
59	h := URLHandler("url")(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
60		l := FromRequest(r)
61		l.Log().Msg("")
62	}))
63	h = NewHandler(zerolog.New(out))(h)
64	h.ServeHTTP(nil, r)
65	if want, got := `{"url":"/path?foo=bar"}`+"\n", decodeIfBinary(out); want != got {
66		t.Errorf("Invalid log output, got: %s, want: %s", got, want)
67	}
68}
69
70func TestMethodHandler(t *testing.T) {
71	out := &bytes.Buffer{}
72	r := &http.Request{
73		Method: "POST",
74	}
75	h := MethodHandler("method")(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
76		l := FromRequest(r)
77		l.Log().Msg("")
78	}))
79	h = NewHandler(zerolog.New(out))(h)
80	h.ServeHTTP(nil, r)
81	if want, got := `{"method":"POST"}`+"\n", decodeIfBinary(out); want != got {
82		t.Errorf("Invalid log output, got: %s, want: %s", got, want)
83	}
84}
85
86func TestRequestHandler(t *testing.T) {
87	out := &bytes.Buffer{}
88	r := &http.Request{
89		Method: "POST",
90		URL:    &url.URL{Path: "/path", RawQuery: "foo=bar"},
91	}
92	h := RequestHandler("request")(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
93		l := FromRequest(r)
94		l.Log().Msg("")
95	}))
96	h = NewHandler(zerolog.New(out))(h)
97	h.ServeHTTP(nil, r)
98	if want, got := `{"request":"POST /path?foo=bar"}`+"\n", decodeIfBinary(out); want != got {
99		t.Errorf("Invalid log output, got: %s, want: %s", got, want)
100	}
101}
102
103func TestRemoteAddrHandler(t *testing.T) {
104	out := &bytes.Buffer{}
105	r := &http.Request{
106		RemoteAddr: "1.2.3.4:1234",
107	}
108	h := RemoteAddrHandler("ip")(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
109		l := FromRequest(r)
110		l.Log().Msg("")
111	}))
112	h = NewHandler(zerolog.New(out))(h)
113	h.ServeHTTP(nil, r)
114	if want, got := `{"ip":"1.2.3.4"}`+"\n", decodeIfBinary(out); want != got {
115		t.Errorf("Invalid log output, got: %s, want: %s", got, want)
116	}
117}
118
119func TestRemoteAddrHandlerIPv6(t *testing.T) {
120	out := &bytes.Buffer{}
121	r := &http.Request{
122		RemoteAddr: "[2001:db8:a0b:12f0::1]:1234",
123	}
124	h := RemoteAddrHandler("ip")(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
125		l := FromRequest(r)
126		l.Log().Msg("")
127	}))
128	h = NewHandler(zerolog.New(out))(h)
129	h.ServeHTTP(nil, r)
130	if want, got := `{"ip":"2001:db8:a0b:12f0::1"}`+"\n", decodeIfBinary(out); want != got {
131		t.Errorf("Invalid log output, got: %s, want: %s", got, want)
132	}
133}
134
135func TestUserAgentHandler(t *testing.T) {
136	out := &bytes.Buffer{}
137	r := &http.Request{
138		Header: http.Header{
139			"User-Agent": []string{"some user agent string"},
140		},
141	}
142	h := UserAgentHandler("ua")(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
143		l := FromRequest(r)
144		l.Log().Msg("")
145	}))
146	h = NewHandler(zerolog.New(out))(h)
147	h.ServeHTTP(nil, r)
148	if want, got := `{"ua":"some user agent string"}`+"\n", decodeIfBinary(out); want != got {
149		t.Errorf("Invalid log output, got: %s, want: %s", got, want)
150	}
151}
152
153func TestRefererHandler(t *testing.T) {
154	out := &bytes.Buffer{}
155	r := &http.Request{
156		Header: http.Header{
157			"Referer": []string{"http://foo.com/bar"},
158		},
159	}
160	h := RefererHandler("referer")(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
161		l := FromRequest(r)
162		l.Log().Msg("")
163	}))
164	h = NewHandler(zerolog.New(out))(h)
165	h.ServeHTTP(nil, r)
166	if want, got := `{"referer":"http://foo.com/bar"}`+"\n", decodeIfBinary(out); want != got {
167		t.Errorf("Invalid log output, got: %s, want: %s", got, want)
168	}
169}
170
171func TestRequestIDHandler(t *testing.T) {
172	out := &bytes.Buffer{}
173	r := &http.Request{
174		Header: http.Header{
175			"Referer": []string{"http://foo.com/bar"},
176		},
177	}
178	h := RequestIDHandler("id", "Request-Id")(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
179		id, ok := IDFromRequest(r)
180		if !ok {
181			t.Fatal("Missing id in request")
182		}
183		l := FromRequest(r)
184		l.Log().Msg("")
185		if want, got := fmt.Sprintf(`{"id":"%s"}`+"\n", id), decodeIfBinary(out); want != got {
186			t.Errorf("Invalid log output, got: %s, want: %s", got, want)
187		}
188	}))
189	h = NewHandler(zerolog.New(out))(h)
190	h.ServeHTTP(httptest.NewRecorder(), r)
191}
192
193func TestCombinedHandlers(t *testing.T) {
194	out := &bytes.Buffer{}
195	r := &http.Request{
196		Method: "POST",
197		URL:    &url.URL{Path: "/path", RawQuery: "foo=bar"},
198	}
199	h := MethodHandler("method")(RequestHandler("request")(URLHandler("url")(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
200		l := FromRequest(r)
201		l.Log().Msg("")
202	}))))
203	h = NewHandler(zerolog.New(out))(h)
204	h.ServeHTTP(nil, r)
205	if want, got := `{"method":"POST","request":"POST /path?foo=bar","url":"/path?foo=bar"}`+"\n", decodeIfBinary(out); want != got {
206		t.Errorf("Invalid log output, got: %s, want: %s", got, want)
207	}
208}
209
210func BenchmarkHandlers(b *testing.B) {
211	r := &http.Request{
212		Method: "POST",
213		URL:    &url.URL{Path: "/path", RawQuery: "foo=bar"},
214	}
215	h1 := URLHandler("url")(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
216		l := FromRequest(r)
217		l.Log().Msg("")
218	}))
219	h2 := MethodHandler("method")(RequestHandler("request")(h1))
220	handlers := map[string]http.Handler{
221		"Single":           NewHandler(zerolog.New(ioutil.Discard))(h1),
222		"Combined":         NewHandler(zerolog.New(ioutil.Discard))(h2),
223		"SingleDisabled":   NewHandler(zerolog.New(ioutil.Discard).Level(zerolog.Disabled))(h1),
224		"CombinedDisabled": NewHandler(zerolog.New(ioutil.Discard).Level(zerolog.Disabled))(h2),
225	}
226	for name := range handlers {
227		h := handlers[name]
228		b.Run(name, func(b *testing.B) {
229			for i := 0; i < b.N; i++ {
230				h.ServeHTTP(nil, r)
231			}
232		})
233	}
234}
235
236func BenchmarkDataRace(b *testing.B) {
237	log := zerolog.New(nil).With().
238		Str("foo", "bar").
239		Logger()
240	lh := NewHandler(log)
241	h := lh(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
242		l := FromRequest(r)
243		l.UpdateContext(func(c zerolog.Context) zerolog.Context {
244			return c.Str("bar", "baz")
245		})
246		l.Log().Msg("")
247	}))
248
249	b.RunParallel(func(pb *testing.PB) {
250		for pb.Next() {
251			h.ServeHTTP(nil, &http.Request{})
252		}
253	})
254}
255
256func TestLogHeadersHandler(t *testing.T) {
257	out := &bytes.Buffer{}
258
259	r := httptest.NewRequest(http.MethodGet, "/", nil)
260
261	r.Header.Set("X-Forwarded-For", "proxy1,proxy2,proxy3")
262
263	h := HeadersHandler([]string{"X-Forwarded-For"})(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
264		l := FromRequest(r)
265		l.Log().Msg("")
266	}))
267	h = NewHandler(zerolog.New(out))(h)
268	h.ServeHTTP(nil, r)
269	if want, got := `{"X-Forwarded-For":["proxy1,proxy2,proxy3"]}`+"\n", decodeIfBinary(out); want != got {
270		t.Errorf("Invalid log output, got: %s, want: %s", got, want)
271	}
272}
273func TestAccessHandler(t *testing.T) {
274	out := &bytes.Buffer{}
275
276	r := httptest.NewRequest(http.MethodGet, "/", nil)
277
278	h := AccessHandler(func(r *http.Request, status, size int, duration time.Duration) {
279		l := FromRequest(r)
280		l.Log().Int("status", status).Int("size", size).Msg("info")
281
282	})(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
283		l := FromRequest(r)
284		l.Log().Msg("some inner logging")
285		w.Write([]byte("Add something to the request of non-zero size"))
286	}))
287	h = NewHandler(zerolog.New(out))(h)
288	w := httptest.NewRecorder()
289
290	h.ServeHTTP(w, r)
291	want := "{\"message\":\"some inner logging\"}\n{\"status\":200,\"size\":45,\"message\":\"info\"}\n"
292	got := decodeIfBinary(out)
293	if diff := cmp.Diff(want, got); diff != "" {
294		t.Errorf("TestAccessHandler: %s", diff)
295	}
296
297}
298