1// Copyright 2018 Google LLC
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//      http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15// +build go1.8
16
17package proxy
18
19import (
20	"bytes"
21	"encoding/json"
22	"errors"
23	"fmt"
24	"io"
25	"io/ioutil"
26	"log"
27	"mime"
28	"mime/multipart"
29	"net/http"
30	"reflect"
31	"strings"
32	"sync"
33
34	"github.com/google/martian/martianlog"
35)
36
37// ForReplaying returns a Proxy configured to replay.
38func ForReplaying(filename string, port int) (*Proxy, error) {
39	p, err := newProxy(filename)
40	if err != nil {
41		return nil, err
42	}
43	calls, initial, err := readLog(filename)
44	if err != nil {
45		return nil, err
46	}
47	p.mproxy.SetRoundTripper(&replayRoundTripper{
48		calls:         calls,
49		ignoreHeaders: p.ignoreHeaders,
50	})
51	p.Initial = initial
52
53	// Debug logging.
54	// TODO(jba): factor out from here and ForRecording.
55	logger := martianlog.NewLogger()
56	logger.SetDecode(true)
57	p.mproxy.SetRequestModifier(logger)
58	p.mproxy.SetResponseModifier(logger)
59
60	if err := p.start(port); err != nil {
61		return nil, err
62	}
63	return p, nil
64}
65
66// A call is an HTTP request and its matching response.
67type call struct {
68	req     *Request
69	reqBody *requestBody // parsed request body
70	res     *Response
71}
72
73func readLog(filename string) ([]*call, []byte, error) {
74	bytes, err := ioutil.ReadFile(filename)
75	if err != nil {
76		return nil, nil, err
77	}
78	var lg Log
79	if err := json.Unmarshal(bytes, &lg); err != nil {
80		return nil, nil, fmt.Errorf("%s: %v", filename, err)
81	}
82	if lg.Version != LogVersion {
83		return nil, nil, fmt.Errorf("httpreplay proxy: read log version %s but current version is %s",
84			lg.Version, LogVersion)
85	}
86	ignoreIDs := map[string]bool{} // IDs of requests to ignore
87	callsByID := map[string]*call{}
88	var calls []*call
89	for _, e := range lg.Entries {
90		if ignoreIDs[e.ID] {
91			continue
92		}
93		c, ok := callsByID[e.ID]
94		switch {
95		case !ok:
96			if e.Request == nil {
97				return nil, nil, fmt.Errorf("first entry for ID %s does not have a request", e.ID)
98			}
99			if e.Request.Method == "CONNECT" {
100				// Ignore CONNECT methods.
101				ignoreIDs[e.ID] = true
102			} else {
103				reqBody, err := newRequestBodyFromLog(e.Request)
104				if err != nil {
105					return nil, nil, err
106				}
107				c := &call{e.Request, reqBody, e.Response}
108				calls = append(calls, c)
109				callsByID[e.ID] = c
110			}
111		case e.Request != nil:
112			if e.Response != nil {
113				return nil, nil, errors.New("HAR entry has both request and response")
114			}
115			c.req = e.Request
116		case e.Response != nil:
117			c.res = e.Response
118		default:
119			return nil, nil, errors.New("HAR entry has neither request nor response")
120		}
121	}
122	for _, c := range calls {
123		if c.req == nil || c.res == nil {
124			return nil, nil, fmt.Errorf("missing request or response: %+v", c)
125		}
126	}
127	return calls, lg.Initial, nil
128}
129
130type replayRoundTripper struct {
131	mu            sync.Mutex
132	calls         []*call
133	ignoreHeaders map[string]bool
134}
135
136func (r *replayRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
137	reqBody, err := newRequestBodyFromHTTP(req)
138	if err != nil {
139		return nil, err
140	}
141	r.mu.Lock()
142	defer r.mu.Unlock()
143	for i, call := range r.calls {
144		if call == nil {
145			continue
146		}
147		if requestsMatch(req, reqBody, call.req, call.reqBody, r.ignoreHeaders) {
148			r.calls[i] = nil // nil out this call so we don't reuse it
149			return toHTTPResponse(call.res, req), nil
150		}
151	}
152	return nil, fmt.Errorf("no matching request for %+v", req)
153}
154
155// Headers that shouldn't be compared, because they may differ on different executions
156// of the same code, or may not be present during record or replay.
157var ignoreHeaders = map[string]bool{}
158
159func init() {
160	// Sensitive headers are redacted in the log, so they won't be equal to incoming values.
161	for h := range sensitiveHeaders {
162		ignoreHeaders[h] = true
163	}
164	for _, h := range []string{
165		"Content-Type", // handled by requestBody
166		"Connection",
167		"Date",
168		"Host",
169		"Transfer-Encoding",
170		"Via",
171		"X-Forwarded-For",
172		"X-Forwarded-Host",
173		"X-Forwarded-Proto",
174		"X-Forwarded-Url",
175		"X-Cloud-Trace-Context", // OpenCensus traces have a random ID
176		"X-Goog-Api-Client",     // can differ for, e.g., different Go versions
177	} {
178		ignoreHeaders[h] = true
179	}
180}
181
182// Report whether the incoming request in matches the candidate request cand.
183func requestsMatch(in *http.Request, inBody *requestBody, cand *Request, candBody *requestBody, ignoreHeaders map[string]bool) bool {
184	if in.Method != cand.Method {
185		return false
186	}
187	if in.URL.String() != cand.URL {
188		return false
189	}
190	if !inBody.equal(candBody) {
191		return false
192	}
193	// Check headers last. See DebugHeaders.
194	return headersMatch(in.Header, cand.Header, ignoreHeaders)
195}
196
197// A requestBody represents the body of a request. If the content type is multipart, the
198// body is split into parts.
199//
200// The replaying proxy needs to understand multipart bodies because the boundaries are
201// generated randomly, so we can't just compare the entire bodies for equality.
202type requestBody struct {
203	mediaType string   // the media type part of the Content-Type header
204	parts     [][]byte // the parts of the body, or just a single []byte if not multipart
205}
206
207func newRequestBodyFromHTTP(req *http.Request) (*requestBody, error) {
208	defer req.Body.Close()
209	return newRequestBody(req.Header.Get("Content-Type"), req.Body)
210}
211
212func newRequestBodyFromLog(req *Request) (*requestBody, error) {
213	if req.Body == nil {
214		return nil, nil
215	}
216	return newRequestBody(req.Header.Get("Content-Type"), bytes.NewReader(req.Body))
217}
218
219// newRequestBody parses the Content-Type header, reads the body, and splits it into
220// parts if necessary.
221func newRequestBody(contentType string, body io.Reader) (*requestBody, error) {
222	if contentType == "" {
223		// No content-type header. There should not be a body.
224		if _, err := body.Read(make([]byte, 1)); err != io.EOF {
225			return nil, errors.New("no Content-Type, but body")
226		}
227		return nil, nil
228	}
229	mediaType, params, err := mime.ParseMediaType(contentType)
230	if err != nil {
231		return nil, err
232	}
233	rb := &requestBody{mediaType: mediaType}
234	if strings.HasPrefix(mediaType, "multipart/") {
235		mr := multipart.NewReader(body, params["boundary"])
236		for {
237			p, err := mr.NextPart()
238			if err == io.EOF {
239				break
240			}
241			if err != nil {
242				return nil, err
243			}
244			part, err := ioutil.ReadAll(p)
245			if err != nil {
246				return nil, err
247			}
248			// TODO(jba): care about part headers?
249			rb.parts = append(rb.parts, part)
250		}
251	} else {
252		bytes, err := ioutil.ReadAll(body)
253		if err != nil {
254			return nil, err
255		}
256		rb.parts = [][]byte{bytes}
257	}
258	return rb, nil
259}
260
261func (r1 *requestBody) equal(r2 *requestBody) bool {
262	if r1 == nil || r2 == nil {
263		return r1 == r2
264	}
265	if r1.mediaType != r2.mediaType {
266		return false
267	}
268	if len(r1.parts) != len(r2.parts) {
269		return false
270	}
271	for i, p1 := range r1.parts {
272		if !bytes.Equal(p1, r2.parts[i]) {
273			return false
274		}
275	}
276	return true
277}
278
279// DebugHeaders helps to determine whether a header should be ignored.
280// When true, if requests have the same method, URL and body but differ
281// in a header, the first mismatched header is logged.
282var DebugHeaders = false
283
284func headersMatch(in, cand http.Header, ignores map[string]bool) bool {
285	for k1, v1 := range in {
286		if ignores[k1] {
287			continue
288		}
289		v2 := cand[k1]
290		if v2 == nil {
291			if DebugHeaders {
292				log.Printf("header %s: present in incoming request but not candidate", k1)
293			}
294			return false
295		}
296		if !reflect.DeepEqual(v1, v2) {
297			if DebugHeaders {
298				log.Printf("header %s: incoming %v, candidate %v", k1, v1, v2)
299			}
300			return false
301		}
302	}
303	for k2 := range cand {
304		if ignores[k2] {
305			continue
306		}
307		if in[k2] == nil {
308			if DebugHeaders {
309				log.Printf("header %s: not in incoming request but present in candidate", k2)
310			}
311			return false
312		}
313	}
314	return true
315}
316