1/*-
2 * Copyright 2018 Square Inc.
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17package jose
18
19import (
20	"fmt"
21	"testing"
22)
23
24type signWrapper struct {
25	pk      *JSONWebKey
26	wrapped payloadSigner
27	algs    []SignatureAlgorithm
28}
29
30var _ = OpaqueSigner(&signWrapper{})
31
32func (sw *signWrapper) Algs() []SignatureAlgorithm {
33	return sw.algs
34}
35
36func (sw *signWrapper) Public() *JSONWebKey {
37	return sw.pk
38}
39
40func (sw *signWrapper) SignPayload(payload []byte, alg SignatureAlgorithm) ([]byte, error) {
41	sig, err := sw.wrapped.signPayload(payload, alg)
42	if err != nil {
43		return nil, err
44	}
45	return sig.Signature, nil
46}
47
48type verifyWrapper struct {
49	wrapped []payloadVerifier
50}
51
52var _ = OpaqueVerifier(&verifyWrapper{})
53
54func (vw *verifyWrapper) VerifyPayload(payload []byte, signature []byte, alg SignatureAlgorithm) error {
55	if len(vw.wrapped) == 0 {
56		return fmt.Errorf("error: verifier had no keys")
57	}
58	var err error
59	for _, v := range vw.wrapped {
60		err = v.verifyPayload(payload, signature, alg)
61		if err == nil {
62			return nil
63		}
64	}
65	return err
66}
67
68type keyEncryptWrapper struct {
69	kid     string
70	wrapped keyEncrypter
71	algs    []KeyAlgorithm
72}
73
74var _ = OpaqueKeyEncrypter(&keyEncryptWrapper{})
75
76func (kew *keyEncryptWrapper) KeyID() string {
77	return kew.kid
78}
79
80func (kew *keyEncryptWrapper) Algs() []KeyAlgorithm {
81	return kew.algs
82}
83
84func (kew *keyEncryptWrapper) encryptKey(cek []byte, alg KeyAlgorithm) (recipientInfo, error) {
85	info, err := kew.wrapped.encryptKey(cek, alg)
86	if err != nil {
87		return recipientInfo{}, err
88	}
89
90	return info, nil
91}
92
93type keyDecryptWrapper struct {
94	wrapped keyDecrypter
95}
96
97var _ = OpaqueKeyDecrypter(&keyDecryptWrapper{})
98
99func (kdw *keyDecryptWrapper) DecryptKey(encryptedKey []byte, header Header) ([]byte, error) {
100	rawHeader := rawHeader{}
101
102	err := rawHeader.set(headerKeyID, header.KeyID)
103	if err != nil {
104		return nil, err
105	}
106	err = rawHeader.set(headerAlgorithm, header.Algorithm)
107	if err != nil {
108		return nil, err
109	}
110	err = rawHeader.set(headerNonce, header.Nonce)
111	if err != nil {
112		return nil, err
113	}
114	err = rawHeader.set(headerJWK, header.JSONWebKey)
115	if err != nil {
116		return nil, err
117	}
118	for k, v := range header.ExtraHeaders {
119		err = rawHeader.set(k, v)
120		if err != nil {
121			return nil, err
122		}
123	}
124
125	recipient := &recipientInfo{
126		encryptedKey: encryptedKey,
127	}
128
129	var generator randomKeyGenerator
130	cipher := getContentCipher(rawHeader.getEncryption())
131	if cipher != nil {
132		generator = randomKeyGenerator{
133			size: cipher.keySize(),
134		}
135	}
136
137	return kdw.wrapped.decryptKey(rawHeader, recipient, generator)
138}
139
140func TestRoundtripsJWSOpaque(t *testing.T) {
141	sigAlgs := []SignatureAlgorithm{RS256, RS384, RS512, PS256, PS384, PS512, ES256, ES384, ES512, EdDSA}
142
143	serializers := []func(*JSONWebSignature) (string, error){
144		func(obj *JSONWebSignature) (string, error) { return obj.CompactSerialize() },
145		func(obj *JSONWebSignature) (string, error) { return obj.FullSerialize(), nil },
146	}
147
148	corrupter := func(obj *JSONWebSignature) {}
149
150	for _, alg := range sigAlgs {
151		signingKey, verificationKey := GenerateSigningTestKey(alg)
152
153		for i, serializer := range serializers {
154			sw := makeOpaqueSigner(t, signingKey, alg)
155			vw := makeOpaqueVerifier(t, []interface{}{verificationKey}, alg)
156
157			err := RoundtripJWS(alg, serializer, corrupter, sw, verificationKey, "test_nonce")
158			if err != nil {
159				t.Error(err, alg, i)
160			}
161
162			err = RoundtripJWS(alg, serializer, corrupter, signingKey, vw, "test_nonce")
163			if err != nil {
164				t.Error(err, alg, i)
165			}
166
167			err = RoundtripJWS(alg, serializer, corrupter, sw, vw, "test_nonce")
168			if err != nil {
169				t.Error(err, alg, i)
170			}
171		}
172	}
173}
174
175func makeOpaqueSigner(t *testing.T, signingKey interface{}, alg SignatureAlgorithm) *signWrapper {
176	ri, err := makeJWSRecipient(alg, signingKey)
177	if err != nil {
178		t.Fatal(err)
179	}
180	return &signWrapper{
181		wrapped: ri.signer,
182		algs:    []SignatureAlgorithm{alg},
183		pk:      &JSONWebKey{Key: ri.publicKey()},
184	}
185}
186
187func makeOpaqueVerifier(t *testing.T, verificationKey []interface{}, alg SignatureAlgorithm) *verifyWrapper {
188	var verifiers []payloadVerifier
189	for _, vk := range verificationKey {
190		verifier, err := newVerifier(vk)
191		if err != nil {
192			t.Fatal(err)
193		}
194		verifiers = append(verifiers, verifier)
195	}
196	return &verifyWrapper{wrapped: verifiers}
197}
198
199func makeOpaqueKeyEncrypter(t *testing.T, signingKey interface{}, alg KeyAlgorithm, kid string) *keyEncryptWrapper {
200	rki, err := makeJWERecipient(alg, signingKey)
201	if err != nil {
202		t.Fatal(err, alg)
203	}
204	return &keyEncryptWrapper{
205		wrapped: rki.keyEncrypter,
206		algs:    []KeyAlgorithm{alg},
207		kid:     kid,
208	}
209}
210
211func makeOpaqueKeyDecrypter(t *testing.T, decryptionKey interface{}, alg KeyAlgorithm) *keyDecryptWrapper {
212	kd, err := newDecrypter(decryptionKey)
213	if err != nil {
214		t.Fatal(err)
215	}
216
217	return &keyDecryptWrapper{
218		wrapped: kd,
219	}
220}
221
222func TestOpaqueSignerKeyRotation(t *testing.T) {
223
224	sigAlgs := []SignatureAlgorithm{RS256, RS384, RS512, PS256, PS384, PS512, ES256, ES384, ES512, EdDSA}
225
226	serializers := []func(*JSONWebSignature) (string, error){
227		func(obj *JSONWebSignature) (string, error) { return obj.CompactSerialize() },
228		func(obj *JSONWebSignature) (string, error) { return obj.FullSerialize(), nil },
229	}
230
231	for _, alg := range sigAlgs {
232		for i, serializer := range serializers {
233			sk1, pk1 := GenerateSigningTestKey(alg)
234			sk2, pk2 := GenerateSigningTestKey(alg)
235
236			sw := makeOpaqueSigner(t, sk1, alg)
237			sw.pk.KeyID = "first"
238			vw := makeOpaqueVerifier(t, []interface{}{pk1, pk2}, alg)
239
240			signer, err := NewSigner(
241				SigningKey{Algorithm: alg, Key: sw},
242				&SignerOptions{NonceSource: staticNonceSource("test_nonce")},
243			)
244			if err != nil {
245				t.Fatal(err, alg, i)
246			}
247
248			jws1, err := signer.Sign([]byte("foo bar baz"))
249			if err != nil {
250				t.Fatal(err, alg, i)
251			}
252			jws1 = rtSerialize(t, serializer, jws1, vw)
253			if kid := jws1.Signatures[0].Protected.KeyID; kid != "first" {
254				t.Errorf("expected kid %q but got %q", "first", kid)
255			}
256
257			swNext := makeOpaqueSigner(t, sk2, alg)
258			swNext.pk.KeyID = "next"
259			sw.wrapped = swNext.wrapped
260			sw.pk = swNext.pk
261
262			jws2, err := signer.Sign([]byte("foo bar baz next"))
263			if err != nil {
264				t.Error(err, alg, i)
265			}
266			jws2 = rtSerialize(t, serializer, jws2, vw)
267			if kid := jws2.Signatures[0].Protected.KeyID; kid != "next" {
268				t.Errorf("expected kid %q but got %q", "next", kid)
269			}
270		}
271	}
272}
273
274func rtSerialize(t *testing.T, serializer func(*JSONWebSignature) (string, error), sig *JSONWebSignature, vk interface{}) *JSONWebSignature {
275	b, err := serializer(sig)
276	if err != nil {
277		t.Fatal(err)
278	}
279	sig, err = ParseSigned(b)
280	if err != nil {
281		t.Fatal(err)
282	}
283	if _, err := sig.Verify(vk); err != nil {
284		t.Fatal(err)
285	}
286	return sig
287}
288
289func TestOpaqueKeyRoundtripJWE(t *testing.T) {
290	keyAlgs := []KeyAlgorithm{
291		ECDH_ES_A128KW, ECDH_ES_A192KW, ECDH_ES_A256KW, A128KW, A192KW, A256KW,
292		RSA1_5, RSA_OAEP, RSA_OAEP_256, A128GCMKW, A192GCMKW, A256GCMKW,
293		PBES2_HS256_A128KW, PBES2_HS384_A192KW, PBES2_HS512_A256KW,
294	}
295	encAlgs := []ContentEncryption{A128GCM, A192GCM, A256GCM, A128CBC_HS256, A192CBC_HS384, A256CBC_HS512}
296	kid := "test-kid"
297
298	serializers := []func(*JSONWebEncryption) (string, error){
299		func(obj *JSONWebEncryption) (string, error) { return obj.CompactSerialize() },
300		func(obj *JSONWebEncryption) (string, error) { return obj.FullSerialize(), nil },
301	}
302
303	for _, alg := range keyAlgs {
304		for _, enc := range encAlgs {
305			for _, testKey := range generateTestKeys(alg, enc) {
306				for _, serializer := range serializers {
307					kew := makeOpaqueKeyEncrypter(t, testKey.enc, alg, kid)
308					encrypter, err := NewEncrypter(
309						enc,
310						Recipient{
311							Algorithm: alg,
312							Key:       kew,
313						},
314						&EncrypterOptions{},
315					)
316					if err != nil {
317						t.Fatal(err, alg)
318					}
319
320					jwe, err := encrypter.Encrypt([]byte("foo bar"))
321					if err != nil {
322						t.Fatal(err, alg)
323					}
324
325					dw := makeOpaqueKeyDecrypter(t, testKey.dec, alg)
326					jwe = jweSerialize(t, serializer, jwe, dw)
327					if jwe.Header.KeyID != kid {
328						t.Errorf("expected jwe kid to equal %s but got %s", kid, jwe.Header.KeyID)
329					}
330
331					out, err := jwe.Decrypt(dw)
332					if err != nil {
333						t.Fatal(err, out)
334					}
335					if string(out) != "foo bar" {
336						t.Errorf("expected decrypted jwe to equal %s but got %s", "foo bar", string(out))
337					}
338				}
339			}
340		}
341	}
342}
343
344func jweSerialize(t *testing.T, serializer func(*JSONWebEncryption) (string, error), jwe *JSONWebEncryption, d OpaqueKeyDecrypter) *JSONWebEncryption {
345	b, err := serializer(jwe)
346	if err != nil {
347		t.Fatal(err)
348	}
349	jwe, err = ParseEncrypted(b)
350	if err != nil {
351		t.Fatal(err)
352	}
353	if _, err := jwe.Decrypt(d); err != nil {
354		t.Fatal(err)
355	}
356	return jwe
357}
358