1// Copyright 2020 Google LLC. 2// Use of this source code is governed by a BSD-style 3// license that can be found in the LICENSE file. 4 5package idtoken 6 7import ( 8 "bytes" 9 "context" 10 "crypto" 11 "crypto/ecdsa" 12 "crypto/elliptic" 13 "crypto/rand" 14 "crypto/rsa" 15 "encoding/base64" 16 "encoding/json" 17 "io/ioutil" 18 "math/big" 19 "net/http" 20 "testing" 21 22 "google.golang.org/api/option" 23) 24 25const ( 26 keyID = "1234" 27 testAudience = "test-audience" 28) 29 30func TestValidateRS256(t *testing.T) { 31 idToken, pk := createRS256JWT(t) 32 tests := []struct { 33 name string 34 keyID string 35 n *big.Int 36 e int 37 wantErr bool 38 }{ 39 {name: "works", keyID: keyID, n: pk.N, e: pk.E, wantErr: false}, 40 {name: "no matching key", keyID: "5678", n: pk.N, e: pk.E, wantErr: true}, 41 {name: "sig does not match", keyID: keyID, n: new(big.Int).SetBytes([]byte("42")), e: 42, wantErr: true}, 42 } 43 for _, tt := range tests { 44 t.Run(tt.name, func(t *testing.T) { 45 client := &http.Client{ 46 Transport: RoundTripFn(func(req *http.Request) *http.Response { 47 cr := certResponse{ 48 Keys: []jwk{ 49 { 50 Kid: tt.keyID, 51 N: base64.RawURLEncoding.EncodeToString(tt.n.Bytes()), 52 E: base64.RawURLEncoding.EncodeToString(new(big.Int).SetInt64(int64(tt.e)).Bytes()), 53 }, 54 }, 55 } 56 b, err := json.Marshal(&cr) 57 if err != nil { 58 t.Fatalf("unable to marshal response: %v", err) 59 } 60 return &http.Response{ 61 StatusCode: 200, 62 Body: ioutil.NopCloser(bytes.NewReader(b)), 63 Header: make(http.Header), 64 } 65 }), 66 } 67 68 v, err := NewValidator(context.Background(), option.WithHTTPClient(client)) 69 if err != nil { 70 t.Fatalf("NewValidator(...) = %q, want nil", err) 71 } 72 payload, err := v.Validate(context.Background(), idToken, testAudience) 73 if !tt.wantErr && err != nil { 74 t.Fatalf("Validate(ctx, %s, %s) = %q, want nil", idToken, testAudience, err) 75 } 76 if !tt.wantErr && payload.Audience != testAudience { 77 t.Fatalf("got %v, want %v", payload.Audience, testAudience) 78 } 79 }) 80 } 81} 82 83func TestValidateES256(t *testing.T) { 84 idToken, pk := createES256JWT(t) 85 tests := []struct { 86 name string 87 keyID string 88 x *big.Int 89 y *big.Int 90 wantErr bool 91 }{ 92 {name: "works", keyID: keyID, x: pk.X, y: pk.Y, wantErr: false}, 93 {name: "no matching key", keyID: "5678", x: pk.X, y: pk.Y, wantErr: true}, 94 {name: "sig does not match", keyID: keyID, x: new(big.Int), y: new(big.Int), wantErr: true}, 95 } 96 for _, tt := range tests { 97 t.Run(tt.name, func(t *testing.T) { 98 client := &http.Client{ 99 Transport: RoundTripFn(func(req *http.Request) *http.Response { 100 cr := certResponse{ 101 Keys: []jwk{ 102 { 103 Kid: tt.keyID, 104 X: base64.RawURLEncoding.EncodeToString(tt.x.Bytes()), 105 Y: base64.RawURLEncoding.EncodeToString(tt.y.Bytes()), 106 }, 107 }, 108 } 109 b, err := json.Marshal(&cr) 110 if err != nil { 111 t.Fatalf("unable to marshal response: %v", err) 112 } 113 return &http.Response{ 114 StatusCode: 200, 115 Body: ioutil.NopCloser(bytes.NewReader(b)), 116 Header: make(http.Header), 117 } 118 }), 119 } 120 121 v, err := NewValidator(context.Background(), option.WithHTTPClient(client)) 122 if err != nil { 123 t.Fatalf("NewValidator(...) = %q, want nil", err) 124 } 125 payload, err := v.Validate(context.Background(), idToken, testAudience) 126 if !tt.wantErr && err != nil { 127 t.Fatalf("Validate(ctx, %s, %s) = %q, want nil", idToken, testAudience, err) 128 } 129 if !tt.wantErr && payload.Audience != testAudience { 130 t.Fatalf("got %v, want %v", payload.Audience, testAudience) 131 } 132 }) 133 } 134} 135 136func createES256JWT(t *testing.T) (string, ecdsa.PublicKey) { 137 t.Helper() 138 token := commonToken(t, "ES256") 139 privateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) 140 if err != nil { 141 t.Fatalf("unable to generate key: %v", err) 142 } 143 r, s, err := ecdsa.Sign(rand.Reader, privateKey, token.hashedContent()) 144 if err != nil { 145 t.Fatalf("unable to sign content: %v", err) 146 } 147 var sig []byte 148 sig = append(sig, r.Bytes()...) 149 sig = append(sig, s.Bytes()...) 150 token.signature = base64.RawURLEncoding.EncodeToString(sig) 151 return token.String(), privateKey.PublicKey 152} 153 154func createRS256JWT(t *testing.T) (string, rsa.PublicKey) { 155 t.Helper() 156 token := commonToken(t, "RS256") 157 privateKey, err := rsa.GenerateKey(rand.Reader, 2048) 158 if err != nil { 159 t.Fatalf("unable to generate key: %v", err) 160 } 161 sig, err := rsa.SignPKCS1v15(rand.Reader, privateKey, crypto.SHA256, token.hashedContent()) 162 if err != nil { 163 t.Fatalf("unable to sign content: %v", err) 164 } 165 token.signature = base64.RawURLEncoding.EncodeToString(sig) 166 return token.String(), privateKey.PublicKey 167} 168 169func commonToken(t *testing.T, alg string) *jwt { 170 t.Helper() 171 header := jwtHeader{ 172 KeyID: keyID, 173 Algorithm: alg, 174 Type: "JWT", 175 } 176 payload := Payload{ 177 Issuer: "example.com", 178 Audience: testAudience, 179 } 180 181 hb, err := json.Marshal(&header) 182 if err != nil { 183 t.Fatalf("unable to marshall header: %v", err) 184 } 185 pb, err := json.Marshal(&payload) 186 if err != nil { 187 t.Fatalf("unable to marshall payload: %v", err) 188 } 189 eb := base64.RawURLEncoding.EncodeToString(hb) 190 ep := base64.RawURLEncoding.EncodeToString(pb) 191 return &jwt{ 192 header: eb, 193 payload: ep, 194 } 195} 196 197type RoundTripFn func(req *http.Request) *http.Response 198 199func (f RoundTripFn) RoundTrip(req *http.Request) (*http.Response, error) { return f(req), nil } 200