1package encryption
2
3import (
4	"crypto/rand"
5	"encoding/base64"
6	"fmt"
7	"io"
8	"testing"
9
10	"github.com/stretchr/testify/assert"
11)
12
13func TestEncodeAndDecodeAccessToken(t *testing.T) {
14	const secret = "0123456789abcdefghijklmnopqrstuv"
15	const token = "my access token"
16	cfb, err := NewCFBCipher([]byte(secret))
17	assert.NoError(t, err)
18	c := NewBase64Cipher(cfb)
19
20	encoded, err := c.Encrypt([]byte(token))
21	assert.Equal(t, nil, err)
22
23	decoded, err := c.Decrypt(encoded)
24	assert.Equal(t, nil, err)
25
26	assert.NotEqual(t, []byte(token), encoded)
27	assert.Equal(t, []byte(token), decoded)
28}
29
30func TestEncodeAndDecodeAccessTokenB64(t *testing.T) {
31	const secretBase64 = "A3Xbr6fu6Al0HkgrP1ztjb-mYiwmxgNPP-XbNsz1WBk="
32	const token = "my access token"
33
34	secret, err := base64.URLEncoding.DecodeString(secretBase64)
35	assert.Equal(t, nil, err)
36	cfb, err := NewCFBCipher([]byte(secret))
37	assert.NoError(t, err)
38	c := NewBase64Cipher(cfb)
39
40	encoded, err := c.Encrypt([]byte(token))
41	assert.Equal(t, nil, err)
42
43	decoded, err := c.Decrypt(encoded)
44	assert.Equal(t, nil, err)
45
46	assert.NotEqual(t, []byte(token), encoded)
47	assert.Equal(t, []byte(token), decoded)
48}
49
50func TestEncryptAndDecrypt(t *testing.T) {
51	// Test our 2 cipher types
52	cipherInits := map[string]func([]byte) (Cipher, error){
53		"CFB": NewCFBCipher,
54		"GCM": NewGCMCipher,
55	}
56	for name, initCipher := range cipherInits {
57		t.Run(name, func(t *testing.T) {
58			// Test all 3 valid AES sizes
59			for _, secretSize := range []int{16, 24, 32} {
60				t.Run(fmt.Sprintf("%d", secretSize), func(t *testing.T) {
61					secret := make([]byte, secretSize)
62					_, err := io.ReadFull(rand.Reader, secret)
63					assert.Equal(t, nil, err)
64
65					// Test Standard & Base64 wrapped
66					cstd, err := initCipher(secret)
67					assert.Equal(t, nil, err)
68
69					cb64 := NewBase64Cipher(cstd)
70
71					ciphers := map[string]Cipher{
72						"Standard": cstd,
73						"Base64":   cb64,
74					}
75
76					for cName, c := range ciphers {
77						t.Run(cName, func(t *testing.T) {
78							// Test various sizes sessions might be
79							for _, dataSize := range []int{10, 100, 1000, 5000, 10000} {
80								t.Run(fmt.Sprintf("%d", dataSize), func(t *testing.T) {
81									runEncryptAndDecrypt(t, c, dataSize)
82								})
83							}
84						})
85					}
86				})
87			}
88		})
89	}
90}
91
92func runEncryptAndDecrypt(t *testing.T, c Cipher, dataSize int) {
93	data := make([]byte, dataSize)
94	_, err := io.ReadFull(rand.Reader, data)
95	assert.Equal(t, nil, err)
96
97	// Ensure our Encrypt function doesn't encrypt in place
98	immutableData := make([]byte, len(data))
99	copy(immutableData, data)
100
101	encrypted, err := c.Encrypt(data)
102	assert.Equal(t, nil, err)
103	assert.NotEqual(t, encrypted, data)
104	// Encrypt didn't operate in-place on []byte
105	assert.Equal(t, data, immutableData)
106
107	// Ensure our Decrypt function doesn't decrypt in place
108	immutableEnc := make([]byte, len(encrypted))
109	copy(immutableEnc, encrypted)
110
111	decrypted, err := c.Decrypt(encrypted)
112	assert.Equal(t, nil, err)
113	// Original data back
114	assert.Equal(t, data, decrypted)
115	// Decrypt didn't operate in-place on []byte
116	assert.Equal(t, encrypted, immutableEnc)
117	// Encrypt/Decrypt actually did something
118	assert.NotEqual(t, encrypted, decrypted)
119}
120
121func TestDecryptCFBWrongSecret(t *testing.T) {
122	secret1 := []byte("0123456789abcdefghijklmnopqrstuv")
123	secret2 := []byte("9876543210abcdefghijklmnopqrstuv")
124
125	c1, err := NewCFBCipher(secret1)
126	assert.Equal(t, nil, err)
127
128	c2, err := NewCFBCipher(secret2)
129	assert.Equal(t, nil, err)
130
131	data := []byte("f3928pufm982374dj02y485dsl34890u2t9nd4028s94dm58y2394087dhmsyt29h8df")
132
133	ciphertext, err := c1.Encrypt(data)
134	assert.Equal(t, nil, err)
135
136	wrongData, err := c2.Decrypt(ciphertext)
137	assert.Equal(t, nil, err)
138	assert.NotEqual(t, data, wrongData)
139}
140
141func TestDecryptGCMWrongSecret(t *testing.T) {
142	secret1 := []byte("0123456789abcdefghijklmnopqrstuv")
143	secret2 := []byte("9876543210abcdefghijklmnopqrstuv")
144
145	c1, err := NewGCMCipher(secret1)
146	assert.Equal(t, nil, err)
147
148	c2, err := NewGCMCipher(secret2)
149	assert.Equal(t, nil, err)
150
151	data := []byte("f3928pufm982374dj02y485dsl34890u2t9nd4028s94dm58y2394087dhmsyt29h8df")
152
153	ciphertext, err := c1.Encrypt(data)
154	assert.Equal(t, nil, err)
155
156	// GCM is authenticated - this should lead to message authentication failed
157	_, err = c2.Decrypt(ciphertext)
158	assert.Error(t, err)
159}
160
161// Encrypt with GCM, Decrypt with CFB: Results in Garbage data
162func TestGCMtoCFBErrors(t *testing.T) {
163	// Test all 3 valid AES sizes
164	for _, secretSize := range []int{16, 24, 32} {
165		t.Run(fmt.Sprintf("%d", secretSize), func(t *testing.T) {
166			secret := make([]byte, secretSize)
167			_, err := io.ReadFull(rand.Reader, secret)
168			assert.Equal(t, nil, err)
169
170			gcm, err := NewGCMCipher(secret)
171			assert.Equal(t, nil, err)
172
173			cfb, err := NewCFBCipher(secret)
174			assert.Equal(t, nil, err)
175
176			// Test various sizes sessions might be
177			for _, dataSize := range []int{10, 100, 1000, 5000, 10000} {
178				t.Run(fmt.Sprintf("%d", dataSize), func(t *testing.T) {
179					data := make([]byte, dataSize)
180					_, err := io.ReadFull(rand.Reader, data)
181					assert.Equal(t, nil, err)
182
183					encrypted, err := gcm.Encrypt(data)
184					assert.Equal(t, nil, err)
185					assert.NotEqual(t, encrypted, data)
186
187					decrypted, err := cfb.Decrypt(encrypted)
188					assert.Equal(t, nil, err)
189					// Data is mangled
190					assert.NotEqual(t, data, decrypted)
191					assert.NotEqual(t, encrypted, decrypted)
192				})
193			}
194		})
195	}
196}
197
198// Encrypt with CFB, Decrypt with GCM: Results in errors
199func TestCFBtoGCMErrors(t *testing.T) {
200	// Test all 3 valid AES sizes
201	for _, secretSize := range []int{16, 24, 32} {
202		t.Run(fmt.Sprintf("%d", secretSize), func(t *testing.T) {
203			secret := make([]byte, secretSize)
204			_, err := io.ReadFull(rand.Reader, secret)
205			assert.Equal(t, nil, err)
206
207			gcm, err := NewGCMCipher(secret)
208			assert.Equal(t, nil, err)
209
210			cfb, err := NewCFBCipher(secret)
211			assert.Equal(t, nil, err)
212
213			// Test various sizes sessions might be
214			for _, dataSize := range []int{10, 100, 1000, 5000, 10000} {
215				t.Run(fmt.Sprintf("%d", dataSize), func(t *testing.T) {
216					data := make([]byte, dataSize)
217					_, err := io.ReadFull(rand.Reader, data)
218					assert.Equal(t, nil, err)
219
220					encrypted, err := cfb.Encrypt(data)
221					assert.Equal(t, nil, err)
222					assert.NotEqual(t, encrypted, data)
223
224					// GCM is authenticated - this should lead to message authentication failed
225					_, err = gcm.Decrypt(encrypted)
226					assert.Error(t, err)
227				})
228			}
229		})
230	}
231}
232