1package jwt
2
3import (
4	"context"
5	"crypto/x509"
6	"encoding/base64"
7	"encoding/json"
8	"encoding/pem"
9	"io/ioutil"
10	"net/http"
11	"net/http/httptest"
12	"os"
13	"strings"
14	"testing"
15	"time"
16
17	"github.com/stretchr/testify/assert"
18	"github.com/stretchr/testify/require"
19	jose "gopkg.in/square/go-jose.v2"
20	"gopkg.in/square/go-jose.v2/jwt"
21
22	"github.com/grafana/grafana/pkg/infra/remotecache"
23	"github.com/grafana/grafana/pkg/setting"
24)
25
26type scenarioContext struct {
27	ctx        context.Context
28	cfg        *setting.Cfg
29	authJWTSvc *AuthService
30}
31
32type cachingScenarioContext struct {
33	scenarioContext
34	reqCount *int
35}
36
37type configureFunc func(*testing.T, *setting.Cfg)
38type scenarioFunc func(*testing.T, scenarioContext)
39type cachingScenarioFunc func(*testing.T, cachingScenarioContext)
40
41const subject = "foo-subj"
42
43func TestVerifyUsingPKIXPublicKeyFile(t *testing.T) {
44	key := rsaKeys[0]
45	unknownKey := rsaKeys[1]
46
47	scenario(t, "verifies a token", func(t *testing.T, sc scenarioContext) {
48		token := sign(t, key, jwt.Claims{
49			Subject: subject,
50		})
51		verifiedClaims, err := sc.authJWTSvc.Verify(sc.ctx, token)
52		require.NoError(t, err)
53		assert.Equal(t, verifiedClaims["sub"], subject)
54	}, configurePKIXPublicKeyFile)
55
56	scenario(t, "rejects a token signed by unknown key", func(t *testing.T, sc scenarioContext) {
57		token := sign(t, unknownKey, jwt.Claims{
58			Subject: subject,
59		})
60		_, err := sc.authJWTSvc.Verify(sc.ctx, token)
61		require.Error(t, err)
62	}, configurePKIXPublicKeyFile)
63}
64
65func TestVerifyUsingJWKSetFile(t *testing.T) {
66	configure := func(t *testing.T, cfg *setting.Cfg) {
67		t.Helper()
68
69		file, err := ioutil.TempFile(os.TempDir(), "jwk-*.json")
70		require.NoError(t, err)
71		t.Cleanup(func() {
72			if err := os.Remove(file.Name()); err != nil {
73				panic(err)
74			}
75		})
76
77		require.NoError(t, json.NewEncoder(file).Encode(jwksPublic))
78		require.NoError(t, file.Close())
79
80		cfg.JWTAuthJWKSetFile = file.Name()
81	}
82
83	scenario(t, "verifies a token signed with a key from the set", func(t *testing.T, sc scenarioContext) {
84		token := sign(t, &jwKeys[0], jwt.Claims{Subject: subject})
85		verifiedClaims, err := sc.authJWTSvc.Verify(sc.ctx, token)
86		require.NoError(t, err)
87		assert.Equal(t, verifiedClaims["sub"], subject)
88	}, configure)
89
90	scenario(t, "verifies a token signed with another key from the set", func(t *testing.T, sc scenarioContext) {
91		token := sign(t, &jwKeys[1], jwt.Claims{Subject: subject})
92		verifiedClaims, err := sc.authJWTSvc.Verify(sc.ctx, token)
93		require.NoError(t, err)
94		assert.Equal(t, verifiedClaims["sub"], subject)
95	}, configure)
96
97	scenario(t, "rejects a token signed with a key not from the set", func(t *testing.T, sc scenarioContext) {
98		token := sign(t, jwKeys[2], jwt.Claims{Subject: subject})
99		_, err := sc.authJWTSvc.Verify(sc.ctx, token)
100		require.Error(t, err)
101	}, configure)
102}
103
104func TestVerifyUsingJWKSetURL(t *testing.T) {
105	t.Run("should refuse to start with non-https URL", func(t *testing.T) {
106		var err error
107
108		_, err = initAuthService(t, func(t *testing.T, cfg *setting.Cfg) {
109			cfg.JWTAuthJWKSetURL = "https://example.com/.well-known/jwks.json"
110		})
111		require.NoError(t, err)
112
113		_, err = initAuthService(t, func(t *testing.T, cfg *setting.Cfg) {
114			cfg.JWTAuthJWKSetURL = "http://example.com/.well-known/jwks.json"
115		})
116		require.Error(t, err)
117	})
118
119	jwkHTTPScenario(t, "verifies a token signed with a key from the set", func(t *testing.T, sc scenarioContext) {
120		token := sign(t, &jwKeys[0], jwt.Claims{Subject: subject})
121		verifiedClaims, err := sc.authJWTSvc.Verify(sc.ctx, token)
122		require.NoError(t, err)
123		assert.Equal(t, verifiedClaims["sub"], subject)
124	})
125
126	jwkHTTPScenario(t, "verifies a token signed with another key from the set", func(t *testing.T, sc scenarioContext) {
127		token := sign(t, &jwKeys[1], jwt.Claims{Subject: subject})
128		verifiedClaims, err := sc.authJWTSvc.Verify(sc.ctx, token)
129		require.NoError(t, err)
130		assert.Equal(t, verifiedClaims["sub"], subject)
131	})
132
133	jwkHTTPScenario(t, "rejects a token signed with a key not from the set", func(t *testing.T, sc scenarioContext) {
134		token := sign(t, jwKeys[2], jwt.Claims{Subject: subject})
135		_, err := sc.authJWTSvc.Verify(sc.ctx, token)
136		require.Error(t, err)
137	})
138}
139
140func TestCachingJWKHTTPResponse(t *testing.T) {
141	jwkCachingScenario(t, "caches the jwk response", func(t *testing.T, sc cachingScenarioContext) {
142		for i := 0; i < 5; i++ {
143			token := sign(t, &jwKeys[0], jwt.Claims{Subject: subject})
144			_, err := sc.authJWTSvc.Verify(sc.ctx, token)
145			require.NoError(t, err, "verify call %d", i+1)
146		}
147
148		assert.Equal(t, 1, *sc.reqCount)
149	})
150
151	jwkCachingScenario(t, "respects TTL setting (while cached)", func(t *testing.T, sc cachingScenarioContext) {
152		var err error
153
154		token0 := sign(t, &jwKeys[0], jwt.Claims{Subject: subject})
155		token1 := sign(t, &jwKeys[1], jwt.Claims{Subject: subject})
156
157		_, err = sc.authJWTSvc.Verify(sc.ctx, token0)
158		require.NoError(t, err)
159		_, err = sc.authJWTSvc.Verify(sc.ctx, token1)
160		require.Error(t, err)
161
162		assert.Equal(t, 1, *sc.reqCount)
163	}, func(t *testing.T, cfg *setting.Cfg) {
164		// Arbitrary high value, several times what the test should take.
165		cfg.JWTAuthCacheTTL = time.Minute
166	})
167
168	jwkCachingScenario(t, "does not cache the response when TTL is zero", func(t *testing.T, sc cachingScenarioContext) {
169		for i := 0; i < 2; i++ {
170			_, err := sc.authJWTSvc.Verify(sc.ctx, sign(t, &jwKeys[i], jwt.Claims{Subject: subject}))
171			require.NoError(t, err, "verify call %d", i+1)
172		}
173
174		assert.Equal(t, 2, *sc.reqCount)
175	}, func(t *testing.T, cfg *setting.Cfg) {
176		cfg.JWTAuthCacheTTL = 0
177	})
178}
179
180func TestSignatureWithNoneAlgorithm(t *testing.T) {
181	scenario(t, "rejects a token signed with \"none\" algorithm", func(t *testing.T, sc scenarioContext) {
182		token := signNone(t, jwt.Claims{Subject: "foo"})
183		_, err := sc.authJWTSvc.Verify(sc.ctx, token)
184		require.Error(t, err)
185	}, configurePKIXPublicKeyFile)
186}
187
188func TestClaimValidation(t *testing.T) {
189	key := rsaKeys[0]
190
191	scenario(t, "validates iss field for equality", func(t *testing.T, sc scenarioContext) {
192		tokenValid := sign(t, key, jwt.Claims{Issuer: "http://foo"})
193		tokenInvalid := sign(t, key, jwt.Claims{Issuer: "http://bar"})
194
195		_, err := sc.authJWTSvc.Verify(sc.ctx, tokenValid)
196		require.NoError(t, err)
197
198		_, err = sc.authJWTSvc.Verify(sc.ctx, tokenInvalid)
199		require.Error(t, err)
200	}, configurePKIXPublicKeyFile, func(t *testing.T, cfg *setting.Cfg) {
201		cfg.JWTAuthExpectClaims = `{"iss": "http://foo"}`
202	})
203
204	scenario(t, "validates sub field for equality", func(t *testing.T, sc scenarioContext) {
205		var err error
206
207		tokenValid := sign(t, key, jwt.Claims{Subject: "foo"})
208		tokenInvalid := sign(t, key, jwt.Claims{Subject: "bar"})
209
210		_, err = sc.authJWTSvc.Verify(sc.ctx, tokenValid)
211		require.NoError(t, err)
212
213		_, err = sc.authJWTSvc.Verify(sc.ctx, tokenInvalid)
214		require.Error(t, err)
215	}, configurePKIXPublicKeyFile, func(t *testing.T, cfg *setting.Cfg) {
216		cfg.JWTAuthExpectClaims = `{"sub": "foo"}`
217	})
218
219	scenario(t, "validates aud field for inclusion", func(t *testing.T, sc scenarioContext) {
220		var err error
221
222		_, err = sc.authJWTSvc.Verify(sc.ctx, sign(t, key, jwt.Claims{Audience: []string{"bar", "foo"}}))
223		require.NoError(t, err)
224
225		_, err = sc.authJWTSvc.Verify(sc.ctx, sign(t, key, jwt.Claims{Audience: []string{"foo", "bar", "baz"}}))
226		require.NoError(t, err)
227
228		_, err = sc.authJWTSvc.Verify(sc.ctx, sign(t, key, jwt.Claims{Audience: []string{"foo"}}))
229		require.Error(t, err)
230
231		_, err = sc.authJWTSvc.Verify(sc.ctx, sign(t, key, jwt.Claims{Audience: []string{"bar", "baz"}}))
232		require.Error(t, err)
233
234		_, err = sc.authJWTSvc.Verify(sc.ctx, sign(t, key, jwt.Claims{Audience: []string{"baz"}}))
235		require.Error(t, err)
236	}, configurePKIXPublicKeyFile, func(t *testing.T, cfg *setting.Cfg) {
237		cfg.JWTAuthExpectClaims = `{"aud": ["foo", "bar"]}`
238	})
239
240	scenario(t, "validates non-registered (custom) claims for equality", func(t *testing.T, sc scenarioContext) {
241		var err error
242
243		_, err = sc.authJWTSvc.Verify(sc.ctx, sign(t, key, map[string]interface{}{"my-str": "foo", "my-number": 123}))
244		require.NoError(t, err)
245
246		_, err = sc.authJWTSvc.Verify(sc.ctx, sign(t, key, map[string]interface{}{"my-str": "bar", "my-number": 123}))
247		require.Error(t, err)
248
249		_, err = sc.authJWTSvc.Verify(sc.ctx, sign(t, key, map[string]interface{}{"my-str": "foo", "my-number": 100}))
250		require.Error(t, err)
251
252		_, err = sc.authJWTSvc.Verify(sc.ctx, sign(t, key, map[string]interface{}{"my-str": "foo"}))
253		require.Error(t, err)
254
255		_, err = sc.authJWTSvc.Verify(sc.ctx, sign(t, key, map[string]interface{}{"my-number": 123}))
256		require.Error(t, err)
257	}, configurePKIXPublicKeyFile, func(t *testing.T, cfg *setting.Cfg) {
258		cfg.JWTAuthExpectClaims = `{"my-str": "foo", "my-number": 123}`
259	})
260
261	scenario(t, "validates exp claim of the token", func(t *testing.T, sc scenarioContext) {
262		var err error
263
264		_, err = sc.authJWTSvc.Verify(sc.ctx, sign(t, key, jwt.Claims{Expiry: jwt.NewNumericDate(time.Now().Add(time.Hour))}))
265		require.NoError(t, err)
266
267		_, err = sc.authJWTSvc.Verify(sc.ctx, sign(t, key, jwt.Claims{Expiry: jwt.NewNumericDate(time.Now().Add(-time.Hour))}))
268		require.Error(t, err)
269	}, configurePKIXPublicKeyFile)
270
271	scenario(t, "validates nbf claim of the token", func(t *testing.T, sc scenarioContext) {
272		var err error
273
274		_, err = sc.authJWTSvc.Verify(sc.ctx, sign(t, key, jwt.Claims{NotBefore: jwt.NewNumericDate(time.Now().Add(-time.Hour))}))
275		require.NoError(t, err)
276
277		_, err = sc.authJWTSvc.Verify(sc.ctx, sign(t, key, jwt.Claims{NotBefore: jwt.NewNumericDate(time.Now().Add(time.Hour))}))
278		require.Error(t, err)
279	}, configurePKIXPublicKeyFile)
280
281	scenario(t, "validates iat claim of the token", func(t *testing.T, sc scenarioContext) {
282		var err error
283
284		_, err = sc.authJWTSvc.Verify(sc.ctx, sign(t, key, jwt.Claims{IssuedAt: jwt.NewNumericDate(time.Now().Add(-time.Hour))}))
285		require.NoError(t, err)
286
287		_, err = sc.authJWTSvc.Verify(sc.ctx, sign(t, key, jwt.Claims{IssuedAt: jwt.NewNumericDate(time.Now().Add(time.Hour))}))
288		require.Error(t, err)
289	}, configurePKIXPublicKeyFile)
290}
291
292func jwkHTTPScenario(t *testing.T, desc string, fn scenarioFunc, cbs ...configureFunc) {
293	t.Helper()
294	t.Run(desc, func(t *testing.T) {
295		ts := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
296			if err := json.NewEncoder(w).Encode(jwksPublic); err != nil {
297				panic(err)
298			}
299		}))
300		t.Cleanup(ts.Close)
301
302		configure := func(t *testing.T, cfg *setting.Cfg) {
303			cfg.JWTAuthJWKSetURL = ts.URL
304		}
305		runner := scenarioRunner(func(t *testing.T, sc scenarioContext) {
306			keySet := sc.authJWTSvc.keySet.(*keySetHTTP)
307			keySet.client = ts.Client()
308			fn(t, sc)
309		}, append([]configureFunc{configure}, cbs...)...)
310		runner(t)
311	})
312}
313
314func jwkCachingScenario(t *testing.T, desc string, fn cachingScenarioFunc, cbs ...configureFunc) {
315	t.Helper()
316
317	t.Run(desc, func(t *testing.T) {
318		var reqCount int
319
320		// We run a server that each call responds differently.
321		ts := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
322			if reqCount++; reqCount > 2 {
323				panic("calling more than two times is not supported")
324			}
325			jwks := jose.JSONWebKeySet{
326				Keys: []jose.JSONWebKey{jwksPublic.Keys[reqCount-1]},
327			}
328			if err := json.NewEncoder(w).Encode(jwks); err != nil {
329				panic(err)
330			}
331		}))
332		t.Cleanup(ts.Close)
333
334		configure := func(t *testing.T, cfg *setting.Cfg) {
335			cfg.JWTAuthJWKSetURL = ts.URL
336			cfg.JWTAuthCacheTTL = time.Hour
337		}
338		runner := scenarioRunner(func(t *testing.T, sc scenarioContext) {
339			keySet := sc.authJWTSvc.keySet.(*keySetHTTP)
340			keySet.client = ts.Client()
341			fn(t, cachingScenarioContext{scenarioContext: sc, reqCount: &reqCount})
342		}, append([]configureFunc{configure}, cbs...)...)
343
344		runner(t)
345	})
346}
347
348func TestBase64Paddings(t *testing.T) {
349	key := rsaKeys[0]
350
351	scenario(t, "verifies a token with base64 padding (non compliant rfc7515#section-2 but accepted)", func(t *testing.T, sc scenarioContext) {
352		token := sign(t, key, jwt.Claims{
353			Subject: subject,
354		})
355		var tokenParts []string
356		for i, part := range strings.Split(token, ".") {
357			// Create parts with different padding numbers to test multiple cases.
358			tokenParts = append(tokenParts, part+strings.Repeat(string(base64.StdPadding), i))
359		}
360		token = strings.Join(tokenParts, ".")
361		verifiedClaims, err := sc.authJWTSvc.Verify(sc.ctx, token)
362		require.NoError(t, err)
363		assert.Equal(t, verifiedClaims["sub"], subject)
364	}, configurePKIXPublicKeyFile)
365}
366
367func scenario(t *testing.T, desc string, fn scenarioFunc, cbs ...configureFunc) {
368	t.Helper()
369
370	t.Run(desc, scenarioRunner(fn, cbs...))
371}
372
373func initAuthService(t *testing.T, cbs ...configureFunc) (*AuthService, error) {
374	t.Helper()
375
376	cfg := setting.NewCfg()
377	cfg.JWTAuthEnabled = true
378	cfg.JWTAuthExpectClaims = "{}"
379
380	for _, cb := range cbs {
381		cb(t, cfg)
382	}
383
384	service := newService(cfg, remotecache.NewFakeStore(t))
385	err := service.init()
386	return service, err
387}
388
389func scenarioRunner(fn scenarioFunc, cbs ...configureFunc) func(t *testing.T) {
390	return func(t *testing.T) {
391		authJWTSvc, err := initAuthService(t, cbs...)
392		require.NoError(t, err)
393
394		fn(t, scenarioContext{
395			ctx:        context.Background(),
396			cfg:        authJWTSvc.Cfg,
397			authJWTSvc: authJWTSvc,
398		})
399	}
400}
401
402func configurePKIXPublicKeyFile(t *testing.T, cfg *setting.Cfg) {
403	t.Helper()
404
405	file, err := ioutil.TempFile(os.TempDir(), "public-key-*.pem")
406	require.NoError(t, err)
407	t.Cleanup(func() {
408		if err := os.Remove(file.Name()); err != nil {
409			panic(err)
410		}
411	})
412
413	blockBytes, err := x509.MarshalPKIXPublicKey(rsaKeys[0].Public())
414	require.NoError(t, err)
415
416	require.NoError(t, pem.Encode(file, &pem.Block{
417		Type:  "PUBLIC KEY",
418		Bytes: blockBytes,
419	}))
420	require.NoError(t, file.Close())
421
422	cfg.JWTAuthKeyFile = file.Name()
423}
424