1package oidc
2
3import (
4	"errors"
5	"testing"
6	"time"
7
8	"github.com/stretchr/testify/assert"
9	"github.com/stretchr/testify/require"
10	"golang.org/x/oauth2"
11)
12
13func TestNewToken(t *testing.T) {
14	t.Parallel()
15	_, priv := TestGenerateKeys(t)
16	testJWT := testDefaultJWT(t, priv, 1*time.Minute, "123456789", nil)
17	testAccessToken := "test_access_token"
18	testRefreshToken := "test_refresh_token"
19	testExpiry := time.Now().Add(1 * time.Minute)
20	testUnderlying := &oauth2.Token{
21		AccessToken:  testAccessToken,
22		RefreshToken: testRefreshToken,
23		Expiry:       testExpiry,
24	}
25
26	testUnderlyingZeroExpiry := &oauth2.Token{
27		AccessToken:  testAccessToken,
28		RefreshToken: testRefreshToken,
29	}
30	testNow := func() time.Time {
31		return time.Now().Add(-1 * time.Minute)
32	}
33
34	tests := []struct {
35		name             string
36		idToken          IDToken
37		oauthToken       *oauth2.Token
38		opts             []Option
39		want             *Tk
40		wantNowFunc      func() time.Time
41		wantIDToken      IDToken
42		wantAccessToken  AccessToken
43		wantRefreshToken RefreshToken
44		wantTokenSource  oauth2.TokenSource
45		wantExpiry       time.Time
46		wantExpired      bool
47		wantValid        bool
48		wantErr          bool
49		wantIsErr        error
50	}{
51		{
52			name:       "valid",
53			idToken:    IDToken(testJWT),
54			oauthToken: testUnderlying,
55			opts:       []Option{WithNow(testNow)},
56			want: &Tk{
57				idToken:    IDToken(testJWT),
58				underlying: testUnderlying,
59				nowFunc:    testNow,
60			},
61			wantIDToken:      IDToken(testJWT),
62			wantAccessToken:  AccessToken(testAccessToken),
63			wantRefreshToken: RefreshToken(testRefreshToken),
64			wantTokenSource:  oauth2.StaticTokenSource(testUnderlying),
65			wantExpiry:       testExpiry,
66			wantExpired:      false,
67			wantValid:        true,
68		},
69		{
70			name:       "valid-def-now-func",
71			idToken:    IDToken(testJWT),
72			oauthToken: testUnderlying,
73			opts:       []Option{},
74			want: &Tk{
75				idToken:    IDToken(testJWT),
76				underlying: testUnderlying,
77			},
78			wantIDToken:      IDToken(testJWT),
79			wantAccessToken:  AccessToken(testAccessToken),
80			wantRefreshToken: RefreshToken(testRefreshToken),
81			wantTokenSource:  oauth2.StaticTokenSource(testUnderlying),
82			wantExpiry:       testExpiry,
83			wantExpired:      false,
84			wantValid:        true,
85		},
86		{
87			name:    "valid-without-accessToken",
88			idToken: IDToken(testJWT),
89			want: &Tk{
90				idToken: IDToken(testJWT),
91			},
92			wantIDToken: IDToken(testJWT),
93			wantExpired: true,
94			wantValid:   false,
95		},
96		{
97			name:       "valid-with-accessToken-and-zero-expiry",
98			idToken:    IDToken(testJWT),
99			oauthToken: testUnderlyingZeroExpiry,
100			want: &Tk{
101				idToken:    IDToken(testJWT),
102				underlying: testUnderlyingZeroExpiry,
103			},
104			wantIDToken:      IDToken(testJWT),
105			wantAccessToken:  AccessToken(testAccessToken),
106			wantRefreshToken: RefreshToken(testRefreshToken),
107			wantTokenSource:  oauth2.StaticTokenSource(testUnderlyingZeroExpiry),
108			wantExpired:      false,
109			wantValid:        true,
110		},
111		{
112			name:    "empty-idToken",
113			idToken: IDToken(""),
114			oauthToken: &oauth2.Token{
115				AccessToken: testAccessToken,
116			},
117			wantErr:   true,
118			wantIsErr: ErrInvalidParameter,
119		},
120	}
121	for _, tt := range tests {
122		t.Run(tt.name, func(t *testing.T) {
123			assert, require := assert.New(t), require.New(t)
124			got, err := NewToken(tt.idToken, tt.oauthToken, tt.opts...)
125			if tt.wantErr {
126				require.Error(err)
127				assert.Truef(errors.Is(err, tt.wantIsErr), "wanted \"%s\" but got \"%s\"", tt.wantIsErr, err)
128				return
129			}
130			require.NoError(err)
131			assert.Equalf(tt.want.underlying, got.underlying, "NewToken() = %v, want %v", got.underlying, tt.want.underlying)
132			assert.Equalf(tt.wantIDToken, got.IDToken(), "t.IDToken() = %v, want %v", tt.wantIDToken, got.IDToken())
133			assert.Equalf(tt.wantAccessToken, got.AccessToken(), "t.AccessToken() = %v, want %v", tt.wantAccessToken, got.AccessToken())
134			assert.Equalf(tt.wantRefreshToken, got.RefreshToken(), "t.RefreshToken() = %v, want %v", tt.wantRefreshToken, got.RefreshToken())
135			assert.Equalf(tt.wantExpiry, got.Expiry(), "t.Expiry() = %v, want %v", tt.wantExpiry, got.Expiry())
136			assert.Equalf(tt.wantTokenSource, got.StaticTokenSource(), "t.StaticTokenSource() = %v, want %v", tt.wantTokenSource, got.StaticTokenSource())
137			assert.Equalf(tt.wantExpired, got.IsExpired(), "t.Expired() = %v, want %v", tt.wantExpired, got.IsExpired())
138			assert.Equalf(tt.wantValid, got.Valid(), "t.Valid() = %v, want %v", tt.wantValid, got.Valid())
139			testAssertEqualFunc(t, tt.want.nowFunc, got.nowFunc, "now = %p,want %p", tt.want.nowFunc, got.nowFunc)
140		})
141	}
142}
143
144func TestUnmarshalClaims(t *testing.T) {
145	// UnmarshalClaims testing is covered by other tests but we do have just a
146	// few more test to add here.
147	t.Parallel()
148	t.Run("jwt-without-3-parts", func(t *testing.T) {
149		assert, require := assert.New(t), require.New(t)
150		var claims map[string]interface{}
151		jwt := "one.two"
152		err := UnmarshalClaims(jwt, &claims)
153		require.Error(err)
154		assert.Truef(errors.Is(err, ErrInvalidParameter), "wanted \"%s\" but got \"%s\"", ErrInvalidParameter, err)
155	})
156}
157