1package webapp
2
3import (
4	"bytes"
5	"io/ioutil"
6	"net"
7	"net/http"
8	"net/url"
9	"testing"
10)
11
12func TestFlow_BrowserURL(t *testing.T) {
13	server := &localServer{
14		listener: &fakeListener{
15			addr: &net.TCPAddr{Port: 12345},
16		},
17	}
18
19	type fields struct {
20		server   *localServer
21		clientID string
22		state    string
23	}
24	type args struct {
25		baseURL string
26		params  BrowserParams
27	}
28	tests := []struct {
29		name    string
30		fields  fields
31		args    args
32		want    string
33		wantErr bool
34	}{
35		{
36			name: "happy path",
37			fields: fields{
38				server: server,
39				state:  "xy/z",
40			},
41			args: args{
42				baseURL: "https://github.com/authorize",
43				params: BrowserParams{
44					ClientID:    "CLIENT-ID",
45					RedirectURI: "http://127.0.0.1/hello",
46					Scopes:      []string{"repo", "read:org"},
47					AllowSignup: true,
48				},
49			},
50			want:    "https://github.com/authorize?client_id=CLIENT-ID&redirect_uri=http%3A%2F%2F127.0.0.1%3A12345%2Fhello&scope=repo+read%3Aorg&state=xy%2Fz",
51			wantErr: false,
52		},
53	}
54	for _, tt := range tests {
55		t.Run(tt.name, func(t *testing.T) {
56			flow := &Flow{
57				server:   tt.fields.server,
58				clientID: tt.fields.clientID,
59				state:    tt.fields.state,
60			}
61			got, err := flow.BrowserURL(tt.args.baseURL, tt.args.params)
62			if (err != nil) != tt.wantErr {
63				t.Errorf("Flow.BrowserURL() error = %v, wantErr %v", err, tt.wantErr)
64				return
65			}
66			if got != tt.want {
67				t.Errorf("Flow.BrowserURL() = %v, want %v", got, tt.want)
68			}
69		})
70	}
71}
72
73type apiStub struct {
74	status      int
75	body        string
76	contentType string
77}
78
79type postArgs struct {
80	url    string
81	params url.Values
82}
83
84type apiClient struct {
85	stubs []apiStub
86	calls []postArgs
87
88	postCount int
89}
90
91func (c *apiClient) PostForm(u string, params url.Values) (*http.Response, error) {
92	stub := c.stubs[c.postCount]
93	c.calls = append(c.calls, postArgs{url: u, params: params})
94	c.postCount++
95	return &http.Response{
96		Body: ioutil.NopCloser(bytes.NewBufferString(stub.body)),
97		Header: http.Header{
98			"Content-Type": {stub.contentType},
99		},
100		StatusCode: stub.status,
101	}, nil
102}
103
104func TestFlow_AccessToken(t *testing.T) {
105	server := &localServer{
106		listener: &fakeListener{
107			addr: &net.TCPAddr{Port: 12345},
108		},
109		resultChan: make(chan CodeResponse),
110	}
111
112	flow := Flow{
113		server:   server,
114		clientID: "CLIENT-ID",
115		state:    "xy/z",
116	}
117
118	client := &apiClient{
119		stubs: []apiStub{
120			{
121				body:        "access_token=ATOKEN&token_type=bearer&scope=repo+gist",
122				status:      200,
123				contentType: "application/x-www-form-urlencoded; charset=utf-8",
124			},
125		},
126	}
127
128	go func() {
129		server.resultChan <- CodeResponse{
130			Code:  "ABC-123",
131			State: "xy/z",
132		}
133	}()
134
135	token, err := flow.AccessToken(client, "https://github.com/access_token", "OAUTH-SEKRIT")
136	if err != nil {
137		t.Fatalf("AccessToken() error: %v", err)
138	}
139
140	if len(client.calls) != 1 {
141		t.Fatalf("expected 1 HTTP POST, got %d", len(client.calls))
142	}
143	apiPost := client.calls[0]
144	if apiPost.url != "https://github.com/access_token" {
145		t.Errorf("HTTP POST to %q", apiPost.url)
146	}
147	if params := apiPost.params.Encode(); params != "client_id=CLIENT-ID&client_secret=OAUTH-SEKRIT&code=ABC-123&state=xy%2Fz" {
148		t.Errorf("HTTP POST params: %v", params)
149	}
150
151	if token.Token != "ATOKEN" {
152		t.Errorf("Token = %q", token.Token)
153	}
154}
155