1// Copyright 2015 Google Inc. All rights reserved.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15// Package messageview provides no-op snapshots for HTTP requests and
16// responses.
17package messageview
18
19import (
20	"bytes"
21	"compress/flate"
22	"compress/gzip"
23	"fmt"
24	"io"
25	"io/ioutil"
26	"net/http"
27	"net/http/httputil"
28	"strings"
29)
30
31// MessageView is a static view of an HTTP request or response.
32type MessageView struct {
33	message       []byte
34	cts           []string
35	chunked       bool
36	skipBody      bool
37	compress      string
38	bodyoffset    int64
39	traileroffset int64
40}
41
42type config struct {
43	decode bool
44}
45
46// Option is a configuration option for a MessageView.
47type Option func(*config)
48
49// Decode sets an option to decode the message body for logging purposes.
50func Decode() Option {
51	return func(c *config) {
52		c.decode = true
53	}
54}
55
56// New returns a new MessageView.
57func New() *MessageView {
58	return &MessageView{}
59}
60
61// SkipBody will skip reading the body when the view is loaded with a request
62// or response.
63func (mv *MessageView) SkipBody(skipBody bool) {
64	mv.skipBody = skipBody
65}
66
67// SkipBodyUnlessContentType will skip reading the body unless the
68// Content-Type matches one in cts.
69func (mv *MessageView) SkipBodyUnlessContentType(cts ...string) {
70	mv.skipBody = true
71	mv.cts = cts
72}
73
74// SnapshotRequest reads the request into the MessageView. If mv.skipBody is false
75// it will also read the body into memory and replace the existing body with
76// the in-memory copy. This method is semantically a no-op.
77func (mv *MessageView) SnapshotRequest(req *http.Request) error {
78	buf := new(bytes.Buffer)
79
80	fmt.Fprintf(buf, "%s %s HTTP/%d.%d\r\n", req.Method,
81		req.URL, req.ProtoMajor, req.ProtoMinor)
82
83	if req.Host != "" {
84		fmt.Fprintf(buf, "Host: %s\r\n", req.Host)
85	}
86
87	if tec := len(req.TransferEncoding); tec > 0 {
88		mv.chunked = req.TransferEncoding[tec-1] == "chunked"
89		fmt.Fprintf(buf, "Transfer-Encoding: %s\r\n", strings.Join(req.TransferEncoding, ", "))
90	}
91	if !mv.chunked && req.ContentLength >= 0 {
92		fmt.Fprintf(buf, "Content-Length: %d\r\n", req.ContentLength)
93	}
94
95	mv.compress = req.Header.Get("Content-Encoding")
96
97	req.Header.WriteSubset(buf, map[string]bool{
98		"Host":              true,
99		"Content-Length":    true,
100		"Transfer-Encoding": true,
101	})
102
103	fmt.Fprint(buf, "\r\n")
104
105	mv.bodyoffset = int64(buf.Len())
106	mv.traileroffset = int64(buf.Len())
107
108	ct := req.Header.Get("Content-Type")
109	if mv.skipBody && !mv.matchContentType(ct) || req.Body == nil {
110		mv.message = buf.Bytes()
111		return nil
112	}
113
114	data, err := ioutil.ReadAll(req.Body)
115	if err != nil {
116		return err
117	}
118	req.Body.Close()
119
120	if mv.chunked {
121		cw := httputil.NewChunkedWriter(buf)
122		cw.Write(data)
123		cw.Close()
124	} else {
125		buf.Write(data)
126	}
127
128	mv.traileroffset = int64(buf.Len())
129
130	req.Body = ioutil.NopCloser(bytes.NewReader(data))
131
132	if req.Trailer != nil {
133		req.Trailer.Write(buf)
134	} else if mv.chunked {
135		fmt.Fprint(buf, "\r\n")
136	}
137
138	mv.message = buf.Bytes()
139
140	return nil
141}
142
143// SnapshotResponse reads the response into the MessageView. If mv.headersOnly
144// is false it will also read the body into memory and replace the existing
145// body with the in-memory copy. This method is semantically a no-op.
146func (mv *MessageView) SnapshotResponse(res *http.Response) error {
147	buf := new(bytes.Buffer)
148
149	fmt.Fprintf(buf, "HTTP/%d.%d %s\r\n", res.ProtoMajor, res.ProtoMinor, res.Status)
150
151	if tec := len(res.TransferEncoding); tec > 0 {
152		mv.chunked = res.TransferEncoding[tec-1] == "chunked"
153		fmt.Fprintf(buf, "Transfer-Encoding: %s\r\n", strings.Join(res.TransferEncoding, ", "))
154	}
155	if !mv.chunked && res.ContentLength >= 0 {
156		fmt.Fprintf(buf, "Content-Length: %d\r\n", res.ContentLength)
157	}
158
159	mv.compress = res.Header.Get("Content-Encoding")
160	// Do not uncompress if we have don't have the full contents.
161	if res.StatusCode == http.StatusNoContent || res.StatusCode == http.StatusPartialContent {
162		mv.compress = ""
163	}
164
165	res.Header.WriteSubset(buf, map[string]bool{
166		"Content-Length":    true,
167		"Transfer-Encoding": true,
168	})
169
170	fmt.Fprint(buf, "\r\n")
171
172	mv.bodyoffset = int64(buf.Len())
173	mv.traileroffset = int64(buf.Len())
174
175	ct := res.Header.Get("Content-Type")
176	if mv.skipBody && !mv.matchContentType(ct) || res.Body == nil {
177		mv.message = buf.Bytes()
178		return nil
179	}
180
181	data, err := ioutil.ReadAll(res.Body)
182	if err != nil {
183		return err
184	}
185	res.Body.Close()
186
187	if mv.chunked {
188		cw := httputil.NewChunkedWriter(buf)
189		cw.Write(data)
190		cw.Close()
191	} else {
192		buf.Write(data)
193	}
194
195	mv.traileroffset = int64(buf.Len())
196
197	res.Body = ioutil.NopCloser(bytes.NewReader(data))
198
199	if res.Trailer != nil {
200		res.Trailer.Write(buf)
201	} else if mv.chunked {
202		fmt.Fprint(buf, "\r\n")
203	}
204
205	mv.message = buf.Bytes()
206
207	return nil
208}
209
210// Reader returns the an io.ReadCloser that reads the full HTTP message.
211func (mv *MessageView) Reader(opts ...Option) (io.ReadCloser, error) {
212	hr := mv.HeaderReader()
213	br, err := mv.BodyReader(opts...)
214	if err != nil {
215		return nil, err
216	}
217	tr := mv.TrailerReader()
218
219	return struct {
220		io.Reader
221		io.Closer
222	}{
223		Reader: io.MultiReader(hr, br, tr),
224		Closer: br,
225	}, nil
226}
227
228// HeaderReader returns an io.Reader that reads the HTTP Status-Line or
229// HTTP Request-Line and headers.
230func (mv *MessageView) HeaderReader() io.Reader {
231	r := bytes.NewReader(mv.message)
232	return io.NewSectionReader(r, 0, mv.bodyoffset)
233}
234
235// BodyReader returns an io.ReadCloser that reads the HTTP request or response
236// body. If mv.skipBody was set the reader will immediately return io.EOF.
237//
238// If the Decode option is passed the body will be unchunked if
239// Transfer-Encoding is set to "chunked", and will decode the following
240// Content-Encodings: gzip, deflate.
241func (mv *MessageView) BodyReader(opts ...Option) (io.ReadCloser, error) {
242	var r io.Reader
243
244	conf := &config{}
245	for _, o := range opts {
246		o(conf)
247	}
248
249	br := bytes.NewReader(mv.message)
250	r = io.NewSectionReader(br, mv.bodyoffset, mv.traileroffset-mv.bodyoffset)
251
252	if !conf.decode {
253		return ioutil.NopCloser(r), nil
254	}
255
256	if mv.chunked {
257		r = httputil.NewChunkedReader(r)
258	}
259	switch mv.compress {
260	case "gzip":
261		gr, err := gzip.NewReader(r)
262		if err != nil {
263			return nil, err
264		}
265		return gr, nil
266	case "deflate":
267		return flate.NewReader(r), nil
268	default:
269		return ioutil.NopCloser(r), nil
270	}
271}
272
273// TrailerReader returns an io.Reader that reads the HTTP request or response
274// trailers, if present.
275func (mv *MessageView) TrailerReader() io.Reader {
276	r := bytes.NewReader(mv.message)
277	end := int64(len(mv.message)) - mv.traileroffset
278
279	return io.NewSectionReader(r, mv.traileroffset, end)
280}
281
282func (mv *MessageView) matchContentType(mct string) bool {
283	for _, ct := range mv.cts {
284		if strings.HasPrefix(mct, ct) {
285			return true
286		}
287	}
288
289	return false
290}
291