1// Copyright 2018 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package webtest
6
7import (
8	"bufio"
9	"bytes"
10	"encoding/hex"
11	"flag"
12	"fmt"
13	"io"
14	"io/ioutil"
15	"log"
16	"net/http"
17	"os"
18	"sort"
19	"strconv"
20	"strings"
21	"sync"
22	"unicode/utf8"
23
24	web "cmd/go/internal/web2"
25)
26
27var mode = flag.String("webtest", "replay", "set webtest `mode` - record, replay, bypass")
28
29func Hook() {
30	if *mode == "bypass" {
31		return
32	}
33	web.SetHTTPDoForTesting(Do)
34}
35
36func Unhook() {
37	web.SetHTTPDoForTesting(nil)
38}
39
40func Print() {
41	web.SetHTTPDoForTesting(DoPrint)
42}
43
44var responses struct {
45	mu    sync.Mutex
46	byURL map[string]*respEntry
47}
48
49type respEntry struct {
50	status string
51	code   int
52	hdr    http.Header
53	body   []byte
54}
55
56func Serve(url string, status string, hdr http.Header, body []byte) {
57	if status == "" {
58		status = "200 OK"
59	}
60	code, err := strconv.Atoi(strings.Fields(status)[0])
61	if err != nil {
62		panic("bad Serve status - " + status + " - " + err.Error())
63	}
64
65	responses.mu.Lock()
66	defer responses.mu.Unlock()
67
68	if responses.byURL == nil {
69		responses.byURL = make(map[string]*respEntry)
70	}
71	responses.byURL[url] = &respEntry{status: status, code: code, hdr: web.CopyHeader(hdr), body: body}
72}
73
74func Do(req *http.Request) (*http.Response, error) {
75	if req.Method != "GET" {
76		return nil, fmt.Errorf("bad method - must be GET")
77	}
78
79	responses.mu.Lock()
80	e := responses.byURL[req.URL.String()]
81	responses.mu.Unlock()
82
83	if e == nil {
84		if *mode == "record" {
85			loaded.mu.Lock()
86			if len(loaded.did) != 1 {
87				loaded.mu.Unlock()
88				return nil, fmt.Errorf("cannot use -webtest=record with multiple loaded response files")
89			}
90			var file string
91			for file = range loaded.did {
92				break
93			}
94			loaded.mu.Unlock()
95			return doSave(file, req)
96		}
97		e = &respEntry{code: 599, status: "599 unexpected request (no canned response)"}
98	}
99	resp := &http.Response{
100		Status:     e.status,
101		StatusCode: e.code,
102		Header:     web.CopyHeader(e.hdr),
103		Body:       ioutil.NopCloser(bytes.NewReader(e.body)),
104	}
105	return resp, nil
106}
107
108func DoPrint(req *http.Request) (*http.Response, error) {
109	return doSave("", req)
110}
111
112func doSave(file string, req *http.Request) (*http.Response, error) {
113	resp, err := http.DefaultClient.Do(req)
114	if err != nil {
115		return nil, err
116	}
117	data, err := ioutil.ReadAll(resp.Body)
118	resp.Body.Close()
119	if err != nil {
120		return nil, err
121	}
122	resp.Body = ioutil.NopCloser(bytes.NewReader(data))
123
124	var f *os.File
125	if file == "" {
126		f = os.Stderr
127	} else {
128		f, err = os.OpenFile(file, os.O_WRONLY|os.O_CREATE|os.O_APPEND, 0666)
129		if err != nil {
130			log.Fatal(err)
131		}
132		defer f.Close()
133	}
134
135	fmt.Fprintf(f, "GET %s\n", req.URL.String())
136	fmt.Fprintf(f, "%s\n", resp.Status)
137	var keys []string
138	for k := range resp.Header {
139		keys = append(keys, k)
140	}
141	sort.Strings(keys)
142	for _, k := range keys {
143		if k == "Set-Cookie" {
144			continue
145		}
146		for _, v := range resp.Header[k] {
147			fmt.Fprintf(f, "%s: %s\n", k, v)
148		}
149	}
150	fmt.Fprintf(f, "\n")
151	if utf8.Valid(data) && !bytes.Contains(data, []byte("\nGET")) && !isHexDump(data) {
152		fmt.Fprintf(f, "%s\n\n", data)
153	} else {
154		fmt.Fprintf(f, "%s\n", hex.Dump(data))
155	}
156	return resp, err
157}
158
159var loaded struct {
160	mu  sync.Mutex
161	did map[string]bool
162}
163
164func LoadOnce(file string) {
165	loaded.mu.Lock()
166	if loaded.did[file] {
167		loaded.mu.Unlock()
168		return
169	}
170	if loaded.did == nil {
171		loaded.did = make(map[string]bool)
172	}
173	loaded.did[file] = true
174	loaded.mu.Unlock()
175
176	f, err := os.Open(file)
177	if err != nil {
178		log.Fatal(err)
179	}
180	defer f.Close()
181
182	b := bufio.NewReader(f)
183	var ungetLine string
184	nextLine := func() string {
185		if ungetLine != "" {
186			l := ungetLine
187			ungetLine = ""
188			return l
189		}
190		line, err := b.ReadString('\n')
191		if err != nil {
192			if err == io.EOF {
193				return ""
194			}
195			log.Fatalf("%s: unexpected read error: %v", file, err)
196		}
197		return line
198	}
199
200	for {
201		line := nextLine()
202		if line == "" { // EOF
203			break
204		}
205		line = strings.TrimSpace(line)
206		if strings.HasPrefix(line, "#") || line == "" {
207			continue
208		}
209		if !strings.HasPrefix(line, "GET ") {
210			log.Fatalf("%s: malformed GET line: %s", file, line)
211		}
212		url := line[len("GET "):]
213		status := nextLine()
214		if _, err := strconv.Atoi(strings.Fields(status)[0]); err != nil {
215			log.Fatalf("%s: malformed status line (after GET %s): %s", file, url, status)
216		}
217		hdr := make(http.Header)
218		for {
219			kv := strings.TrimSpace(nextLine())
220			if kv == "" {
221				break
222			}
223			i := strings.Index(kv, ":")
224			if i < 0 {
225				log.Fatalf("%s: malformed header line (after GET %s): %s", file, url, kv)
226			}
227			k, v := kv[:i], strings.TrimSpace(kv[i+1:])
228			hdr[k] = append(hdr[k], v)
229		}
230
231		var body []byte
232	Body:
233		for n := 0; ; n++ {
234			line := nextLine()
235			if n == 0 && isHexDump([]byte(line)) {
236				ungetLine = line
237				b, err := parseHexDump(nextLine)
238				if err != nil {
239					log.Fatalf("%s: malformed hex dump (after GET %s): %v", file, url, err)
240				}
241				body = b
242				break
243			}
244			if line == "" { // EOF
245				for i := 0; i < 2; i++ {
246					if len(body) > 0 && body[len(body)-1] == '\n' {
247						body = body[:len(body)-1]
248					}
249				}
250				break
251			}
252			body = append(body, line...)
253			for line == "\n" {
254				line = nextLine()
255				if strings.HasPrefix(line, "GET ") {
256					ungetLine = line
257					body = body[:len(body)-1]
258					if len(body) > 0 {
259						body = body[:len(body)-1]
260					}
261					break Body
262				}
263				body = append(body, line...)
264			}
265		}
266
267		Serve(url, status, hdr, body)
268	}
269}
270
271func isHexDump(data []byte) bool {
272	return bytes.HasPrefix(data, []byte("00000000  ")) || bytes.HasPrefix(data, []byte("0000000 "))
273}
274
275// parseHexDump parses the hex dump in text, which should be the
276// output of "hexdump -C" or Plan 9's "xd -b" or Go's hex.Dump
277// and returns the original data used to produce the dump.
278// It is meant to enable storing golden binary files as text, so that
279// changes to the golden files can be seen during code reviews.
280func parseHexDump(nextLine func() string) ([]byte, error) {
281	var out []byte
282	for {
283		line := nextLine()
284		if line == "" || line == "\n" {
285			break
286		}
287		if i := strings.Index(line, "|"); i >= 0 { // remove text dump
288			line = line[:i]
289		}
290		f := strings.Fields(line)
291		if len(f) > 1+16 {
292			return nil, fmt.Errorf("parsing hex dump: too many fields on line %q", line)
293		}
294		if len(f) == 0 || len(f) == 1 && f[0] == "*" { // all zeros block omitted
295			continue
296		}
297		addr64, err := strconv.ParseUint(f[0], 16, 0)
298		if err != nil {
299			return nil, fmt.Errorf("parsing hex dump: invalid address %q", f[0])
300		}
301		addr := int(addr64)
302		if len(out) < addr {
303			out = append(out, make([]byte, addr-len(out))...)
304		}
305		for _, x := range f[1:] {
306			val, err := strconv.ParseUint(x, 16, 8)
307			if err != nil {
308				return nil, fmt.Errorf("parsing hexdump: invalid hex byte %q", x)
309			}
310			out = append(out, byte(val))
311		}
312	}
313	return out, nil
314}
315