1package oauth2
2
3import (
4	"errors"
5	"io"
6	"net/http"
7	"net/http/httptest"
8	"testing"
9	"time"
10)
11
12type tokenSource struct{ token *Token }
13
14func (t *tokenSource) Token() (*Token, error) {
15	return t.token, nil
16}
17
18func TestTransportNilTokenSource(t *testing.T) {
19	tr := &Transport{}
20	server := newMockServer(func(w http.ResponseWriter, r *http.Request) {})
21	defer server.Close()
22	client := &http.Client{Transport: tr}
23	resp, err := client.Get(server.URL)
24	if err == nil {
25		t.Errorf("got no errors, want an error with nil token source")
26	}
27	if resp != nil {
28		t.Errorf("Response = %v; want nil", resp)
29	}
30}
31
32type readCloseCounter struct {
33	CloseCount int
34	ReadErr    error
35}
36
37func (r *readCloseCounter) Read(b []byte) (int, error) {
38	return 0, r.ReadErr
39}
40
41func (r *readCloseCounter) Close() error {
42	r.CloseCount++
43	return nil
44}
45
46func TestTransportCloseRequestBody(t *testing.T) {
47	tr := &Transport{}
48	server := newMockServer(func(w http.ResponseWriter, r *http.Request) {})
49	defer server.Close()
50	client := &http.Client{Transport: tr}
51	body := &readCloseCounter{
52		ReadErr: errors.New("readCloseCounter.Read not implemented"),
53	}
54	resp, err := client.Post(server.URL, "application/json", body)
55	if err == nil {
56		t.Errorf("got no errors, want an error with nil token source")
57	}
58	if resp != nil {
59		t.Errorf("Response = %v; want nil", resp)
60	}
61	if expected := 1; body.CloseCount != expected {
62		t.Errorf("Body was closed %d times, expected %d", body.CloseCount, expected)
63	}
64}
65
66func TestTransportCloseRequestBodySuccess(t *testing.T) {
67	tr := &Transport{
68		Source: StaticTokenSource(&Token{
69			AccessToken: "abc",
70		}),
71	}
72	server := newMockServer(func(w http.ResponseWriter, r *http.Request) {})
73	defer server.Close()
74	client := &http.Client{Transport: tr}
75	body := &readCloseCounter{
76		ReadErr: io.EOF,
77	}
78	resp, err := client.Post(server.URL, "application/json", body)
79	if err != nil {
80		t.Errorf("got error %v; expected none", err)
81	}
82	if resp == nil {
83		t.Errorf("Response is nil; expected non-nil")
84	}
85	if expected := 1; body.CloseCount != expected {
86		t.Errorf("Body was closed %d times, expected %d", body.CloseCount, expected)
87	}
88}
89
90func TestTransportTokenSource(t *testing.T) {
91	ts := &tokenSource{
92		token: &Token{
93			AccessToken: "abc",
94		},
95	}
96	tr := &Transport{
97		Source: ts,
98	}
99	server := newMockServer(func(w http.ResponseWriter, r *http.Request) {
100		if got, want := r.Header.Get("Authorization"), "Bearer abc"; got != want {
101			t.Errorf("Authorization header = %q; want %q", got, want)
102		}
103	})
104	defer server.Close()
105	client := &http.Client{Transport: tr}
106	res, err := client.Get(server.URL)
107	if err != nil {
108		t.Fatal(err)
109	}
110	res.Body.Close()
111}
112
113// Test for case-sensitive token types, per https://github.com/golang/oauth2/issues/113
114func TestTransportTokenSourceTypes(t *testing.T) {
115	const val = "abc"
116	tests := []struct {
117		key  string
118		val  string
119		want string
120	}{
121		{key: "bearer", val: val, want: "Bearer abc"},
122		{key: "mac", val: val, want: "MAC abc"},
123		{key: "basic", val: val, want: "Basic abc"},
124	}
125	for _, tc := range tests {
126		ts := &tokenSource{
127			token: &Token{
128				AccessToken: tc.val,
129				TokenType:   tc.key,
130			},
131		}
132		tr := &Transport{
133			Source: ts,
134		}
135		server := newMockServer(func(w http.ResponseWriter, r *http.Request) {
136			if got, want := r.Header.Get("Authorization"), tc.want; got != want {
137				t.Errorf("Authorization header (%q) = %q; want %q", val, got, want)
138			}
139		})
140		defer server.Close()
141		client := &http.Client{Transport: tr}
142		res, err := client.Get(server.URL)
143		if err != nil {
144			t.Fatal(err)
145		}
146		res.Body.Close()
147	}
148}
149
150func TestTokenValidNoAccessToken(t *testing.T) {
151	token := &Token{}
152	if token.Valid() {
153		t.Errorf("got valid with no access token; want invalid")
154	}
155}
156
157func TestExpiredWithExpiry(t *testing.T) {
158	token := &Token{
159		Expiry: time.Now().Add(-5 * time.Hour),
160	}
161	if token.Valid() {
162		t.Errorf("got valid with expired token; want invalid")
163	}
164}
165
166func newMockServer(handler func(w http.ResponseWriter, r *http.Request)) *httptest.Server {
167	return httptest.NewServer(http.HandlerFunc(handler))
168}
169