1// Copyright 2019 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5// authtest is a diagnostic tool for implementations of the GOAUTH protocol
6// described in https://golang.org/issue/26232.
7//
8// It accepts a single URL as an argument, and executes the GOAUTH protocol to
9// fetch and display the headers for that URL.
10//
11// CAUTION: authtest logs the GOAUTH responses, which may include user
12// credentials, to stderr. Do not post its output unless you are certain that
13// all of the credentials involved are fake!
14package main
15
16import (
17	"bufio"
18	"bytes"
19	"flag"
20	"fmt"
21	exec "golang.org/x/sys/execabs"
22	"io"
23	"log"
24	"net/http"
25	"net/textproto"
26	"net/url"
27	"os"
28	"path/filepath"
29	"strings"
30)
31
32var v = flag.Bool("v", false, "if true, log GOAUTH responses to stderr")
33
34func main() {
35	log.SetFlags(log.LstdFlags | log.Lshortfile)
36	flag.Parse()
37	args := flag.Args()
38	if len(args) != 1 {
39		log.Fatalf("usage: [GOAUTH=CMD...] %s URL", filepath.Base(os.Args[0]))
40	}
41
42	resp := try(args[0], nil)
43	if resp.StatusCode == http.StatusOK {
44		return
45	}
46
47	resp = try(args[0], resp)
48	if resp.StatusCode != http.StatusOK {
49		os.Exit(1)
50	}
51}
52
53func try(url string, prev *http.Response) *http.Response {
54	req := new(http.Request)
55	if prev != nil {
56		*req = *prev.Request
57	} else {
58		var err error
59		req, err = http.NewRequest("HEAD", os.Args[1], nil)
60		if err != nil {
61			log.Fatal(err)
62		}
63	}
64
65goauth:
66	for _, argList := range strings.Split(os.Getenv("GOAUTH"), ";") {
67		// TODO(golang.org/issue/26849): If we escape quoted strings in GOFLAGS, use
68		// the same quoting here.
69		args := strings.Split(argList, " ")
70		if len(args) == 0 || args[0] == "" {
71			log.Fatalf("invalid or empty command in GOAUTH")
72		}
73
74		creds, err := getCreds(args, prev)
75		if err != nil {
76			log.Fatal(err)
77		}
78		for _, c := range creds {
79			if c.Apply(req) {
80				fmt.Fprintf(os.Stderr, "# request to %s\n", req.URL)
81				fmt.Fprintf(os.Stderr, "%s %s %s\n", req.Method, req.URL, req.Proto)
82				req.Header.Write(os.Stderr)
83				fmt.Fprintln(os.Stderr)
84				break goauth
85			}
86		}
87	}
88
89	resp, err := http.DefaultClient.Do(req)
90	if err != nil {
91		log.Fatal(err)
92	}
93	defer resp.Body.Close()
94
95	if resp.StatusCode != http.StatusOK && resp.StatusCode < 400 || resp.StatusCode > 500 {
96		log.Fatalf("unexpected status: %v", resp.Status)
97	}
98
99	fmt.Fprintf(os.Stderr, "# response from %s\n", resp.Request.URL)
100	formatHead(os.Stderr, resp)
101	return resp
102}
103
104func formatHead(out io.Writer, resp *http.Response) {
105	fmt.Fprintf(out, "%s %s\n", resp.Proto, resp.Status)
106	if err := resp.Header.Write(out); err != nil {
107		log.Fatal(err)
108	}
109	fmt.Fprintln(out)
110}
111
112type Cred struct {
113	URLPrefixes []*url.URL
114	Header      http.Header
115}
116
117func (c Cred) Apply(req *http.Request) bool {
118	if req.URL == nil {
119		return false
120	}
121	ok := false
122	for _, prefix := range c.URLPrefixes {
123		if prefix.Host == req.URL.Host &&
124			(req.URL.Path == prefix.Path ||
125				(strings.HasPrefix(req.URL.Path, prefix.Path) &&
126					(strings.HasSuffix(prefix.Path, "/") ||
127						req.URL.Path[len(prefix.Path)] == '/'))) {
128			ok = true
129			break
130		}
131	}
132	if !ok {
133		return false
134	}
135
136	for k, vs := range c.Header {
137		req.Header.Del(k)
138		for _, v := range vs {
139			req.Header.Add(k, v)
140		}
141	}
142	return true
143}
144
145func (c Cred) String() string {
146	var buf strings.Builder
147	for _, u := range c.URLPrefixes {
148		fmt.Fprintln(&buf, u)
149	}
150	buf.WriteString("\n")
151	c.Header.Write(&buf)
152	buf.WriteString("\n")
153	return buf.String()
154}
155
156func getCreds(args []string, resp *http.Response) ([]Cred, error) {
157	cmd := exec.Command(args[0], args[1:]...)
158	cmd.Stderr = os.Stderr
159
160	if resp != nil {
161		u := *resp.Request.URL
162		u.RawQuery = ""
163		cmd.Args = append(cmd.Args, u.String())
164	}
165
166	var head strings.Builder
167	if resp != nil {
168		formatHead(&head, resp)
169	}
170	cmd.Stdin = strings.NewReader(head.String())
171
172	fmt.Fprintf(os.Stderr, "# %s\n", strings.Join(cmd.Args, " "))
173	out, err := cmd.Output()
174	if err != nil {
175		return nil, fmt.Errorf("%s: %v", strings.Join(cmd.Args, " "), err)
176	}
177	os.Stderr.Write(out)
178	os.Stderr.WriteString("\n")
179
180	var creds []Cred
181	r := textproto.NewReader(bufio.NewReader(bytes.NewReader(out)))
182	line := 0
183readLoop:
184	for {
185		var prefixes []*url.URL
186		for {
187			prefix, err := r.ReadLine()
188			if err == io.EOF {
189				if len(prefixes) > 0 {
190					return nil, fmt.Errorf("line %d: %v", line, io.ErrUnexpectedEOF)
191				}
192				break readLoop
193			}
194			line++
195
196			if prefix == "" {
197				if len(prefixes) == 0 {
198					return nil, fmt.Errorf("line %d: unexpected newline", line)
199				}
200				break
201			}
202			u, err := url.Parse(prefix)
203			if err != nil {
204				return nil, fmt.Errorf("line %d: malformed URL: %v", line, err)
205			}
206			if u.Scheme != "https" {
207				return nil, fmt.Errorf("line %d: non-HTTPS URL %q", line, prefix)
208			}
209			if len(u.RawQuery) > 0 {
210				return nil, fmt.Errorf("line %d: unexpected query string in URL %q", line, prefix)
211			}
212			if len(u.Fragment) > 0 {
213				return nil, fmt.Errorf("line %d: unexpected fragment in URL %q", line, prefix)
214			}
215			prefixes = append(prefixes, u)
216		}
217
218		header, err := r.ReadMIMEHeader()
219		if err != nil {
220			return nil, fmt.Errorf("headers at line %d: %v", line, err)
221		}
222		if len(header) > 0 {
223			creds = append(creds, Cred{
224				URLPrefixes: prefixes,
225				Header:      http.Header(header),
226			})
227		}
228	}
229
230	return creds, nil
231}
232