1// Copyright 2020 Google LLC.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package idtoken
6
7import (
8	"bytes"
9	"context"
10	"crypto"
11	"crypto/ecdsa"
12	"crypto/elliptic"
13	"crypto/rand"
14	"crypto/rsa"
15	"encoding/base64"
16	"encoding/json"
17	"io/ioutil"
18	"math/big"
19	"net/http"
20	"testing"
21
22	"google.golang.org/api/option"
23)
24
25const (
26	keyID        = "1234"
27	testAudience = "test-audience"
28)
29
30func TestValidateRS256(t *testing.T) {
31	idToken, pk := createRS256JWT(t)
32	tests := []struct {
33		name    string
34		keyID   string
35		n       *big.Int
36		e       int
37		wantErr bool
38	}{
39		{name: "works", keyID: keyID, n: pk.N, e: pk.E, wantErr: false},
40		{name: "no matching key", keyID: "5678", n: pk.N, e: pk.E, wantErr: true},
41		{name: "sig does not match", keyID: keyID, n: new(big.Int).SetBytes([]byte("42")), e: 42, wantErr: true},
42	}
43	for _, tt := range tests {
44		t.Run(tt.name, func(t *testing.T) {
45			client := &http.Client{
46				Transport: RoundTripFn(func(req *http.Request) *http.Response {
47					cr := certResponse{
48						Keys: []jwk{
49							{
50								Kid: tt.keyID,
51								N:   base64.RawURLEncoding.EncodeToString(tt.n.Bytes()),
52								E:   base64.RawURLEncoding.EncodeToString(new(big.Int).SetInt64(int64(tt.e)).Bytes()),
53							},
54						},
55					}
56					b, err := json.Marshal(&cr)
57					if err != nil {
58						t.Fatalf("unable to marshal response: %v", err)
59					}
60					return &http.Response{
61						StatusCode: 200,
62						Body:       ioutil.NopCloser(bytes.NewReader(b)),
63						Header:     make(http.Header),
64					}
65				}),
66			}
67
68			v, err := NewValidator(context.Background(), option.WithHTTPClient(client))
69			if err != nil {
70				t.Fatalf("NewValidator(...) = %q, want nil", err)
71			}
72			payload, err := v.Validate(context.Background(), idToken, testAudience)
73			if !tt.wantErr && err != nil {
74				t.Fatalf("Validate(ctx, %s, %s) = %q, want nil", idToken, testAudience, err)
75			}
76			if !tt.wantErr && payload.Audience != testAudience {
77				t.Fatalf("got %v, want %v", payload.Audience, testAudience)
78			}
79		})
80	}
81}
82
83func TestValidateES256(t *testing.T) {
84	idToken, pk := createES256JWT(t)
85	tests := []struct {
86		name    string
87		keyID   string
88		x       *big.Int
89		y       *big.Int
90		wantErr bool
91	}{
92		{name: "works", keyID: keyID, x: pk.X, y: pk.Y, wantErr: false},
93		{name: "no matching key", keyID: "5678", x: pk.X, y: pk.Y, wantErr: true},
94		{name: "sig does not match", keyID: keyID, x: new(big.Int), y: new(big.Int), wantErr: true},
95	}
96	for _, tt := range tests {
97		t.Run(tt.name, func(t *testing.T) {
98			client := &http.Client{
99				Transport: RoundTripFn(func(req *http.Request) *http.Response {
100					cr := certResponse{
101						Keys: []jwk{
102							{
103								Kid: tt.keyID,
104								X:   base64.RawURLEncoding.EncodeToString(tt.x.Bytes()),
105								Y:   base64.RawURLEncoding.EncodeToString(tt.y.Bytes()),
106							},
107						},
108					}
109					b, err := json.Marshal(&cr)
110					if err != nil {
111						t.Fatalf("unable to marshal response: %v", err)
112					}
113					return &http.Response{
114						StatusCode: 200,
115						Body:       ioutil.NopCloser(bytes.NewReader(b)),
116						Header:     make(http.Header),
117					}
118				}),
119			}
120
121			v, err := NewValidator(context.Background(), option.WithHTTPClient(client))
122			if err != nil {
123				t.Fatalf("NewValidator(...) = %q, want nil", err)
124			}
125			payload, err := v.Validate(context.Background(), idToken, testAudience)
126			if !tt.wantErr && err != nil {
127				t.Fatalf("Validate(ctx, %s, %s) = %q, want nil", idToken, testAudience, err)
128			}
129			if !tt.wantErr && payload.Audience != testAudience {
130				t.Fatalf("got %v, want %v", payload.Audience, testAudience)
131			}
132		})
133	}
134}
135
136func createES256JWT(t *testing.T) (string, ecdsa.PublicKey) {
137	t.Helper()
138	token := commonToken(t, "ES256")
139	privateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
140	if err != nil {
141		t.Fatalf("unable to generate key: %v", err)
142	}
143	r, s, err := ecdsa.Sign(rand.Reader, privateKey, token.hashedContent())
144	if err != nil {
145		t.Fatalf("unable to sign content: %v", err)
146	}
147	var sig []byte
148	sig = append(sig, r.Bytes()...)
149	sig = append(sig, s.Bytes()...)
150	token.signature = base64.RawURLEncoding.EncodeToString(sig)
151	return token.String(), privateKey.PublicKey
152}
153
154func createRS256JWT(t *testing.T) (string, rsa.PublicKey) {
155	t.Helper()
156	token := commonToken(t, "RS256")
157	privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
158	if err != nil {
159		t.Fatalf("unable to generate key: %v", err)
160	}
161	sig, err := rsa.SignPKCS1v15(rand.Reader, privateKey, crypto.SHA256, token.hashedContent())
162	if err != nil {
163		t.Fatalf("unable to sign content: %v", err)
164	}
165	token.signature = base64.RawURLEncoding.EncodeToString(sig)
166	return token.String(), privateKey.PublicKey
167}
168
169func commonToken(t *testing.T, alg string) *jwt {
170	t.Helper()
171	header := jwtHeader{
172		KeyID:     keyID,
173		Algorithm: alg,
174		Type:      "JWT",
175	}
176	payload := Payload{
177		Issuer:   "example.com",
178		Audience: testAudience,
179	}
180
181	hb, err := json.Marshal(&header)
182	if err != nil {
183		t.Fatalf("unable to marshall header: %v", err)
184	}
185	pb, err := json.Marshal(&payload)
186	if err != nil {
187		t.Fatalf("unable to marshall payload: %v", err)
188	}
189	eb := base64.RawURLEncoding.EncodeToString(hb)
190	ep := base64.RawURLEncoding.EncodeToString(pb)
191	return &jwt{
192		header:  eb,
193		payload: ep,
194	}
195}
196
197type RoundTripFn func(req *http.Request) *http.Response
198
199func (f RoundTripFn) RoundTrip(req *http.Request) (*http.Response, error) { return f(req), nil }
200