1package oidc 2 3import ( 4 "errors" 5 "testing" 6 "time" 7 8 "github.com/stretchr/testify/assert" 9 "github.com/stretchr/testify/require" 10 "golang.org/x/oauth2" 11) 12 13func TestNewToken(t *testing.T) { 14 t.Parallel() 15 _, priv := TestGenerateKeys(t) 16 testJWT := testDefaultJWT(t, priv, 1*time.Minute, "123456789", nil) 17 testAccessToken := "test_access_token" 18 testRefreshToken := "test_refresh_token" 19 testExpiry := time.Now().Add(1 * time.Minute) 20 testUnderlying := &oauth2.Token{ 21 AccessToken: testAccessToken, 22 RefreshToken: testRefreshToken, 23 Expiry: testExpiry, 24 } 25 26 testUnderlyingZeroExpiry := &oauth2.Token{ 27 AccessToken: testAccessToken, 28 RefreshToken: testRefreshToken, 29 } 30 testNow := func() time.Time { 31 return time.Now().Add(-1 * time.Minute) 32 } 33 34 tests := []struct { 35 name string 36 idToken IDToken 37 oauthToken *oauth2.Token 38 opts []Option 39 want *Tk 40 wantNowFunc func() time.Time 41 wantIDToken IDToken 42 wantAccessToken AccessToken 43 wantRefreshToken RefreshToken 44 wantTokenSource oauth2.TokenSource 45 wantExpiry time.Time 46 wantExpired bool 47 wantValid bool 48 wantErr bool 49 wantIsErr error 50 }{ 51 { 52 name: "valid", 53 idToken: IDToken(testJWT), 54 oauthToken: testUnderlying, 55 opts: []Option{WithNow(testNow)}, 56 want: &Tk{ 57 idToken: IDToken(testJWT), 58 underlying: testUnderlying, 59 nowFunc: testNow, 60 }, 61 wantIDToken: IDToken(testJWT), 62 wantAccessToken: AccessToken(testAccessToken), 63 wantRefreshToken: RefreshToken(testRefreshToken), 64 wantTokenSource: oauth2.StaticTokenSource(testUnderlying), 65 wantExpiry: testExpiry, 66 wantExpired: false, 67 wantValid: true, 68 }, 69 { 70 name: "valid-def-now-func", 71 idToken: IDToken(testJWT), 72 oauthToken: testUnderlying, 73 opts: []Option{}, 74 want: &Tk{ 75 idToken: IDToken(testJWT), 76 underlying: testUnderlying, 77 }, 78 wantIDToken: IDToken(testJWT), 79 wantAccessToken: AccessToken(testAccessToken), 80 wantRefreshToken: RefreshToken(testRefreshToken), 81 wantTokenSource: oauth2.StaticTokenSource(testUnderlying), 82 wantExpiry: testExpiry, 83 wantExpired: false, 84 wantValid: true, 85 }, 86 { 87 name: "valid-without-accessToken", 88 idToken: IDToken(testJWT), 89 want: &Tk{ 90 idToken: IDToken(testJWT), 91 }, 92 wantIDToken: IDToken(testJWT), 93 wantExpired: true, 94 wantValid: false, 95 }, 96 { 97 name: "valid-with-accessToken-and-zero-expiry", 98 idToken: IDToken(testJWT), 99 oauthToken: testUnderlyingZeroExpiry, 100 want: &Tk{ 101 idToken: IDToken(testJWT), 102 underlying: testUnderlyingZeroExpiry, 103 }, 104 wantIDToken: IDToken(testJWT), 105 wantAccessToken: AccessToken(testAccessToken), 106 wantRefreshToken: RefreshToken(testRefreshToken), 107 wantTokenSource: oauth2.StaticTokenSource(testUnderlyingZeroExpiry), 108 wantExpired: false, 109 wantValid: true, 110 }, 111 { 112 name: "empty-idToken", 113 idToken: IDToken(""), 114 oauthToken: &oauth2.Token{ 115 AccessToken: testAccessToken, 116 }, 117 wantErr: true, 118 wantIsErr: ErrInvalidParameter, 119 }, 120 } 121 for _, tt := range tests { 122 t.Run(tt.name, func(t *testing.T) { 123 assert, require := assert.New(t), require.New(t) 124 got, err := NewToken(tt.idToken, tt.oauthToken, tt.opts...) 125 if tt.wantErr { 126 require.Error(err) 127 assert.Truef(errors.Is(err, tt.wantIsErr), "wanted \"%s\" but got \"%s\"", tt.wantIsErr, err) 128 return 129 } 130 require.NoError(err) 131 assert.Equalf(tt.want.underlying, got.underlying, "NewToken() = %v, want %v", got.underlying, tt.want.underlying) 132 assert.Equalf(tt.wantIDToken, got.IDToken(), "t.IDToken() = %v, want %v", tt.wantIDToken, got.IDToken()) 133 assert.Equalf(tt.wantAccessToken, got.AccessToken(), "t.AccessToken() = %v, want %v", tt.wantAccessToken, got.AccessToken()) 134 assert.Equalf(tt.wantRefreshToken, got.RefreshToken(), "t.RefreshToken() = %v, want %v", tt.wantRefreshToken, got.RefreshToken()) 135 assert.Equalf(tt.wantExpiry, got.Expiry(), "t.Expiry() = %v, want %v", tt.wantExpiry, got.Expiry()) 136 assert.Equalf(tt.wantTokenSource, got.StaticTokenSource(), "t.StaticTokenSource() = %v, want %v", tt.wantTokenSource, got.StaticTokenSource()) 137 assert.Equalf(tt.wantExpired, got.IsExpired(), "t.Expired() = %v, want %v", tt.wantExpired, got.IsExpired()) 138 assert.Equalf(tt.wantValid, got.Valid(), "t.Valid() = %v, want %v", tt.wantValid, got.Valid()) 139 testAssertEqualFunc(t, tt.want.nowFunc, got.nowFunc, "now = %p,want %p", tt.want.nowFunc, got.nowFunc) 140 }) 141 } 142} 143 144func TestUnmarshalClaims(t *testing.T) { 145 // UnmarshalClaims testing is covered by other tests but we do have just a 146 // few more test to add here. 147 t.Parallel() 148 t.Run("jwt-without-3-parts", func(t *testing.T) { 149 assert, require := assert.New(t), require.New(t) 150 var claims map[string]interface{} 151 jwt := "one.two" 152 err := UnmarshalClaims(jwt, &claims) 153 require.Error(err) 154 assert.Truef(errors.Is(err, ErrInvalidParameter), "wanted \"%s\" but got \"%s\"", ErrInvalidParameter, err) 155 }) 156} 157