1package jwt 2 3import ( 4 "context" 5 "crypto/x509" 6 "encoding/base64" 7 "encoding/json" 8 "encoding/pem" 9 "io/ioutil" 10 "net/http" 11 "net/http/httptest" 12 "os" 13 "strings" 14 "testing" 15 "time" 16 17 "github.com/stretchr/testify/assert" 18 "github.com/stretchr/testify/require" 19 jose "gopkg.in/square/go-jose.v2" 20 "gopkg.in/square/go-jose.v2/jwt" 21 22 "github.com/grafana/grafana/pkg/infra/remotecache" 23 "github.com/grafana/grafana/pkg/setting" 24) 25 26type scenarioContext struct { 27 ctx context.Context 28 cfg *setting.Cfg 29 authJWTSvc *AuthService 30} 31 32type cachingScenarioContext struct { 33 scenarioContext 34 reqCount *int 35} 36 37type configureFunc func(*testing.T, *setting.Cfg) 38type scenarioFunc func(*testing.T, scenarioContext) 39type cachingScenarioFunc func(*testing.T, cachingScenarioContext) 40 41const subject = "foo-subj" 42 43func TestVerifyUsingPKIXPublicKeyFile(t *testing.T) { 44 key := rsaKeys[0] 45 unknownKey := rsaKeys[1] 46 47 scenario(t, "verifies a token", func(t *testing.T, sc scenarioContext) { 48 token := sign(t, key, jwt.Claims{ 49 Subject: subject, 50 }) 51 verifiedClaims, err := sc.authJWTSvc.Verify(sc.ctx, token) 52 require.NoError(t, err) 53 assert.Equal(t, verifiedClaims["sub"], subject) 54 }, configurePKIXPublicKeyFile) 55 56 scenario(t, "rejects a token signed by unknown key", func(t *testing.T, sc scenarioContext) { 57 token := sign(t, unknownKey, jwt.Claims{ 58 Subject: subject, 59 }) 60 _, err := sc.authJWTSvc.Verify(sc.ctx, token) 61 require.Error(t, err) 62 }, configurePKIXPublicKeyFile) 63} 64 65func TestVerifyUsingJWKSetFile(t *testing.T) { 66 configure := func(t *testing.T, cfg *setting.Cfg) { 67 t.Helper() 68 69 file, err := ioutil.TempFile(os.TempDir(), "jwk-*.json") 70 require.NoError(t, err) 71 t.Cleanup(func() { 72 if err := os.Remove(file.Name()); err != nil { 73 panic(err) 74 } 75 }) 76 77 require.NoError(t, json.NewEncoder(file).Encode(jwksPublic)) 78 require.NoError(t, file.Close()) 79 80 cfg.JWTAuthJWKSetFile = file.Name() 81 } 82 83 scenario(t, "verifies a token signed with a key from the set", func(t *testing.T, sc scenarioContext) { 84 token := sign(t, &jwKeys[0], jwt.Claims{Subject: subject}) 85 verifiedClaims, err := sc.authJWTSvc.Verify(sc.ctx, token) 86 require.NoError(t, err) 87 assert.Equal(t, verifiedClaims["sub"], subject) 88 }, configure) 89 90 scenario(t, "verifies a token signed with another key from the set", func(t *testing.T, sc scenarioContext) { 91 token := sign(t, &jwKeys[1], jwt.Claims{Subject: subject}) 92 verifiedClaims, err := sc.authJWTSvc.Verify(sc.ctx, token) 93 require.NoError(t, err) 94 assert.Equal(t, verifiedClaims["sub"], subject) 95 }, configure) 96 97 scenario(t, "rejects a token signed with a key not from the set", func(t *testing.T, sc scenarioContext) { 98 token := sign(t, jwKeys[2], jwt.Claims{Subject: subject}) 99 _, err := sc.authJWTSvc.Verify(sc.ctx, token) 100 require.Error(t, err) 101 }, configure) 102} 103 104func TestVerifyUsingJWKSetURL(t *testing.T) { 105 t.Run("should refuse to start with non-https URL", func(t *testing.T) { 106 var err error 107 108 _, err = initAuthService(t, func(t *testing.T, cfg *setting.Cfg) { 109 cfg.JWTAuthJWKSetURL = "https://example.com/.well-known/jwks.json" 110 }) 111 require.NoError(t, err) 112 113 _, err = initAuthService(t, func(t *testing.T, cfg *setting.Cfg) { 114 cfg.JWTAuthJWKSetURL = "http://example.com/.well-known/jwks.json" 115 }) 116 require.Error(t, err) 117 }) 118 119 jwkHTTPScenario(t, "verifies a token signed with a key from the set", func(t *testing.T, sc scenarioContext) { 120 token := sign(t, &jwKeys[0], jwt.Claims{Subject: subject}) 121 verifiedClaims, err := sc.authJWTSvc.Verify(sc.ctx, token) 122 require.NoError(t, err) 123 assert.Equal(t, verifiedClaims["sub"], subject) 124 }) 125 126 jwkHTTPScenario(t, "verifies a token signed with another key from the set", func(t *testing.T, sc scenarioContext) { 127 token := sign(t, &jwKeys[1], jwt.Claims{Subject: subject}) 128 verifiedClaims, err := sc.authJWTSvc.Verify(sc.ctx, token) 129 require.NoError(t, err) 130 assert.Equal(t, verifiedClaims["sub"], subject) 131 }) 132 133 jwkHTTPScenario(t, "rejects a token signed with a key not from the set", func(t *testing.T, sc scenarioContext) { 134 token := sign(t, jwKeys[2], jwt.Claims{Subject: subject}) 135 _, err := sc.authJWTSvc.Verify(sc.ctx, token) 136 require.Error(t, err) 137 }) 138} 139 140func TestCachingJWKHTTPResponse(t *testing.T) { 141 jwkCachingScenario(t, "caches the jwk response", func(t *testing.T, sc cachingScenarioContext) { 142 for i := 0; i < 5; i++ { 143 token := sign(t, &jwKeys[0], jwt.Claims{Subject: subject}) 144 _, err := sc.authJWTSvc.Verify(sc.ctx, token) 145 require.NoError(t, err, "verify call %d", i+1) 146 } 147 148 assert.Equal(t, 1, *sc.reqCount) 149 }) 150 151 jwkCachingScenario(t, "respects TTL setting (while cached)", func(t *testing.T, sc cachingScenarioContext) { 152 var err error 153 154 token0 := sign(t, &jwKeys[0], jwt.Claims{Subject: subject}) 155 token1 := sign(t, &jwKeys[1], jwt.Claims{Subject: subject}) 156 157 _, err = sc.authJWTSvc.Verify(sc.ctx, token0) 158 require.NoError(t, err) 159 _, err = sc.authJWTSvc.Verify(sc.ctx, token1) 160 require.Error(t, err) 161 162 assert.Equal(t, 1, *sc.reqCount) 163 }, func(t *testing.T, cfg *setting.Cfg) { 164 // Arbitrary high value, several times what the test should take. 165 cfg.JWTAuthCacheTTL = time.Minute 166 }) 167 168 jwkCachingScenario(t, "does not cache the response when TTL is zero", func(t *testing.T, sc cachingScenarioContext) { 169 for i := 0; i < 2; i++ { 170 _, err := sc.authJWTSvc.Verify(sc.ctx, sign(t, &jwKeys[i], jwt.Claims{Subject: subject})) 171 require.NoError(t, err, "verify call %d", i+1) 172 } 173 174 assert.Equal(t, 2, *sc.reqCount) 175 }, func(t *testing.T, cfg *setting.Cfg) { 176 cfg.JWTAuthCacheTTL = 0 177 }) 178} 179 180func TestSignatureWithNoneAlgorithm(t *testing.T) { 181 scenario(t, "rejects a token signed with \"none\" algorithm", func(t *testing.T, sc scenarioContext) { 182 token := signNone(t, jwt.Claims{Subject: "foo"}) 183 _, err := sc.authJWTSvc.Verify(sc.ctx, token) 184 require.Error(t, err) 185 }, configurePKIXPublicKeyFile) 186} 187 188func TestClaimValidation(t *testing.T) { 189 key := rsaKeys[0] 190 191 scenario(t, "validates iss field for equality", func(t *testing.T, sc scenarioContext) { 192 tokenValid := sign(t, key, jwt.Claims{Issuer: "http://foo"}) 193 tokenInvalid := sign(t, key, jwt.Claims{Issuer: "http://bar"}) 194 195 _, err := sc.authJWTSvc.Verify(sc.ctx, tokenValid) 196 require.NoError(t, err) 197 198 _, err = sc.authJWTSvc.Verify(sc.ctx, tokenInvalid) 199 require.Error(t, err) 200 }, configurePKIXPublicKeyFile, func(t *testing.T, cfg *setting.Cfg) { 201 cfg.JWTAuthExpectClaims = `{"iss": "http://foo"}` 202 }) 203 204 scenario(t, "validates sub field for equality", func(t *testing.T, sc scenarioContext) { 205 var err error 206 207 tokenValid := sign(t, key, jwt.Claims{Subject: "foo"}) 208 tokenInvalid := sign(t, key, jwt.Claims{Subject: "bar"}) 209 210 _, err = sc.authJWTSvc.Verify(sc.ctx, tokenValid) 211 require.NoError(t, err) 212 213 _, err = sc.authJWTSvc.Verify(sc.ctx, tokenInvalid) 214 require.Error(t, err) 215 }, configurePKIXPublicKeyFile, func(t *testing.T, cfg *setting.Cfg) { 216 cfg.JWTAuthExpectClaims = `{"sub": "foo"}` 217 }) 218 219 scenario(t, "validates aud field for inclusion", func(t *testing.T, sc scenarioContext) { 220 var err error 221 222 _, err = sc.authJWTSvc.Verify(sc.ctx, sign(t, key, jwt.Claims{Audience: []string{"bar", "foo"}})) 223 require.NoError(t, err) 224 225 _, err = sc.authJWTSvc.Verify(sc.ctx, sign(t, key, jwt.Claims{Audience: []string{"foo", "bar", "baz"}})) 226 require.NoError(t, err) 227 228 _, err = sc.authJWTSvc.Verify(sc.ctx, sign(t, key, jwt.Claims{Audience: []string{"foo"}})) 229 require.Error(t, err) 230 231 _, err = sc.authJWTSvc.Verify(sc.ctx, sign(t, key, jwt.Claims{Audience: []string{"bar", "baz"}})) 232 require.Error(t, err) 233 234 _, err = sc.authJWTSvc.Verify(sc.ctx, sign(t, key, jwt.Claims{Audience: []string{"baz"}})) 235 require.Error(t, err) 236 }, configurePKIXPublicKeyFile, func(t *testing.T, cfg *setting.Cfg) { 237 cfg.JWTAuthExpectClaims = `{"aud": ["foo", "bar"]}` 238 }) 239 240 scenario(t, "validates non-registered (custom) claims for equality", func(t *testing.T, sc scenarioContext) { 241 var err error 242 243 _, err = sc.authJWTSvc.Verify(sc.ctx, sign(t, key, map[string]interface{}{"my-str": "foo", "my-number": 123})) 244 require.NoError(t, err) 245 246 _, err = sc.authJWTSvc.Verify(sc.ctx, sign(t, key, map[string]interface{}{"my-str": "bar", "my-number": 123})) 247 require.Error(t, err) 248 249 _, err = sc.authJWTSvc.Verify(sc.ctx, sign(t, key, map[string]interface{}{"my-str": "foo", "my-number": 100})) 250 require.Error(t, err) 251 252 _, err = sc.authJWTSvc.Verify(sc.ctx, sign(t, key, map[string]interface{}{"my-str": "foo"})) 253 require.Error(t, err) 254 255 _, err = sc.authJWTSvc.Verify(sc.ctx, sign(t, key, map[string]interface{}{"my-number": 123})) 256 require.Error(t, err) 257 }, configurePKIXPublicKeyFile, func(t *testing.T, cfg *setting.Cfg) { 258 cfg.JWTAuthExpectClaims = `{"my-str": "foo", "my-number": 123}` 259 }) 260 261 scenario(t, "validates exp claim of the token", func(t *testing.T, sc scenarioContext) { 262 var err error 263 264 _, err = sc.authJWTSvc.Verify(sc.ctx, sign(t, key, jwt.Claims{Expiry: jwt.NewNumericDate(time.Now().Add(time.Hour))})) 265 require.NoError(t, err) 266 267 _, err = sc.authJWTSvc.Verify(sc.ctx, sign(t, key, jwt.Claims{Expiry: jwt.NewNumericDate(time.Now().Add(-time.Hour))})) 268 require.Error(t, err) 269 }, configurePKIXPublicKeyFile) 270 271 scenario(t, "validates nbf claim of the token", func(t *testing.T, sc scenarioContext) { 272 var err error 273 274 _, err = sc.authJWTSvc.Verify(sc.ctx, sign(t, key, jwt.Claims{NotBefore: jwt.NewNumericDate(time.Now().Add(-time.Hour))})) 275 require.NoError(t, err) 276 277 _, err = sc.authJWTSvc.Verify(sc.ctx, sign(t, key, jwt.Claims{NotBefore: jwt.NewNumericDate(time.Now().Add(time.Hour))})) 278 require.Error(t, err) 279 }, configurePKIXPublicKeyFile) 280 281 scenario(t, "validates iat claim of the token", func(t *testing.T, sc scenarioContext) { 282 var err error 283 284 _, err = sc.authJWTSvc.Verify(sc.ctx, sign(t, key, jwt.Claims{IssuedAt: jwt.NewNumericDate(time.Now().Add(-time.Hour))})) 285 require.NoError(t, err) 286 287 _, err = sc.authJWTSvc.Verify(sc.ctx, sign(t, key, jwt.Claims{IssuedAt: jwt.NewNumericDate(time.Now().Add(time.Hour))})) 288 require.Error(t, err) 289 }, configurePKIXPublicKeyFile) 290} 291 292func jwkHTTPScenario(t *testing.T, desc string, fn scenarioFunc, cbs ...configureFunc) { 293 t.Helper() 294 t.Run(desc, func(t *testing.T) { 295 ts := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 296 if err := json.NewEncoder(w).Encode(jwksPublic); err != nil { 297 panic(err) 298 } 299 })) 300 t.Cleanup(ts.Close) 301 302 configure := func(t *testing.T, cfg *setting.Cfg) { 303 cfg.JWTAuthJWKSetURL = ts.URL 304 } 305 runner := scenarioRunner(func(t *testing.T, sc scenarioContext) { 306 keySet := sc.authJWTSvc.keySet.(*keySetHTTP) 307 keySet.client = ts.Client() 308 fn(t, sc) 309 }, append([]configureFunc{configure}, cbs...)...) 310 runner(t) 311 }) 312} 313 314func jwkCachingScenario(t *testing.T, desc string, fn cachingScenarioFunc, cbs ...configureFunc) { 315 t.Helper() 316 317 t.Run(desc, func(t *testing.T) { 318 var reqCount int 319 320 // We run a server that each call responds differently. 321 ts := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 322 if reqCount++; reqCount > 2 { 323 panic("calling more than two times is not supported") 324 } 325 jwks := jose.JSONWebKeySet{ 326 Keys: []jose.JSONWebKey{jwksPublic.Keys[reqCount-1]}, 327 } 328 if err := json.NewEncoder(w).Encode(jwks); err != nil { 329 panic(err) 330 } 331 })) 332 t.Cleanup(ts.Close) 333 334 configure := func(t *testing.T, cfg *setting.Cfg) { 335 cfg.JWTAuthJWKSetURL = ts.URL 336 cfg.JWTAuthCacheTTL = time.Hour 337 } 338 runner := scenarioRunner(func(t *testing.T, sc scenarioContext) { 339 keySet := sc.authJWTSvc.keySet.(*keySetHTTP) 340 keySet.client = ts.Client() 341 fn(t, cachingScenarioContext{scenarioContext: sc, reqCount: &reqCount}) 342 }, append([]configureFunc{configure}, cbs...)...) 343 344 runner(t) 345 }) 346} 347 348func TestBase64Paddings(t *testing.T) { 349 key := rsaKeys[0] 350 351 scenario(t, "verifies a token with base64 padding (non compliant rfc7515#section-2 but accepted)", func(t *testing.T, sc scenarioContext) { 352 token := sign(t, key, jwt.Claims{ 353 Subject: subject, 354 }) 355 var tokenParts []string 356 for i, part := range strings.Split(token, ".") { 357 // Create parts with different padding numbers to test multiple cases. 358 tokenParts = append(tokenParts, part+strings.Repeat(string(base64.StdPadding), i)) 359 } 360 token = strings.Join(tokenParts, ".") 361 verifiedClaims, err := sc.authJWTSvc.Verify(sc.ctx, token) 362 require.NoError(t, err) 363 assert.Equal(t, verifiedClaims["sub"], subject) 364 }, configurePKIXPublicKeyFile) 365} 366 367func scenario(t *testing.T, desc string, fn scenarioFunc, cbs ...configureFunc) { 368 t.Helper() 369 370 t.Run(desc, scenarioRunner(fn, cbs...)) 371} 372 373func initAuthService(t *testing.T, cbs ...configureFunc) (*AuthService, error) { 374 t.Helper() 375 376 cfg := setting.NewCfg() 377 cfg.JWTAuthEnabled = true 378 cfg.JWTAuthExpectClaims = "{}" 379 380 for _, cb := range cbs { 381 cb(t, cfg) 382 } 383 384 service := newService(cfg, remotecache.NewFakeStore(t)) 385 err := service.init() 386 return service, err 387} 388 389func scenarioRunner(fn scenarioFunc, cbs ...configureFunc) func(t *testing.T) { 390 return func(t *testing.T) { 391 authJWTSvc, err := initAuthService(t, cbs...) 392 require.NoError(t, err) 393 394 fn(t, scenarioContext{ 395 ctx: context.Background(), 396 cfg: authJWTSvc.Cfg, 397 authJWTSvc: authJWTSvc, 398 }) 399 } 400} 401 402func configurePKIXPublicKeyFile(t *testing.T, cfg *setting.Cfg) { 403 t.Helper() 404 405 file, err := ioutil.TempFile(os.TempDir(), "public-key-*.pem") 406 require.NoError(t, err) 407 t.Cleanup(func() { 408 if err := os.Remove(file.Name()); err != nil { 409 panic(err) 410 } 411 }) 412 413 blockBytes, err := x509.MarshalPKIXPublicKey(rsaKeys[0].Public()) 414 require.NoError(t, err) 415 416 require.NoError(t, pem.Encode(file, &pem.Block{ 417 Type: "PUBLIC KEY", 418 Bytes: blockBytes, 419 })) 420 require.NoError(t, file.Close()) 421 422 cfg.JWTAuthKeyFile = file.Name() 423} 424