1// Copyright 2018 Google LLC
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//      http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15package proxy
16
17import (
18	"bytes"
19	"encoding/json"
20	"errors"
21	"fmt"
22	"io/ioutil"
23	"log"
24	"net/http"
25	"reflect"
26	"sync"
27
28	"github.com/google/martian/martianlog"
29)
30
31// ForReplaying returns a Proxy configured to replay.
32func ForReplaying(filename string, port int) (*Proxy, error) {
33	p, err := newProxy(filename)
34	if err != nil {
35		return nil, err
36	}
37	lg, err := readLog(filename)
38	if err != nil {
39		return nil, err
40	}
41	calls, err := constructCalls(lg)
42	if err != nil {
43		return nil, err
44	}
45	p.Initial = lg.Initial
46	p.mproxy.SetRoundTripper(&replayRoundTripper{
47		calls:         calls,
48		ignoreHeaders: p.ignoreHeaders,
49		conv:          lg.Converter,
50	})
51
52	// Debug logging.
53	// TODO(jba): factor out from here and ForRecording.
54	logger := martianlog.NewLogger()
55	logger.SetDecode(true)
56	p.mproxy.SetRequestModifier(logger)
57	p.mproxy.SetResponseModifier(logger)
58
59	if err := p.start(port); err != nil {
60		return nil, err
61	}
62	return p, nil
63}
64
65func readLog(filename string) (*Log, error) {
66	bytes, err := ioutil.ReadFile(filename)
67	if err != nil {
68		return nil, err
69	}
70	var lg Log
71	if err := json.Unmarshal(bytes, &lg); err != nil {
72		return nil, fmt.Errorf("%s: %v", filename, err)
73	}
74	if lg.Version != LogVersion {
75		return nil, fmt.Errorf(
76			"httpreplay: read log version %s but current version is %s; re-record the log",
77			lg.Version, LogVersion)
78	}
79	return &lg, nil
80}
81
82// A call is an HTTP request and its matching response.
83type call struct {
84	req *Request
85	res *Response
86}
87
88func constructCalls(lg *Log) ([]*call, error) {
89	ignoreIDs := map[string]bool{} // IDs of requests to ignore
90	callsByID := map[string]*call{}
91	var calls []*call
92	for _, e := range lg.Entries {
93		if ignoreIDs[e.ID] {
94			continue
95		}
96		c, ok := callsByID[e.ID]
97		switch {
98		case !ok:
99			if e.Request == nil {
100				return nil, fmt.Errorf("first entry for ID %s does not have a request", e.ID)
101			}
102			if e.Request.Method == "CONNECT" {
103				// Ignore CONNECT methods.
104				ignoreIDs[e.ID] = true
105			} else {
106				c := &call{e.Request, e.Response}
107				calls = append(calls, c)
108				callsByID[e.ID] = c
109			}
110		case e.Request != nil:
111			if e.Response != nil {
112				return nil, errors.New("entry has both request and response")
113			}
114			c.req = e.Request
115		case e.Response != nil:
116			c.res = e.Response
117		default:
118			return nil, errors.New("entry has neither request nor response")
119		}
120	}
121	for _, c := range calls {
122		if c.req == nil || c.res == nil {
123			return nil, fmt.Errorf("missing request or response: %+v", c)
124		}
125	}
126	return calls, nil
127}
128
129type replayRoundTripper struct {
130	mu            sync.Mutex
131	calls         []*call
132	ignoreHeaders map[string]bool
133	conv          *Converter
134}
135
136func (r *replayRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
137	if req.Body != nil {
138		defer req.Body.Close()
139	}
140	creq, err := r.conv.convertRequest(req)
141	if err != nil {
142		return nil, err
143	}
144	r.mu.Lock()
145	defer r.mu.Unlock()
146	for i, call := range r.calls {
147		if call == nil {
148			continue
149		}
150		if requestsMatch(creq, call.req, r.ignoreHeaders) {
151			r.calls[i] = nil // nil out this call so we don't reuse it
152			return toHTTPResponse(call.res, req), nil
153		}
154	}
155	return nil, fmt.Errorf("no matching request for %+v", req)
156}
157
158// Report whether the incoming request in matches the candidate request cand.
159func requestsMatch(in, cand *Request, ignoreHeaders map[string]bool) bool {
160	if in.Method != cand.Method {
161		return false
162	}
163	if in.URL != cand.URL {
164		return false
165	}
166	if in.MediaType != cand.MediaType {
167		return false
168	}
169	if len(in.BodyParts) != len(cand.BodyParts) {
170		return false
171	}
172	for i, p1 := range in.BodyParts {
173		if !bytes.Equal(p1, cand.BodyParts[i]) {
174			return false
175		}
176	}
177	// Check headers last. See DebugHeaders.
178	return headersMatch(in.Header, cand.Header, ignoreHeaders)
179}
180
181// DebugHeaders helps to determine whether a header should be ignored.
182// When true, if requests have the same method, URL and body but differ
183// in a header, the first mismatched header is logged.
184var DebugHeaders = false
185
186func headersMatch(in, cand http.Header, ignores map[string]bool) bool {
187	for k1, v1 := range in {
188		if ignores[k1] {
189			continue
190		}
191		v2 := cand[k1]
192		if v2 == nil {
193			if DebugHeaders {
194				log.Printf("header %s: present in incoming request but not candidate", k1)
195			}
196			return false
197		}
198		if !reflect.DeepEqual(v1, v2) {
199			if DebugHeaders {
200				log.Printf("header %s: incoming %v, candidate %v", k1, v1, v2)
201			}
202			return false
203		}
204	}
205	for k2 := range cand {
206		if ignores[k2] {
207			continue
208		}
209		if in[k2] == nil {
210			if DebugHeaders {
211				log.Printf("header %s: not in incoming request but present in candidate", k2)
212			}
213			return false
214		}
215	}
216	return true
217}
218