1/*
2Copyright 2017 The Kubernetes Authors.
3
4Licensed under the Apache License, Version 2.0 (the "License");
5you may not use this file except in compliance with the License.
6You may obtain a copy of the License at
7
8    http://www.apache.org/licenses/LICENSE-2.0
9
10Unless required by applicable law or agreed to in writing, software
11distributed under the License is distributed on an "AS IS" BASIS,
12WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13See the License for the specific language governing permissions and
14limitations under the License.
15*/
16
17package aes
18
19import (
20	"bytes"
21	"crypto/aes"
22	"crypto/cipher"
23	"crypto/rand"
24	"encoding/hex"
25	"fmt"
26	"io"
27	"reflect"
28	"testing"
29
30	"k8s.io/apiserver/pkg/storage/value"
31)
32
33func TestGCMDataStable(t *testing.T) {
34	block, err := aes.NewCipher([]byte("0123456789abcdef"))
35	if err != nil {
36		t.Fatal(err)
37	}
38	aead, err := cipher.NewGCM(block)
39	if err != nil {
40		t.Fatal(err)
41	}
42	// IMPORTANT: If you must fix this test, then all previously encrypted data from previously compiled versions is broken unless you hardcode the nonce size to 12
43	if aead.NonceSize() != 12 {
44		t.Fatalf("The underlying Golang crypto size has changed, old version of AES on disk will not be readable unless the AES implementation is changed to hardcode nonce size.")
45	}
46}
47
48func TestGCMKeyRotation(t *testing.T) {
49	testErr := fmt.Errorf("test error")
50	block1, err := aes.NewCipher([]byte("abcdefghijklmnop"))
51	if err != nil {
52		t.Fatal(err)
53	}
54	block2, err := aes.NewCipher([]byte("0123456789abcdef"))
55	if err != nil {
56		t.Fatal(err)
57	}
58
59	context := value.DefaultContext([]byte("authenticated_data"))
60
61	p := value.NewPrefixTransformers(testErr,
62		value.PrefixTransformer{Prefix: []byte("first:"), Transformer: NewGCMTransformer(block1)},
63		value.PrefixTransformer{Prefix: []byte("second:"), Transformer: NewGCMTransformer(block2)},
64	)
65	out, err := p.TransformToStorage([]byte("firstvalue"), context)
66	if err != nil {
67		t.Fatal(err)
68	}
69	if !bytes.HasPrefix(out, []byte("first:")) {
70		t.Fatalf("unexpected prefix: %q", out)
71	}
72	from, stale, err := p.TransformFromStorage(out, context)
73	if err != nil {
74		t.Fatal(err)
75	}
76	if stale || !bytes.Equal([]byte("firstvalue"), from) {
77		t.Fatalf("unexpected data: %t %q", stale, from)
78	}
79
80	// verify changing the context fails storage
81	_, _, err = p.TransformFromStorage(out, value.DefaultContext([]byte("incorrect_context")))
82	if err == nil {
83		t.Fatalf("expected unauthenticated data")
84	}
85
86	// reverse the order, use the second key
87	p = value.NewPrefixTransformers(testErr,
88		value.PrefixTransformer{Prefix: []byte("second:"), Transformer: NewGCMTransformer(block2)},
89		value.PrefixTransformer{Prefix: []byte("first:"), Transformer: NewGCMTransformer(block1)},
90	)
91	from, stale, err = p.TransformFromStorage(out, context)
92	if err != nil {
93		t.Fatal(err)
94	}
95	if !stale || !bytes.Equal([]byte("firstvalue"), from) {
96		t.Fatalf("unexpected data: %t %q", stale, from)
97	}
98}
99
100func TestCBCKeyRotation(t *testing.T) {
101	testErr := fmt.Errorf("test error")
102	block1, err := aes.NewCipher([]byte("abcdefghijklmnop"))
103	if err != nil {
104		t.Fatal(err)
105	}
106	block2, err := aes.NewCipher([]byte("0123456789abcdef"))
107	if err != nil {
108		t.Fatal(err)
109	}
110
111	context := value.DefaultContext([]byte("authenticated_data"))
112
113	p := value.NewPrefixTransformers(testErr,
114		value.PrefixTransformer{Prefix: []byte("first:"), Transformer: NewCBCTransformer(block1)},
115		value.PrefixTransformer{Prefix: []byte("second:"), Transformer: NewCBCTransformer(block2)},
116	)
117	out, err := p.TransformToStorage([]byte("firstvalue"), context)
118	if err != nil {
119		t.Fatal(err)
120	}
121	if !bytes.HasPrefix(out, []byte("first:")) {
122		t.Fatalf("unexpected prefix: %q", out)
123	}
124	from, stale, err := p.TransformFromStorage(out, context)
125	if err != nil {
126		t.Fatal(err)
127	}
128	if stale || !bytes.Equal([]byte("firstvalue"), from) {
129		t.Fatalf("unexpected data: %t %q", stale, from)
130	}
131
132	// verify changing the context fails storage
133	_, _, err = p.TransformFromStorage(out, value.DefaultContext([]byte("incorrect_context")))
134	if err != nil {
135		t.Fatalf("CBC mode does not support authentication: %v", err)
136	}
137
138	// reverse the order, use the second key
139	p = value.NewPrefixTransformers(testErr,
140		value.PrefixTransformer{Prefix: []byte("second:"), Transformer: NewCBCTransformer(block2)},
141		value.PrefixTransformer{Prefix: []byte("first:"), Transformer: NewCBCTransformer(block1)},
142	)
143	from, stale, err = p.TransformFromStorage(out, context)
144	if err != nil {
145		t.Fatal(err)
146	}
147	if !stale || !bytes.Equal([]byte("firstvalue"), from) {
148		t.Fatalf("unexpected data: %t %q", stale, from)
149	}
150}
151
152func BenchmarkGCMRead(b *testing.B) {
153	tests := []struct {
154		keyLength   int
155		valueLength int
156		expectStale bool
157	}{
158		{keyLength: 16, valueLength: 1024, expectStale: false},
159		{keyLength: 32, valueLength: 1024, expectStale: false},
160		{keyLength: 32, valueLength: 16384, expectStale: false},
161		{keyLength: 32, valueLength: 16384, expectStale: true},
162	}
163	for _, t := range tests {
164		name := fmt.Sprintf("%vKeyLength/%vValueLength/%vExpectStale", t.keyLength, t.valueLength, t.expectStale)
165		b.Run(name, func(b *testing.B) {
166			benchmarkGCMRead(b, t.keyLength, t.valueLength, t.expectStale)
167		})
168	}
169}
170
171func BenchmarkGCMWrite(b *testing.B) {
172	tests := []struct {
173		keyLength   int
174		valueLength int
175	}{
176		{keyLength: 16, valueLength: 1024},
177		{keyLength: 32, valueLength: 1024},
178		{keyLength: 32, valueLength: 16384},
179	}
180	for _, t := range tests {
181		name := fmt.Sprintf("%vKeyLength/%vValueLength", t.keyLength, t.valueLength)
182		b.Run(name, func(b *testing.B) {
183			benchmarkGCMWrite(b, t.keyLength, t.valueLength)
184		})
185	}
186}
187
188func benchmarkGCMRead(b *testing.B, keyLength int, valueLength int, expectStale bool) {
189	block1, err := aes.NewCipher(bytes.Repeat([]byte("a"), keyLength))
190	if err != nil {
191		b.Fatal(err)
192	}
193	block2, err := aes.NewCipher(bytes.Repeat([]byte("b"), keyLength))
194	if err != nil {
195		b.Fatal(err)
196	}
197	p := value.NewPrefixTransformers(nil,
198		value.PrefixTransformer{Prefix: []byte("first:"), Transformer: NewGCMTransformer(block1)},
199		value.PrefixTransformer{Prefix: []byte("second:"), Transformer: NewGCMTransformer(block2)},
200	)
201
202	context := value.DefaultContext([]byte("authenticated_data"))
203	v := bytes.Repeat([]byte("0123456789abcdef"), valueLength/16)
204
205	out, err := p.TransformToStorage(v, context)
206	if err != nil {
207		b.Fatal(err)
208	}
209	// reverse the key order if expecting stale
210	if expectStale {
211		p = value.NewPrefixTransformers(nil,
212			value.PrefixTransformer{Prefix: []byte("second:"), Transformer: NewGCMTransformer(block2)},
213			value.PrefixTransformer{Prefix: []byte("first:"), Transformer: NewGCMTransformer(block1)},
214		)
215	}
216
217	b.ResetTimer()
218	for i := 0; i < b.N; i++ {
219		from, stale, err := p.TransformFromStorage(out, context)
220		if err != nil {
221			b.Fatal(err)
222		}
223		if expectStale != stale {
224			b.Fatalf("unexpected data: %q, expect stale %t but got %t", from, expectStale, stale)
225		}
226	}
227	b.StopTimer()
228}
229
230func benchmarkGCMWrite(b *testing.B, keyLength int, valueLength int) {
231	block1, err := aes.NewCipher(bytes.Repeat([]byte("a"), keyLength))
232	if err != nil {
233		b.Fatal(err)
234	}
235	block2, err := aes.NewCipher(bytes.Repeat([]byte("b"), keyLength))
236	if err != nil {
237		b.Fatal(err)
238	}
239	p := value.NewPrefixTransformers(nil,
240		value.PrefixTransformer{Prefix: []byte("first:"), Transformer: NewGCMTransformer(block1)},
241		value.PrefixTransformer{Prefix: []byte("second:"), Transformer: NewGCMTransformer(block2)},
242	)
243
244	context := value.DefaultContext([]byte("authenticated_data"))
245	v := bytes.Repeat([]byte("0123456789abcdef"), valueLength/16)
246
247	b.ResetTimer()
248	for i := 0; i < b.N; i++ {
249		_, err := p.TransformToStorage(v, context)
250		if err != nil {
251			b.Fatal(err)
252		}
253	}
254	b.StopTimer()
255}
256
257func BenchmarkCBCRead(b *testing.B) {
258	tests := []struct {
259		keyLength   int
260		valueLength int
261		expectStale bool
262	}{
263		{keyLength: 32, valueLength: 1024, expectStale: false},
264		{keyLength: 32, valueLength: 16384, expectStale: false},
265		{keyLength: 32, valueLength: 16384, expectStale: true},
266	}
267	for _, t := range tests {
268		name := fmt.Sprintf("%vKeyLength/%vValueLength/%vExpectStale", t.keyLength, t.valueLength, t.expectStale)
269		b.Run(name, func(b *testing.B) {
270			benchmarkCBCRead(b, t.keyLength, t.valueLength, t.expectStale)
271		})
272	}
273}
274
275func BenchmarkCBCWrite(b *testing.B) {
276	tests := []struct {
277		keyLength   int
278		valueLength int
279	}{
280		{keyLength: 32, valueLength: 1024},
281		{keyLength: 32, valueLength: 16384},
282	}
283	for _, t := range tests {
284		name := fmt.Sprintf("%vKeyLength/%vValueLength", t.keyLength, t.valueLength)
285		b.Run(name, func(b *testing.B) {
286			benchmarkCBCWrite(b, t.keyLength, t.valueLength)
287		})
288	}
289}
290
291func benchmarkCBCRead(b *testing.B, keyLength int, valueLength int, expectStale bool) {
292	block1, err := aes.NewCipher(bytes.Repeat([]byte("a"), keyLength))
293	if err != nil {
294		b.Fatal(err)
295	}
296	block2, err := aes.NewCipher(bytes.Repeat([]byte("b"), keyLength))
297	if err != nil {
298		b.Fatal(err)
299	}
300	p := value.NewPrefixTransformers(nil,
301		value.PrefixTransformer{Prefix: []byte("first:"), Transformer: NewCBCTransformer(block1)},
302		value.PrefixTransformer{Prefix: []byte("second:"), Transformer: NewCBCTransformer(block2)},
303	)
304
305	context := value.DefaultContext([]byte("authenticated_data"))
306	v := bytes.Repeat([]byte("0123456789abcdef"), valueLength/16)
307
308	out, err := p.TransformToStorage(v, context)
309	if err != nil {
310		b.Fatal(err)
311	}
312	// reverse the key order if expecting stale
313	if expectStale {
314		p = value.NewPrefixTransformers(nil,
315			value.PrefixTransformer{Prefix: []byte("second:"), Transformer: NewCBCTransformer(block2)},
316			value.PrefixTransformer{Prefix: []byte("first:"), Transformer: NewCBCTransformer(block1)},
317		)
318	}
319
320	b.ResetTimer()
321	for i := 0; i < b.N; i++ {
322		from, stale, err := p.TransformFromStorage(out, context)
323		if err != nil {
324			b.Fatal(err)
325		}
326		if expectStale != stale {
327			b.Fatalf("unexpected data: %q, expect stale %t but got %t", from, expectStale, stale)
328		}
329	}
330	b.StopTimer()
331}
332
333func benchmarkCBCWrite(b *testing.B, keyLength int, valueLength int) {
334	block1, err := aes.NewCipher(bytes.Repeat([]byte("a"), keyLength))
335	if err != nil {
336		b.Fatal(err)
337	}
338	block2, err := aes.NewCipher(bytes.Repeat([]byte("b"), keyLength))
339	if err != nil {
340		b.Fatal(err)
341	}
342	p := value.NewPrefixTransformers(nil,
343		value.PrefixTransformer{Prefix: []byte("first:"), Transformer: NewCBCTransformer(block1)},
344		value.PrefixTransformer{Prefix: []byte("second:"), Transformer: NewCBCTransformer(block2)},
345	)
346
347	context := value.DefaultContext([]byte("authenticated_data"))
348	v := bytes.Repeat([]byte("0123456789abcdef"), valueLength/16)
349
350	b.ResetTimer()
351	for i := 0; i < b.N; i++ {
352		_, err := p.TransformToStorage(v, context)
353		if err != nil {
354			b.Fatal(err)
355		}
356	}
357	b.StopTimer()
358}
359
360func TestRoundTrip(t *testing.T) {
361	lengths := []int{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 128, 1024}
362
363	aes16block, err := aes.NewCipher([]byte(bytes.Repeat([]byte("a"), 16)))
364	if err != nil {
365		t.Fatal(err)
366	}
367	aes24block, err := aes.NewCipher([]byte(bytes.Repeat([]byte("b"), 24)))
368	if err != nil {
369		t.Fatal(err)
370	}
371	aes32block, err := aes.NewCipher([]byte(bytes.Repeat([]byte("c"), 32)))
372	if err != nil {
373		t.Fatal(err)
374	}
375
376	tests := []struct {
377		name    string
378		context value.Context
379		t       value.Transformer
380	}{
381		{name: "GCM 16 byte key", t: NewGCMTransformer(aes16block)},
382		{name: "GCM 24 byte key", t: NewGCMTransformer(aes24block)},
383		{name: "GCM 32 byte key", t: NewGCMTransformer(aes32block)},
384		{name: "CBC 32 byte key", t: NewCBCTransformer(aes32block)},
385	}
386	for _, tt := range tests {
387		t.Run(tt.name, func(t *testing.T) {
388			context := tt.context
389			if context == nil {
390				context = value.DefaultContext("")
391			}
392			for _, l := range lengths {
393				data := make([]byte, l)
394				if _, err := io.ReadFull(rand.Reader, data); err != nil {
395					t.Fatalf("unable to read sufficient random bytes: %v", err)
396				}
397				original := append([]byte{}, data...)
398
399				ciphertext, err := tt.t.TransformToStorage(data, context)
400				if err != nil {
401					t.Errorf("TransformToStorage error = %v", err)
402					continue
403				}
404
405				result, stale, err := tt.t.TransformFromStorage(ciphertext, context)
406				if err != nil {
407					t.Errorf("TransformFromStorage error = %v", err)
408					continue
409				}
410				if stale {
411					t.Errorf("unexpected stale output")
412					continue
413				}
414
415				switch {
416				case l == 0:
417					if len(result) != 0 {
418						t.Errorf("Round trip failed len=%d\noriginal:\n%s\nresult:\n%s", l, hex.Dump(original), hex.Dump(result))
419					}
420				case !reflect.DeepEqual(original, result):
421					t.Errorf("Round trip failed len=%d\noriginal:\n%s\nresult:\n%s", l, hex.Dump(original), hex.Dump(result))
422				}
423			}
424		})
425	}
426}
427