1// Copyright 2014 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package oauth2
6
7import (
8	"context"
9	"errors"
10	"fmt"
11	"io"
12	"io/ioutil"
13	"net/http"
14	"net/http/httptest"
15	"net/url"
16	"testing"
17	"time"
18
19	"golang.org/x/oauth2/internal"
20)
21
22type mockTransport struct {
23	rt func(req *http.Request) (resp *http.Response, err error)
24}
25
26func (t *mockTransport) RoundTrip(req *http.Request) (resp *http.Response, err error) {
27	return t.rt(req)
28}
29
30func newConf(url string) *Config {
31	return &Config{
32		ClientID:     "CLIENT_ID",
33		ClientSecret: "CLIENT_SECRET",
34		RedirectURL:  "REDIRECT_URL",
35		Scopes:       []string{"scope1", "scope2"},
36		Endpoint: Endpoint{
37			AuthURL:  url + "/auth",
38			TokenURL: url + "/token",
39		},
40	}
41}
42
43func TestAuthCodeURL(t *testing.T) {
44	conf := newConf("server")
45	url := conf.AuthCodeURL("foo", AccessTypeOffline, ApprovalForce)
46	const want = "server/auth?access_type=offline&client_id=CLIENT_ID&prompt=consent&redirect_uri=REDIRECT_URL&response_type=code&scope=scope1+scope2&state=foo"
47	if got := url; got != want {
48		t.Errorf("got auth code URL = %q; want %q", got, want)
49	}
50}
51
52func TestAuthCodeURL_CustomParam(t *testing.T) {
53	conf := newConf("server")
54	param := SetAuthURLParam("foo", "bar")
55	url := conf.AuthCodeURL("baz", param)
56	const want = "server/auth?client_id=CLIENT_ID&foo=bar&redirect_uri=REDIRECT_URL&response_type=code&scope=scope1+scope2&state=baz"
57	if got := url; got != want {
58		t.Errorf("got auth code = %q; want %q", got, want)
59	}
60}
61
62func TestAuthCodeURL_Optional(t *testing.T) {
63	conf := &Config{
64		ClientID: "CLIENT_ID",
65		Endpoint: Endpoint{
66			AuthURL:  "/auth-url",
67			TokenURL: "/token-url",
68		},
69	}
70	url := conf.AuthCodeURL("")
71	const want = "/auth-url?client_id=CLIENT_ID&response_type=code"
72	if got := url; got != want {
73		t.Fatalf("got auth code = %q; want %q", got, want)
74	}
75}
76
77func TestURLUnsafeClientConfig(t *testing.T) {
78	ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
79		if got, want := r.Header.Get("Authorization"), "Basic Q0xJRU5UX0lEJTNGJTNGOkNMSUVOVF9TRUNSRVQlM0YlM0Y="; got != want {
80			t.Errorf("Authorization header = %q; want %q", got, want)
81		}
82
83		w.Header().Set("Content-Type", "application/x-www-form-urlencoded")
84		w.Write([]byte("access_token=90d64460d14870c08c81352a05dedd3465940a7c&scope=user&token_type=bearer"))
85	}))
86	defer ts.Close()
87	conf := newConf(ts.URL)
88	conf.ClientID = "CLIENT_ID??"
89	conf.ClientSecret = "CLIENT_SECRET??"
90	_, err := conf.Exchange(context.Background(), "exchange-code")
91	if err != nil {
92		t.Error(err)
93	}
94}
95
96func TestExchangeRequest(t *testing.T) {
97	ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
98		if r.URL.String() != "/token" {
99			t.Errorf("Unexpected exchange request URL %q", r.URL)
100		}
101		headerAuth := r.Header.Get("Authorization")
102		if want := "Basic Q0xJRU5UX0lEOkNMSUVOVF9TRUNSRVQ="; headerAuth != want {
103			t.Errorf("Unexpected authorization header %q, want %q", headerAuth, want)
104		}
105		headerContentType := r.Header.Get("Content-Type")
106		if headerContentType != "application/x-www-form-urlencoded" {
107			t.Errorf("Unexpected Content-Type header %q", headerContentType)
108		}
109		body, err := ioutil.ReadAll(r.Body)
110		if err != nil {
111			t.Errorf("Failed reading request body: %s.", err)
112		}
113		if string(body) != "code=exchange-code&grant_type=authorization_code&redirect_uri=REDIRECT_URL" {
114			t.Errorf("Unexpected exchange payload; got %q", body)
115		}
116		w.Header().Set("Content-Type", "application/x-www-form-urlencoded")
117		w.Write([]byte("access_token=90d64460d14870c08c81352a05dedd3465940a7c&scope=user&token_type=bearer"))
118	}))
119	defer ts.Close()
120	conf := newConf(ts.URL)
121	tok, err := conf.Exchange(context.Background(), "exchange-code")
122	if err != nil {
123		t.Error(err)
124	}
125	if !tok.Valid() {
126		t.Fatalf("Token invalid. Got: %#v", tok)
127	}
128	if tok.AccessToken != "90d64460d14870c08c81352a05dedd3465940a7c" {
129		t.Errorf("Unexpected access token, %#v.", tok.AccessToken)
130	}
131	if tok.TokenType != "bearer" {
132		t.Errorf("Unexpected token type, %#v.", tok.TokenType)
133	}
134	scope := tok.Extra("scope")
135	if scope != "user" {
136		t.Errorf("Unexpected value for scope: %v", scope)
137	}
138}
139
140func TestExchangeRequest_CustomParam(t *testing.T) {
141	ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
142		if r.URL.String() != "/token" {
143			t.Errorf("Unexpected exchange request URL, %v is found.", r.URL)
144		}
145		headerAuth := r.Header.Get("Authorization")
146		if headerAuth != "Basic Q0xJRU5UX0lEOkNMSUVOVF9TRUNSRVQ=" {
147			t.Errorf("Unexpected authorization header, %v is found.", headerAuth)
148		}
149		headerContentType := r.Header.Get("Content-Type")
150		if headerContentType != "application/x-www-form-urlencoded" {
151			t.Errorf("Unexpected Content-Type header, %v is found.", headerContentType)
152		}
153		body, err := ioutil.ReadAll(r.Body)
154		if err != nil {
155			t.Errorf("Failed reading request body: %s.", err)
156		}
157		if string(body) != "code=exchange-code&foo=bar&grant_type=authorization_code&redirect_uri=REDIRECT_URL" {
158			t.Errorf("Unexpected exchange payload, %v is found.", string(body))
159		}
160		w.Header().Set("Content-Type", "application/x-www-form-urlencoded")
161		w.Write([]byte("access_token=90d64460d14870c08c81352a05dedd3465940a7c&scope=user&token_type=bearer"))
162	}))
163	defer ts.Close()
164	conf := newConf(ts.URL)
165
166	param := SetAuthURLParam("foo", "bar")
167	tok, err := conf.Exchange(context.Background(), "exchange-code", param)
168	if err != nil {
169		t.Error(err)
170	}
171	if !tok.Valid() {
172		t.Fatalf("Token invalid. Got: %#v", tok)
173	}
174	if tok.AccessToken != "90d64460d14870c08c81352a05dedd3465940a7c" {
175		t.Errorf("Unexpected access token, %#v.", tok.AccessToken)
176	}
177	if tok.TokenType != "bearer" {
178		t.Errorf("Unexpected token type, %#v.", tok.TokenType)
179	}
180	scope := tok.Extra("scope")
181	if scope != "user" {
182		t.Errorf("Unexpected value for scope: %v", scope)
183	}
184}
185
186func TestExchangeRequest_JSONResponse(t *testing.T) {
187	ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
188		if r.URL.String() != "/token" {
189			t.Errorf("Unexpected exchange request URL, %v is found.", r.URL)
190		}
191		headerAuth := r.Header.Get("Authorization")
192		if headerAuth != "Basic Q0xJRU5UX0lEOkNMSUVOVF9TRUNSRVQ=" {
193			t.Errorf("Unexpected authorization header, %v is found.", headerAuth)
194		}
195		headerContentType := r.Header.Get("Content-Type")
196		if headerContentType != "application/x-www-form-urlencoded" {
197			t.Errorf("Unexpected Content-Type header, %v is found.", headerContentType)
198		}
199		body, err := ioutil.ReadAll(r.Body)
200		if err != nil {
201			t.Errorf("Failed reading request body: %s.", err)
202		}
203		if string(body) != "code=exchange-code&grant_type=authorization_code&redirect_uri=REDIRECT_URL" {
204			t.Errorf("Unexpected exchange payload, %v is found.", string(body))
205		}
206		w.Header().Set("Content-Type", "application/json")
207		w.Write([]byte(`{"access_token": "90d64460d14870c08c81352a05dedd3465940a7c", "scope": "user", "token_type": "bearer", "expires_in": 86400}`))
208	}))
209	defer ts.Close()
210	conf := newConf(ts.URL)
211	tok, err := conf.Exchange(context.Background(), "exchange-code")
212	if err != nil {
213		t.Error(err)
214	}
215	if !tok.Valid() {
216		t.Fatalf("Token invalid. Got: %#v", tok)
217	}
218	if tok.AccessToken != "90d64460d14870c08c81352a05dedd3465940a7c" {
219		t.Errorf("Unexpected access token, %#v.", tok.AccessToken)
220	}
221	if tok.TokenType != "bearer" {
222		t.Errorf("Unexpected token type, %#v.", tok.TokenType)
223	}
224	scope := tok.Extra("scope")
225	if scope != "user" {
226		t.Errorf("Unexpected value for scope: %v", scope)
227	}
228	expiresIn := tok.Extra("expires_in")
229	if expiresIn != float64(86400) {
230		t.Errorf("Unexpected non-numeric value for expires_in: %v", expiresIn)
231	}
232}
233
234func TestExtraValueRetrieval(t *testing.T) {
235	values := url.Values{}
236	kvmap := map[string]string{
237		"scope": "user", "token_type": "bearer", "expires_in": "86400.92",
238		"server_time": "1443571905.5606415", "referer_ip": "10.0.0.1",
239		"etag": "\"afZYj912P4alikMz_P11982\"", "request_id": "86400",
240		"untrimmed": "  untrimmed  ",
241	}
242	for key, value := range kvmap {
243		values.Set(key, value)
244	}
245
246	tok := Token{raw: values}
247	scope := tok.Extra("scope")
248	if got, want := scope, "user"; got != want {
249		t.Errorf("got scope = %q; want %q", got, want)
250	}
251	serverTime := tok.Extra("server_time")
252	if got, want := serverTime, 1443571905.5606415; got != want {
253		t.Errorf("got server_time value = %v; want %v", got, want)
254	}
255	refererIP := tok.Extra("referer_ip")
256	if got, want := refererIP, "10.0.0.1"; got != want {
257		t.Errorf("got referer_ip value = %v, want %v", got, want)
258	}
259	expiresIn := tok.Extra("expires_in")
260	if got, want := expiresIn, 86400.92; got != want {
261		t.Errorf("got expires_in value = %v, want %v", got, want)
262	}
263	requestID := tok.Extra("request_id")
264	if got, want := requestID, int64(86400); got != want {
265		t.Errorf("got request_id value = %v, want %v", got, want)
266	}
267	untrimmed := tok.Extra("untrimmed")
268	if got, want := untrimmed, "  untrimmed  "; got != want {
269		t.Errorf("got untrimmed = %q; want %q", got, want)
270	}
271}
272
273const day = 24 * time.Hour
274
275func TestExchangeRequest_JSONResponse_Expiry(t *testing.T) {
276	seconds := int32(day.Seconds())
277	for _, c := range []struct {
278		name        string
279		expires     string
280		want        bool
281		nullExpires bool
282	}{
283		{"normal", fmt.Sprintf(`"expires_in": %d`, seconds), true, false},
284		{"paypal", fmt.Sprintf(`"expires_in": "%d"`, seconds), true, false},
285		{"issue_239", fmt.Sprintf(`"expires_in": null`), true, true},
286
287		{"wrong_type", `"expires_in": false`, false, false},
288		{"wrong_type2", `"expires_in": {}`, false, false},
289		{"wrong_value", `"expires_in": "zzz"`, false, false},
290	} {
291		t.Run(c.name, func(t *testing.T) {
292			testExchangeRequest_JSONResponse_expiry(t, c.expires, c.want, c.nullExpires)
293		})
294	}
295}
296
297func testExchangeRequest_JSONResponse_expiry(t *testing.T, exp string, want, nullExpires bool) {
298	ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
299		w.Header().Set("Content-Type", "application/json")
300		w.Write([]byte(fmt.Sprintf(`{"access_token": "90d", "scope": "user", "token_type": "bearer", %s}`, exp)))
301	}))
302	defer ts.Close()
303	conf := newConf(ts.URL)
304	t1 := time.Now().Add(day)
305	tok, err := conf.Exchange(context.Background(), "exchange-code")
306	t2 := t1.Add(day)
307
308	if got := (err == nil); got != want {
309		if want {
310			t.Errorf("unexpected error: got %v", err)
311		} else {
312			t.Errorf("unexpected success")
313		}
314	}
315	if !want {
316		return
317	}
318	if !tok.Valid() {
319		t.Fatalf("Token invalid. Got: %#v", tok)
320	}
321	expiry := tok.Expiry
322
323	if nullExpires && expiry.IsZero() {
324		return
325	}
326	if expiry.Before(t1) || expiry.After(t2) {
327		t.Errorf("Unexpected value for Expiry: %v (should be between %v and %v)", expiry, t1, t2)
328	}
329}
330
331func TestExchangeRequest_BadResponse(t *testing.T) {
332	ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
333		w.Header().Set("Content-Type", "application/json")
334		w.Write([]byte(`{"scope": "user", "token_type": "bearer"}`))
335	}))
336	defer ts.Close()
337	conf := newConf(ts.URL)
338	_, err := conf.Exchange(context.Background(), "code")
339	if err == nil {
340		t.Error("expected error from missing access_token")
341	}
342}
343
344func TestExchangeRequest_BadResponseType(t *testing.T) {
345	ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
346		w.Header().Set("Content-Type", "application/json")
347		w.Write([]byte(`{"access_token":123,  "scope": "user", "token_type": "bearer"}`))
348	}))
349	defer ts.Close()
350	conf := newConf(ts.URL)
351	_, err := conf.Exchange(context.Background(), "exchange-code")
352	if err == nil {
353		t.Error("expected error from non-string access_token")
354	}
355}
356
357func TestExchangeRequest_NonBasicAuth(t *testing.T) {
358	internal.ResetAuthCache()
359	tr := &mockTransport{
360		rt: func(r *http.Request) (w *http.Response, err error) {
361			headerAuth := r.Header.Get("Authorization")
362			if headerAuth != "" {
363				t.Errorf("Unexpected authorization header %q", headerAuth)
364			}
365			return nil, errors.New("no response")
366		},
367	}
368	c := &http.Client{Transport: tr}
369	conf := &Config{
370		ClientID: "CLIENT_ID",
371		Endpoint: Endpoint{
372			AuthURL:   "https://accounts.google.com/auth",
373			TokenURL:  "https://accounts.google.com/token",
374			AuthStyle: AuthStyleInParams,
375		},
376	}
377
378	ctx := context.WithValue(context.Background(), HTTPClient, c)
379	conf.Exchange(ctx, "code")
380}
381
382func TestPasswordCredentialsTokenRequest(t *testing.T) {
383	ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
384		defer r.Body.Close()
385		expected := "/token"
386		if r.URL.String() != expected {
387			t.Errorf("URL = %q; want %q", r.URL, expected)
388		}
389		headerAuth := r.Header.Get("Authorization")
390		expected = "Basic Q0xJRU5UX0lEOkNMSUVOVF9TRUNSRVQ="
391		if headerAuth != expected {
392			t.Errorf("Authorization header = %q; want %q", headerAuth, expected)
393		}
394		headerContentType := r.Header.Get("Content-Type")
395		expected = "application/x-www-form-urlencoded"
396		if headerContentType != expected {
397			t.Errorf("Content-Type header = %q; want %q", headerContentType, expected)
398		}
399		body, err := ioutil.ReadAll(r.Body)
400		if err != nil {
401			t.Errorf("Failed reading request body: %s.", err)
402		}
403		expected = "grant_type=password&password=password1&scope=scope1+scope2&username=user1"
404		if string(body) != expected {
405			t.Errorf("res.Body = %q; want %q", string(body), expected)
406		}
407		w.Header().Set("Content-Type", "application/x-www-form-urlencoded")
408		w.Write([]byte("access_token=90d64460d14870c08c81352a05dedd3465940a7c&scope=user&token_type=bearer"))
409	}))
410	defer ts.Close()
411	conf := newConf(ts.URL)
412	tok, err := conf.PasswordCredentialsToken(context.Background(), "user1", "password1")
413	if err != nil {
414		t.Error(err)
415	}
416	if !tok.Valid() {
417		t.Fatalf("Token invalid. Got: %#v", tok)
418	}
419	expected := "90d64460d14870c08c81352a05dedd3465940a7c"
420	if tok.AccessToken != expected {
421		t.Errorf("AccessToken = %q; want %q", tok.AccessToken, expected)
422	}
423	expected = "bearer"
424	if tok.TokenType != expected {
425		t.Errorf("TokenType = %q; want %q", tok.TokenType, expected)
426	}
427}
428
429func TestTokenRefreshRequest(t *testing.T) {
430	internal.ResetAuthCache()
431	ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
432		if r.URL.String() == "/somethingelse" {
433			return
434		}
435		if r.URL.String() != "/token" {
436			t.Errorf("Unexpected token refresh request URL %q", r.URL)
437		}
438		headerContentType := r.Header.Get("Content-Type")
439		if headerContentType != "application/x-www-form-urlencoded" {
440			t.Errorf("Unexpected Content-Type header %q", headerContentType)
441		}
442		body, _ := ioutil.ReadAll(r.Body)
443		if string(body) != "grant_type=refresh_token&refresh_token=REFRESH_TOKEN" {
444			t.Errorf("Unexpected refresh token payload %q", body)
445		}
446		w.Header().Set("Content-Type", "application/json")
447		io.WriteString(w, `{"access_token": "foo", "refresh_token": "bar"}`)
448	}))
449	defer ts.Close()
450	conf := newConf(ts.URL)
451	c := conf.Client(context.Background(), &Token{RefreshToken: "REFRESH_TOKEN"})
452	c.Get(ts.URL + "/somethingelse")
453}
454
455func TestFetchWithNoRefreshToken(t *testing.T) {
456	ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
457		if r.URL.String() == "/somethingelse" {
458			return
459		}
460		if r.URL.String() != "/token" {
461			t.Errorf("Unexpected token refresh request URL, %v is found.", r.URL)
462		}
463		headerContentType := r.Header.Get("Content-Type")
464		if headerContentType != "application/x-www-form-urlencoded" {
465			t.Errorf("Unexpected Content-Type header, %v is found.", headerContentType)
466		}
467		body, _ := ioutil.ReadAll(r.Body)
468		if string(body) != "client_id=CLIENT_ID&grant_type=refresh_token&refresh_token=REFRESH_TOKEN" {
469			t.Errorf("Unexpected refresh token payload, %v is found.", string(body))
470		}
471	}))
472	defer ts.Close()
473	conf := newConf(ts.URL)
474	c := conf.Client(context.Background(), nil)
475	_, err := c.Get(ts.URL + "/somethingelse")
476	if err == nil {
477		t.Errorf("Fetch should return an error if no refresh token is set")
478	}
479}
480
481func TestTokenRetrieveError(t *testing.T) {
482	ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
483		if r.URL.String() != "/token" {
484			t.Errorf("Unexpected token refresh request URL, %v is found.", r.URL)
485		}
486		w.Header().Set("Content-type", "application/json")
487		w.WriteHeader(http.StatusBadRequest)
488		w.Write([]byte(`{"error": "invalid_grant"}`))
489	}))
490	defer ts.Close()
491	conf := newConf(ts.URL)
492	_, err := conf.Exchange(context.Background(), "exchange-code")
493	if err == nil {
494		t.Fatalf("got no error, expected one")
495	}
496	_, ok := err.(*RetrieveError)
497	if !ok {
498		t.Fatalf("got %T error, expected *RetrieveError; error was: %v", err, err)
499	}
500	// Test error string for backwards compatibility
501	expected := fmt.Sprintf("oauth2: cannot fetch token: %v\nResponse: %s", "400 Bad Request", `{"error": "invalid_grant"}`)
502	if errStr := err.Error(); errStr != expected {
503		t.Fatalf("got %#v, expected %#v", errStr, expected)
504	}
505}
506
507func TestRefreshToken_RefreshTokenReplacement(t *testing.T) {
508	ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
509		w.Header().Set("Content-Type", "application/json")
510		w.Write([]byte(`{"access_token":"ACCESS_TOKEN",  "scope": "user", "token_type": "bearer", "refresh_token": "NEW_REFRESH_TOKEN"}`))
511		return
512	}))
513	defer ts.Close()
514	conf := newConf(ts.URL)
515	tkr := conf.TokenSource(context.Background(), &Token{RefreshToken: "OLD_REFRESH_TOKEN"})
516	tk, err := tkr.Token()
517	if err != nil {
518		t.Errorf("got err = %v; want none", err)
519		return
520	}
521	if want := "NEW_REFRESH_TOKEN"; tk.RefreshToken != want {
522		t.Errorf("RefreshToken = %q; want %q", tk.RefreshToken, want)
523	}
524}
525
526func TestRefreshToken_RefreshTokenPreservation(t *testing.T) {
527	ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
528		w.Header().Set("Content-Type", "application/json")
529		w.Write([]byte(`{"access_token":"ACCESS_TOKEN",  "scope": "user", "token_type": "bearer"}`))
530		return
531	}))
532	defer ts.Close()
533	conf := newConf(ts.URL)
534	const oldRefreshToken = "OLD_REFRESH_TOKEN"
535	tkr := conf.TokenSource(context.Background(), &Token{RefreshToken: oldRefreshToken})
536	tk, err := tkr.Token()
537	if err != nil {
538		t.Fatalf("got err = %v; want none", err)
539	}
540	if tk.RefreshToken != oldRefreshToken {
541		t.Errorf("RefreshToken = %q; want %q", tk.RefreshToken, oldRefreshToken)
542	}
543}
544
545func TestConfigClientWithToken(t *testing.T) {
546	tok := &Token{
547		AccessToken: "abc123",
548	}
549	ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
550		if got, want := r.Header.Get("Authorization"), fmt.Sprintf("Bearer %s", tok.AccessToken); got != want {
551			t.Errorf("Authorization header = %q; want %q", got, want)
552		}
553		return
554	}))
555	defer ts.Close()
556	conf := newConf(ts.URL)
557
558	c := conf.Client(context.Background(), tok)
559	req, err := http.NewRequest("GET", ts.URL, nil)
560	if err != nil {
561		t.Error(err)
562	}
563	_, err = c.Do(req)
564	if err != nil {
565		t.Error(err)
566	}
567}
568