1// Copyright 2018 Google LLC 2// 3// Licensed under the Apache License, Version 2.0 (the "License"); 4// you may not use this file except in compliance with the License. 5// You may obtain a copy of the License at 6// 7// http://www.apache.org/licenses/LICENSE-2.0 8// 9// Unless required by applicable law or agreed to in writing, software 10// distributed under the License is distributed on an "AS IS" BASIS, 11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12// See the License for the specific language governing permissions and 13// limitations under the License. 14 15package proxy 16 17import ( 18 "bytes" 19 "encoding/json" 20 "errors" 21 "fmt" 22 "io/ioutil" 23 "log" 24 "net/http" 25 "reflect" 26 "sync" 27 28 "github.com/google/martian/martianlog" 29) 30 31// ForReplaying returns a Proxy configured to replay. 32func ForReplaying(filename string, port int) (*Proxy, error) { 33 p, err := newProxy(filename) 34 if err != nil { 35 return nil, err 36 } 37 lg, err := readLog(filename) 38 if err != nil { 39 return nil, err 40 } 41 calls, err := constructCalls(lg) 42 if err != nil { 43 return nil, err 44 } 45 p.Initial = lg.Initial 46 p.mproxy.SetRoundTripper(&replayRoundTripper{ 47 calls: calls, 48 ignoreHeaders: p.ignoreHeaders, 49 conv: lg.Converter, 50 }) 51 52 // Debug logging. 53 // TODO(jba): factor out from here and ForRecording. 54 logger := martianlog.NewLogger() 55 logger.SetDecode(true) 56 p.mproxy.SetRequestModifier(logger) 57 p.mproxy.SetResponseModifier(logger) 58 59 if err := p.start(port); err != nil { 60 return nil, err 61 } 62 return p, nil 63} 64 65func readLog(filename string) (*Log, error) { 66 bytes, err := ioutil.ReadFile(filename) 67 if err != nil { 68 return nil, err 69 } 70 var lg Log 71 if err := json.Unmarshal(bytes, &lg); err != nil { 72 return nil, fmt.Errorf("%s: %v", filename, err) 73 } 74 if lg.Version != LogVersion { 75 return nil, fmt.Errorf( 76 "httpreplay: read log version %s but current version is %s; re-record the log", 77 lg.Version, LogVersion) 78 } 79 return &lg, nil 80} 81 82// A call is an HTTP request and its matching response. 83type call struct { 84 req *Request 85 res *Response 86} 87 88func constructCalls(lg *Log) ([]*call, error) { 89 ignoreIDs := map[string]bool{} // IDs of requests to ignore 90 callsByID := map[string]*call{} 91 var calls []*call 92 for _, e := range lg.Entries { 93 if ignoreIDs[e.ID] { 94 continue 95 } 96 c, ok := callsByID[e.ID] 97 switch { 98 case !ok: 99 if e.Request == nil { 100 return nil, fmt.Errorf("first entry for ID %s does not have a request", e.ID) 101 } 102 if e.Request.Method == "CONNECT" { 103 // Ignore CONNECT methods. 104 ignoreIDs[e.ID] = true 105 } else { 106 c := &call{e.Request, e.Response} 107 calls = append(calls, c) 108 callsByID[e.ID] = c 109 } 110 case e.Request != nil: 111 if e.Response != nil { 112 return nil, errors.New("entry has both request and response") 113 } 114 c.req = e.Request 115 case e.Response != nil: 116 c.res = e.Response 117 default: 118 return nil, errors.New("entry has neither request nor response") 119 } 120 } 121 for _, c := range calls { 122 if c.req == nil || c.res == nil { 123 return nil, fmt.Errorf("missing request or response: %+v", c) 124 } 125 } 126 return calls, nil 127} 128 129type replayRoundTripper struct { 130 mu sync.Mutex 131 calls []*call 132 ignoreHeaders map[string]bool 133 conv *Converter 134} 135 136func (r *replayRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { 137 if req.Body != nil { 138 defer req.Body.Close() 139 } 140 creq, err := r.conv.convertRequest(req) 141 if err != nil { 142 return nil, err 143 } 144 r.mu.Lock() 145 defer r.mu.Unlock() 146 for i, call := range r.calls { 147 if call == nil { 148 continue 149 } 150 if requestsMatch(creq, call.req, r.ignoreHeaders) { 151 r.calls[i] = nil // nil out this call so we don't reuse it 152 return toHTTPResponse(call.res, req), nil 153 } 154 } 155 return nil, fmt.Errorf("no matching request for %+v", req) 156} 157 158// Report whether the incoming request in matches the candidate request cand. 159func requestsMatch(in, cand *Request, ignoreHeaders map[string]bool) bool { 160 if in.Method != cand.Method { 161 return false 162 } 163 if in.URL != cand.URL { 164 return false 165 } 166 if in.MediaType != cand.MediaType { 167 return false 168 } 169 if len(in.BodyParts) != len(cand.BodyParts) { 170 return false 171 } 172 for i, p1 := range in.BodyParts { 173 if !bytes.Equal(p1, cand.BodyParts[i]) { 174 return false 175 } 176 } 177 // Check headers last. See DebugHeaders. 178 return headersMatch(in.Header, cand.Header, ignoreHeaders) 179} 180 181// DebugHeaders helps to determine whether a header should be ignored. 182// When true, if requests have the same method, URL and body but differ 183// in a header, the first mismatched header is logged. 184var DebugHeaders = false 185 186func headersMatch(in, cand http.Header, ignores map[string]bool) bool { 187 for k1, v1 := range in { 188 if ignores[k1] { 189 continue 190 } 191 v2 := cand[k1] 192 if v2 == nil { 193 if DebugHeaders { 194 log.Printf("header %s: present in incoming request but not candidate", k1) 195 } 196 return false 197 } 198 if !reflect.DeepEqual(v1, v2) { 199 if DebugHeaders { 200 log.Printf("header %s: incoming %v, candidate %v", k1, v1, v2) 201 } 202 return false 203 } 204 } 205 for k2 := range cand { 206 if ignores[k2] { 207 continue 208 } 209 if in[k2] == nil { 210 if DebugHeaders { 211 log.Printf("header %s: not in incoming request but present in candidate", k2) 212 } 213 return false 214 } 215 } 216 return true 217} 218