1package keyenc_test
2
3import (
4	"bytes"
5	"crypto/aes"
6	"crypto/ecdsa"
7	"encoding/hex"
8	"testing"
9
10	"github.com/lestrrat-go/jwx/jwe/internal/keyenc"
11	"github.com/lestrrat-go/jwx/jwk"
12	"github.com/stretchr/testify/assert"
13)
14
15func mustHexDecode(s string) []byte {
16	b, err := hex.DecodeString(s)
17	if err != nil {
18		panic(err)
19	}
20	return b
21}
22
23type vector struct {
24	Kek      string
25	Data     string
26	Expected string
27}
28
29func TestRFC3394_Wrap(t *testing.T) {
30	vectors := []vector{
31		{
32			Kek:      "000102030405060708090A0B0C0D0E0F",
33			Data:     "00112233445566778899AABBCCDDEEFF",
34			Expected: "1FA68B0A8112B447AEF34BD8FB5A7B829D3E862371D2CFE5",
35		},
36		{
37			Kek:      "000102030405060708090A0B0C0D0E0F1011121314151617",
38			Data:     "00112233445566778899AABBCCDDEEFF",
39			Expected: "96778B25AE6CA435F92B5B97C050AED2468AB8A17AD84E5D",
40		},
41		{
42			Kek:      "000102030405060708090A0B0C0D0E0F101112131415161718191A1B1C1D1E1F",
43			Data:     "00112233445566778899AABBCCDDEEFF0001020304050607",
44			Expected: "A8F9BC1612C68B3FF6E6F4FBE30E71E4769C8B80A32CB8958CD5D17D6B254DA1",
45		},
46	}
47
48	for _, v := range vectors {
49		t.Logf("kek      = %s", v.Kek)
50		t.Logf("data     = %s", v.Data)
51		t.Logf("expected = %s", v.Expected)
52
53		kek := mustHexDecode(v.Kek)
54		data := mustHexDecode(v.Data)
55		expected := mustHexDecode(v.Expected)
56
57		block, err := aes.NewCipher(kek)
58		if !assert.NoError(t, err, "NewCipher is successful") {
59			return
60		}
61		out, err := keyenc.Wrap(block, data)
62		if !assert.NoError(t, err, "Wrap is successful") {
63			return
64		}
65
66		if !assert.Equal(t, expected, out, "Wrap generates expected output") {
67			return
68		}
69
70		unwrapped, err := keyenc.Unwrap(block, out)
71		if !assert.NoError(t, err, "Unwrap is successful") {
72			return
73		}
74
75		if !assert.Equal(t, data, unwrapped, "Unwrapped data matches") {
76			return
77		}
78	}
79}
80
81func TestDeriveECDHES(t *testing.T) {
82	// stolen from go-jose
83	// Example keys from JWA, Appendix C
84	var aliceKey ecdsa.PrivateKey
85	var bobKey ecdsa.PrivateKey
86
87	const aliceKeySrc = `{"kty":"EC",
88      "crv":"P-256",
89      "x":"gI0GAILBdu7T53akrFmMyGcsF3n5dO7MmwNBHKW5SV0",
90      "y":"SLW_xSffzlPWrHEVI30DHM_4egVwt3NQqeUD7nMFpps",
91      "d":"0_NxaRPUMQoAJt50Gz8YiTr8gRTwyEaCumd-MToTmIo"
92     }`
93	const bobKeySrc = `{"kty":"EC",
94      "crv":"P-256",
95      "x":"weNJy2HscCSM6AEDTDg04biOvhFhyyWvOHQfeF_PxMQ",
96      "y":"e8lnCO-AlStT-NJVX-crhB7QRYhiix03illJOVAOyck",
97      "d":"VEmDZpDXXK8p8N0Cndsxs924q6nS1RXFASRl6BfUqdw"
98     }`
99
100	aliceWebKey, err := jwk.ParseKey([]byte(aliceKeySrc))
101	if !assert.NoError(t, err, `jwk.ParseKey should succeed`) {
102		return
103	}
104	if !assert.NoError(t, aliceWebKey.Raw(&aliceKey), `aliceWebKey.Raw should succeed`) {
105		return
106	}
107
108	bobWebKey, err := jwk.ParseKey([]byte(bobKeySrc))
109	if !assert.NoError(t, err, `jwk.ParseKey should succeed`) {
110		return
111	}
112	if !assert.NoError(t, bobWebKey.Raw(&bobKey), `bobWebKey.Raw should succeed`) {
113		return
114	}
115
116	apuData := []byte("Alice")
117	apvData := []byte("Bob")
118
119	expected := []byte{86, 170, 141, 234, 248, 35, 109, 32, 92, 34, 40, 205, 113, 167, 16, 26}
120
121	output, err := keyenc.DeriveECDHES([]byte("A128GCM"), apuData, apvData, &bobKey, &aliceKey.PublicKey, 16)
122	if !assert.NoError(t, err, `keyenc.DeriveECDHES should succeed`) {
123		return
124	}
125
126	if !assert.Equal(t, output, expected, `result should match`) {
127		return
128	}
129}
130
131func TestKeyWrap(t *testing.T) {
132	// stolen from go-jose
133	// Test vectors from: http://csrc.nist.gov/groups/ST/toolkit/documents/kms/key-wrap.pdf
134	kek0, _ := hex.DecodeString("000102030405060708090A0B0C0D0E0F")
135	cek0, _ := hex.DecodeString("00112233445566778899AABBCCDDEEFF")
136
137	expected0, _ := hex.DecodeString("1FA68B0A8112B447AEF34BD8FB5A7B829D3E862371D2CFE5")
138
139	kek1, _ := hex.DecodeString("000102030405060708090A0B0C0D0E0F1011121314151617")
140	cek1, _ := hex.DecodeString("00112233445566778899AABBCCDDEEFF")
141
142	expected1, _ := hex.DecodeString("96778B25AE6CA435F92B5B97C050AED2468AB8A17AD84E5D")
143
144	kek2, _ := hex.DecodeString("000102030405060708090A0B0C0D0E0F101112131415161718191A1B1C1D1E1F")
145	cek2, _ := hex.DecodeString("00112233445566778899AABBCCDDEEFF0001020304050607")
146
147	expected2, _ := hex.DecodeString("A8F9BC1612C68B3FF6E6F4FBE30E71E4769C8B80A32CB8958CD5D17D6B254DA1")
148
149	block0, _ := aes.NewCipher(kek0)
150	block1, _ := aes.NewCipher(kek1)
151	block2, _ := aes.NewCipher(kek2)
152
153	out0, _ := keyenc.Wrap(block0, cek0)
154	out1, _ := keyenc.Wrap(block1, cek1)
155	out2, _ := keyenc.Wrap(block2, cek2)
156
157	if !bytes.Equal(out0, expected0) {
158		t.Error("output 0 not as expected, got", out0, "wanted", expected0)
159	}
160
161	if !bytes.Equal(out1, expected1) {
162		t.Error("output 1 not as expected, got", out1, "wanted", expected1)
163	}
164
165	if !bytes.Equal(out2, expected2) {
166		t.Error("output 2 not as expected, got", out2, "wanted", expected2)
167	}
168
169	unwrap0, _ := keyenc.Unwrap(block0, out0)
170	unwrap1, _ := keyenc.Unwrap(block1, out1)
171	unwrap2, _ := keyenc.Unwrap(block2, out2)
172
173	if !bytes.Equal(unwrap0, cek0) {
174		t.Error("key unwrap did not return original input, got", unwrap0, "wanted", cek0)
175	}
176
177	if !bytes.Equal(unwrap1, cek1) {
178		t.Error("key unwrap did not return original input, got", unwrap1, "wanted", cek1)
179	}
180
181	if !bytes.Equal(unwrap2, cek2) {
182		t.Error("key unwrap did not return original input, got", unwrap2, "wanted", cek2)
183	}
184}
185