1// Copyright 2018 The Go Cloud Development Kit Authors
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     https://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15// Package requestlog provides an http.Handler that logs information
16// about requests.
17package requestlog // import "gocloud.dev/server/requestlog"
18
19import (
20	"bufio"
21	"errors"
22	"io"
23	"io/ioutil"
24	"net"
25	"net/http"
26	"time"
27
28	"go.opencensus.io/trace"
29)
30
31// Logger wraps the Log method.  Log must be safe to call from multiple
32// goroutines.  Log must not hold onto an Entry after it returns.
33type Logger interface {
34	Log(*Entry)
35}
36
37// A Handler emits request information to a Logger.
38type Handler struct {
39	log Logger
40	h   http.Handler
41}
42
43// NewHandler returns a handler that emits information to log and calls
44// h.ServeHTTP.
45func NewHandler(log Logger, h http.Handler) *Handler {
46	return &Handler{
47		log: log,
48		h:   h,
49	}
50}
51
52// ServeHTTP calls its underlying handler's ServeHTTP method, then calls
53// Log after the handler returns.
54//
55// ServeHTTP will always consume the request body up to the first error,
56// even if the underlying handler does not.
57func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
58	start := time.Now()
59	sc := trace.FromContext(r.Context()).SpanContext()
60	ent := &Entry{
61		ReceivedTime:      start,
62		RequestMethod:     r.Method,
63		RequestURL:        r.URL.String(),
64		RequestHeaderSize: headerSize(r.Header),
65		UserAgent:         r.UserAgent(),
66		Referer:           r.Referer(),
67		Proto:             r.Proto,
68		RemoteIP:          ipFromHostPort(r.RemoteAddr),
69		TraceID:           sc.TraceID,
70		SpanID:            sc.SpanID,
71	}
72	if addr, ok := r.Context().Value(http.LocalAddrContextKey).(net.Addr); ok {
73		ent.ServerIP = ipFromHostPort(addr.String())
74	}
75	r2 := new(http.Request)
76	*r2 = *r
77	rcc := &readCounterCloser{r: r.Body}
78	r2.Body = rcc
79	w2 := &responseStats{w: w}
80
81	h.h.ServeHTTP(w2, r2)
82
83	ent.Latency = time.Since(start)
84	if rcc.err == nil && rcc.r != nil && !w2.hijacked {
85		// If the handler hasn't encountered an error in the Body (like EOF),
86		// then consume the rest of the Body to provide an accurate rcc.n.
87		io.Copy(ioutil.Discard, rcc)
88	}
89	ent.RequestBodySize = rcc.n
90	ent.Status = w2.code
91	if ent.Status == 0 {
92		ent.Status = http.StatusOK
93	}
94	ent.ResponseHeaderSize, ent.ResponseBodySize = w2.size()
95	h.log.Log(ent)
96}
97
98// Entry records information about a completed HTTP request.
99type Entry struct {
100	ReceivedTime      time.Time
101	RequestMethod     string
102	RequestURL        string
103	RequestHeaderSize int64
104	RequestBodySize   int64
105	UserAgent         string
106	Referer           string
107	Proto             string
108
109	RemoteIP string
110	ServerIP string
111
112	Status             int
113	ResponseHeaderSize int64
114	ResponseBodySize   int64
115	Latency            time.Duration
116	TraceID            trace.TraceID
117	SpanID             trace.SpanID
118}
119
120func ipFromHostPort(hp string) string {
121	h, _, err := net.SplitHostPort(hp)
122	if err != nil {
123		return ""
124	}
125	if len(h) > 0 && h[0] == '[' {
126		return h[1 : len(h)-1]
127	}
128	return h
129}
130
131type readCounterCloser struct {
132	r   io.ReadCloser
133	n   int64
134	err error
135}
136
137func (rcc *readCounterCloser) Read(p []byte) (n int, err error) {
138	if rcc.err != nil {
139		return 0, rcc.err
140	}
141	n, rcc.err = rcc.r.Read(p)
142	rcc.n += int64(n)
143	return n, rcc.err
144}
145
146func (rcc *readCounterCloser) Close() error {
147	rcc.err = errors.New("read from closed reader")
148	return rcc.r.Close()
149}
150
151type writeCounter int64
152
153func (wc *writeCounter) Write(p []byte) (n int, err error) {
154	*wc += writeCounter(len(p))
155	return len(p), nil
156}
157
158func headerSize(h http.Header) int64 {
159	var wc writeCounter
160	h.Write(&wc)
161	return int64(wc) + 2 // for CRLF
162}
163
164type responseStats struct {
165	w        http.ResponseWriter
166	hsize    int64
167	wc       writeCounter
168	code     int
169	hijacked bool
170}
171
172func (r *responseStats) Header() http.Header {
173	return r.w.Header()
174}
175
176func (r *responseStats) WriteHeader(statusCode int) {
177	if r.code != 0 {
178		return
179	}
180	r.hsize = headerSize(r.w.Header())
181	r.w.WriteHeader(statusCode)
182	r.code = statusCode
183}
184
185func (r *responseStats) Write(p []byte) (n int, err error) {
186	if r.code == 0 {
187		r.WriteHeader(http.StatusOK)
188	}
189	n, err = r.w.Write(p)
190	r.wc.Write(p[:n])
191	return
192}
193
194func (r *responseStats) size() (hdr, body int64) {
195	if r.code == 0 {
196		return headerSize(r.w.Header()), 0
197	}
198	// Use the header size from the time WriteHeader was called.
199	// The Header map can be mutated after the call to add HTTP Trailers,
200	// which we don't want to count.
201	return r.hsize, int64(r.wc)
202}
203
204func (r *responseStats) Hijack() (_ net.Conn, _ *bufio.ReadWriter, err error) {
205	defer func() {
206		if err == nil {
207			r.hijacked = true
208		}
209	}()
210	if hj, ok := r.w.(http.Hijacker); ok {
211		return hj.Hijack()
212	}
213	return nil, nil, errors.New("underlying ResponseWriter does not support hijacking")
214}
215