1package log
2
3import (
4	"bufio"
5	"net"
6	"net/http"
7	"time"
8
9	"github.com/sebest/xff"
10	"github.com/sirupsen/logrus"
11	"gitlab.com/gitlab-org/labkit/correlation"
12	"gitlab.com/gitlab-org/labkit/mask"
13)
14
15// AccessLogger will generate access logs for the service.
16func AccessLogger(h http.Handler, opts ...AccessLoggerOption) http.Handler {
17	config := applyAccessLoggerOptions(opts)
18
19	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
20		lrw := newLoggingResponseWriter(w, &config)
21		defer lrw.requestFinished(r)
22
23		h.ServeHTTP(lrw, r)
24		lrw.setStatus()
25	})
26}
27
28func newLoggingResponseWriter(rw http.ResponseWriter, config *accessLoggerConfig) notifiableResponseWriter {
29	out := loggingResponseWriter{
30		rw:      rw,
31		started: time.Now(),
32		config:  config,
33	}
34
35	// If the underlying response writer supports hijacking,
36	// we need to ensure that we do too
37	if _, ok := rw.(http.Hijacker); ok {
38		return &hijackingResponseWriter{out}
39	}
40
41	return &out
42}
43
44// notifiableResponseWriter is a response writer that can be notified when the request is complete,
45// via the requestFinished method.
46type notifiableResponseWriter interface {
47	http.ResponseWriter
48
49	// requestFinished is called by the middleware when the request has completed
50	requestFinished(r *http.Request)
51	setStatus()
52}
53
54type loggingResponseWriter struct {
55	rw          http.ResponseWriter
56	status      int
57	wroteHeader bool
58	written     int64
59	started     time.Time
60	ttfb        time.Duration
61	config      *accessLoggerConfig
62	contentType string
63}
64
65func (l *loggingResponseWriter) Header() http.Header {
66	return l.rw.Header()
67}
68
69func (l *loggingResponseWriter) Write(data []byte) (int, error) {
70	if !l.wroteHeader {
71		l.WriteHeader(http.StatusOK)
72	}
73	n, err := l.rw.Write(data)
74
75	l.written += int64(n)
76	return n, err
77}
78
79func (l *loggingResponseWriter) WriteHeader(status int) {
80	if l.wroteHeader {
81		return
82	}
83	l.wroteHeader = true
84	l.status = status
85	l.contentType = l.Header().Get("Content-Type")
86	l.ttfb = time.Since(l.started)
87
88	l.rw.WriteHeader(status)
89}
90
91func (l *loggingResponseWriter) setStatus() {
92	if !l.wroteHeader {
93		// If upstream never called WriteHeader, the Go net/http server will
94		// respond with status 200 to the client. We should also log status 200
95		// in that case.
96		l.status = http.StatusOK
97	}
98}
99
100//nolint:cyclop
101func (l *loggingResponseWriter) accessLogFields(r *http.Request) logrus.Fields {
102	duration := time.Since(l.started)
103
104	fields := l.config.extraFields(r)
105	fieldsBitMask := l.config.fields
106
107	// Optionally add built in fields
108	if fieldsBitMask&CorrelationID != 0 {
109		fields[correlation.FieldName] = correlation.ExtractFromContext(r.Context())
110	}
111
112	if fieldsBitMask&HTTPHost != 0 {
113		fields[httpHostField] = r.Host
114	}
115
116	if fieldsBitMask&HTTPRemoteIP != 0 {
117		fields[httpRemoteIPField] = l.getRemoteIP(r)
118	}
119
120	if fieldsBitMask&HTTPRemoteAddr != 0 {
121		fields[httpRemoteAddrField] = r.RemoteAddr
122	}
123
124	if fieldsBitMask&HTTPRequestMethod != 0 {
125		fields[httpRequestMethodField] = r.Method
126	}
127
128	if fieldsBitMask&HTTPURI != 0 {
129		fields[httpURIField] = mask.URL(r.RequestURI)
130	}
131
132	if fieldsBitMask&HTTPProto != 0 {
133		fields[httpProtoField] = r.Proto
134	}
135
136	if fieldsBitMask&HTTPResponseStatusCode != 0 {
137		fields[httpResponseStatusCodeField] = l.status
138	}
139
140	if fieldsBitMask&HTTPResponseSize != 0 {
141		fields[httpResponseSizeField] = l.written
142	}
143
144	if fieldsBitMask&HTTPRequestReferrer != 0 {
145		fields[httpRequestReferrerField] = mask.URL(r.Referer())
146	}
147
148	if fieldsBitMask&HTTPUserAgent != 0 {
149		fields[httpUserAgentField] = r.UserAgent()
150	}
151
152	if fieldsBitMask&RequestDuration != 0 {
153		fields[requestDurationField] = int64(duration / time.Millisecond)
154	}
155
156	if fieldsBitMask&RequestTTFB != 0 && l.ttfb > 0 {
157		fields[requestTTFBField] = l.ttfb.Milliseconds()
158	}
159
160	if fieldsBitMask&System != 0 {
161		fields[systemField] = "http"
162	}
163
164	if fieldsBitMask&HTTPResponseContentType != 0 {
165		fields[httpResponseContentTypeField] = l.contentType
166	}
167
168	return fields
169}
170
171func (l *loggingResponseWriter) requestFinished(r *http.Request) {
172	l.config.logger.WithFields(l.accessLogFields(r)).Info("access")
173}
174
175func (l *loggingResponseWriter) getRemoteIP(r *http.Request) string {
176	remoteAddr := xff.GetRemoteAddrIfAllowed(r, l.config.xffAllowed)
177	host, _, err := net.SplitHostPort(remoteAddr)
178	if err != nil {
179		return r.RemoteAddr
180	}
181
182	return host
183}
184
185// hijackingResponseWriter is like a loggingResponseWriter that supports the http.Hijacker interface.
186type hijackingResponseWriter struct {
187	loggingResponseWriter
188}
189
190func (l *hijackingResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
191	// The only way to get here is through NewStatsCollectingResponseWriter(), which
192	// checks that this cast will be valid.
193	hijacker := l.rw.(http.Hijacker)
194	return hijacker.Hijack()
195}
196