1package transit
2
3import (
4	"context"
5	"testing"
6
7	"github.com/hashicorp/vault/sdk/logical"
8	"github.com/mitchellh/mapstructure"
9)
10
11// Case1: Ensure that batch encryption did not affect the normal flow of
12// encrypting the plaintext with a pre-existing key.
13func TestTransit_BatchEncryptionCase1(t *testing.T) {
14	var resp *logical.Response
15	var err error
16
17	b, s := createBackendWithStorage(t)
18
19	// Create the policy
20	policyReq := &logical.Request{
21		Operation: logical.UpdateOperation,
22		Path:      "keys/existing_key",
23		Storage:   s,
24	}
25	resp, err = b.HandleRequest(context.Background(), policyReq)
26	if err != nil || (resp != nil && resp.IsError()) {
27		t.Fatalf("err:%v resp:%#v", err, resp)
28	}
29
30	plaintext := "dGhlIHF1aWNrIGJyb3duIGZveA==" // "the quick brown fox"
31
32	encData := map[string]interface{}{
33		"plaintext": plaintext,
34	}
35
36	encReq := &logical.Request{
37		Operation: logical.UpdateOperation,
38		Path:      "encrypt/existing_key",
39		Storage:   s,
40		Data:      encData,
41	}
42	resp, err = b.HandleRequest(context.Background(), encReq)
43	if err != nil || (resp != nil && resp.IsError()) {
44		t.Fatalf("err:%v resp:%#v", err, resp)
45	}
46
47	ciphertext := resp.Data["ciphertext"]
48
49	decData := map[string]interface{}{
50		"ciphertext": ciphertext,
51	}
52	decReq := &logical.Request{
53		Operation: logical.UpdateOperation,
54		Path:      "decrypt/existing_key",
55		Storage:   s,
56		Data:      decData,
57	}
58	resp, err = b.HandleRequest(context.Background(), decReq)
59	if err != nil || (resp != nil && resp.IsError()) {
60		t.Fatalf("err:%v resp:%#v", err, resp)
61	}
62
63	if resp.Data["plaintext"] != plaintext {
64		t.Fatalf("bad: plaintext. Expected: %q, Actual: %q", plaintext, resp.Data["plaintext"])
65	}
66}
67
68// Case2: Ensure that batch encryption did not affect the normal flow of
69// encrypting the plaintext with the key upserted.
70func TestTransit_BatchEncryptionCase2(t *testing.T) {
71	var resp *logical.Response
72	var err error
73	b, s := createBackendWithStorage(t)
74
75	// Upsert the key and encrypt the data
76	plaintext := "dGhlIHF1aWNrIGJyb3duIGZveA=="
77
78	encData := map[string]interface{}{
79		"plaintext": plaintext,
80	}
81
82	encReq := &logical.Request{
83		Operation: logical.CreateOperation,
84		Path:      "encrypt/upserted_key",
85		Storage:   s,
86		Data:      encData,
87	}
88	resp, err = b.HandleRequest(context.Background(), encReq)
89	if err != nil || (resp != nil && resp.IsError()) {
90		t.Fatalf("err:%v resp:%#v", err, resp)
91	}
92
93	ciphertext := resp.Data["ciphertext"]
94	decData := map[string]interface{}{
95		"ciphertext": ciphertext,
96	}
97
98	policyReq := &logical.Request{
99		Operation: logical.ReadOperation,
100		Path:      "keys/upserted_key",
101		Storage:   s,
102	}
103
104	resp, err = b.HandleRequest(context.Background(), policyReq)
105	if err != nil || (resp != nil && resp.IsError()) {
106		t.Fatalf("err:%v resp:%#v", err, resp)
107	}
108
109	decReq := &logical.Request{
110		Operation: logical.UpdateOperation,
111		Path:      "decrypt/upserted_key",
112		Storage:   s,
113		Data:      decData,
114	}
115	resp, err = b.HandleRequest(context.Background(), decReq)
116	if err != nil || (resp != nil && resp.IsError()) {
117		t.Fatalf("err:%v resp:%#v", err, resp)
118	}
119
120	if resp.Data["plaintext"] != plaintext {
121		t.Fatalf("bad: plaintext. Expected: %q, Actual: %q", plaintext, resp.Data["plaintext"])
122	}
123}
124
125// Case3: If batch encryption input is not base64 encoded, it should fail.
126func TestTransit_BatchEncryptionCase3(t *testing.T) {
127	var err error
128
129	b, s := createBackendWithStorage(t)
130
131	batchInput := `[{"plaintext":"dGhlIHF1aWNrIGJyb3duIGZveA=="}]`
132	batchData := map[string]interface{}{
133		"batch_input": batchInput,
134	}
135
136	batchReq := &logical.Request{
137		Operation: logical.CreateOperation,
138		Path:      "encrypt/upserted_key",
139		Storage:   s,
140		Data:      batchData,
141	}
142	_, err = b.HandleRequest(context.Background(), batchReq)
143	if err == nil {
144		t.Fatal("expected an error")
145	}
146}
147
148// Case4: Test batch encryption with an existing key
149func TestTransit_BatchEncryptionCase4(t *testing.T) {
150	var resp *logical.Response
151	var err error
152
153	b, s := createBackendWithStorage(t)
154
155	policyReq := &logical.Request{
156		Operation: logical.UpdateOperation,
157		Path:      "keys/existing_key",
158		Storage:   s,
159	}
160	resp, err = b.HandleRequest(context.Background(), policyReq)
161	if err != nil || (resp != nil && resp.IsError()) {
162		t.Fatalf("err:%v resp:%#v", err, resp)
163	}
164
165	batchInput := []interface{}{
166		map[string]interface{}{"plaintext": "dGhlIHF1aWNrIGJyb3duIGZveA=="},
167		map[string]interface{}{"plaintext": "dGhlIHF1aWNrIGJyb3duIGZveA=="},
168	}
169
170	batchData := map[string]interface{}{
171		"batch_input": batchInput,
172	}
173	batchReq := &logical.Request{
174		Operation: logical.UpdateOperation,
175		Path:      "encrypt/existing_key",
176		Storage:   s,
177		Data:      batchData,
178	}
179	resp, err = b.HandleRequest(context.Background(), batchReq)
180	if err != nil || (resp != nil && resp.IsError()) {
181		t.Fatalf("err:%v resp:%#v", err, resp)
182	}
183
184	batchResponseItems := resp.Data["batch_results"].([]BatchResponseItem)
185
186	decReq := &logical.Request{
187		Operation: logical.UpdateOperation,
188		Path:      "decrypt/existing_key",
189		Storage:   s,
190	}
191
192	plaintext := "dGhlIHF1aWNrIGJyb3duIGZveA=="
193
194	for _, item := range batchResponseItems {
195		decReq.Data = map[string]interface{}{
196			"ciphertext": item.Ciphertext,
197		}
198		resp, err = b.HandleRequest(context.Background(), decReq)
199		if err != nil || (resp != nil && resp.IsError()) {
200			t.Fatalf("err:%v resp:%#v", err, resp)
201		}
202
203		if resp.Data["plaintext"] != plaintext {
204			t.Fatalf("bad: plaintext. Expected: %q, Actual: %q", plaintext, resp.Data["plaintext"])
205		}
206	}
207}
208
209// Case5: Test batch encryption with an existing derived key
210func TestTransit_BatchEncryptionCase5(t *testing.T) {
211	var resp *logical.Response
212	var err error
213
214	b, s := createBackendWithStorage(t)
215
216	policyData := map[string]interface{}{
217		"derived": true,
218	}
219
220	policyReq := &logical.Request{
221		Operation: logical.UpdateOperation,
222		Path:      "keys/existing_key",
223		Storage:   s,
224		Data:      policyData,
225	}
226
227	resp, err = b.HandleRequest(context.Background(), policyReq)
228	if err != nil || (resp != nil && resp.IsError()) {
229		t.Fatalf("err:%v resp:%#v", err, resp)
230	}
231
232	batchInput := []interface{}{
233		map[string]interface{}{"plaintext": "dGhlIHF1aWNrIGJyb3duIGZveA==", "context": "dmlzaGFsCg=="},
234		map[string]interface{}{"plaintext": "dGhlIHF1aWNrIGJyb3duIGZveA==", "context": "dmlzaGFsCg=="},
235	}
236
237	batchData := map[string]interface{}{
238		"batch_input": batchInput,
239	}
240
241	batchReq := &logical.Request{
242		Operation: logical.UpdateOperation,
243		Path:      "encrypt/existing_key",
244		Storage:   s,
245		Data:      batchData,
246	}
247	resp, err = b.HandleRequest(context.Background(), batchReq)
248	if err != nil || (resp != nil && resp.IsError()) {
249		t.Fatalf("err:%v resp:%#v", err, resp)
250	}
251
252	batchResponseItems := resp.Data["batch_results"].([]BatchResponseItem)
253
254	decReq := &logical.Request{
255		Operation: logical.UpdateOperation,
256		Path:      "decrypt/existing_key",
257		Storage:   s,
258	}
259
260	plaintext := "dGhlIHF1aWNrIGJyb3duIGZveA=="
261
262	for _, item := range batchResponseItems {
263		decReq.Data = map[string]interface{}{
264			"ciphertext": item.Ciphertext,
265			"context":    "dmlzaGFsCg==",
266		}
267		resp, err = b.HandleRequest(context.Background(), decReq)
268		if err != nil || (resp != nil && resp.IsError()) {
269			t.Fatalf("err:%v resp:%#v", err, resp)
270		}
271
272		if resp.Data["plaintext"] != plaintext {
273			t.Fatalf("bad: plaintext. Expected: %q, Actual: %q", plaintext, resp.Data["plaintext"])
274		}
275	}
276}
277
278// Case6: Test batch encryption with an upserted non-derived key
279func TestTransit_BatchEncryptionCase6(t *testing.T) {
280	var resp *logical.Response
281	var err error
282
283	b, s := createBackendWithStorage(t)
284
285	batchInput := []interface{}{
286		map[string]interface{}{"plaintext": "dGhlIHF1aWNrIGJyb3duIGZveA=="},
287		map[string]interface{}{"plaintext": "dGhlIHF1aWNrIGJyb3duIGZveA=="},
288	}
289
290	batchData := map[string]interface{}{
291		"batch_input": batchInput,
292	}
293	batchReq := &logical.Request{
294		Operation: logical.CreateOperation,
295		Path:      "encrypt/upserted_key",
296		Storage:   s,
297		Data:      batchData,
298	}
299	resp, err = b.HandleRequest(context.Background(), batchReq)
300	if err != nil || (resp != nil && resp.IsError()) {
301		t.Fatalf("err:%v resp:%#v", err, resp)
302	}
303
304	batchResponseItems := resp.Data["batch_results"].([]BatchResponseItem)
305
306	decReq := &logical.Request{
307		Operation: logical.UpdateOperation,
308		Path:      "decrypt/upserted_key",
309		Storage:   s,
310	}
311
312	plaintext := "dGhlIHF1aWNrIGJyb3duIGZveA=="
313
314	for _, responseItem := range batchResponseItems {
315		var item BatchResponseItem
316		if err := mapstructure.Decode(responseItem, &item); err != nil {
317			t.Fatal(err)
318		}
319		decReq.Data = map[string]interface{}{
320			"ciphertext": item.Ciphertext,
321		}
322		resp, err = b.HandleRequest(context.Background(), decReq)
323		if err != nil || (resp != nil && resp.IsError()) {
324			t.Fatalf("err:%v resp:%#v", err, resp)
325		}
326
327		if resp.Data["plaintext"] != plaintext {
328			t.Fatalf("bad: plaintext. Expected: %q, Actual: %q", plaintext, resp.Data["plaintext"])
329		}
330	}
331}
332
333// Case7: Test batch encryption with an upserted derived key
334func TestTransit_BatchEncryptionCase7(t *testing.T) {
335	var resp *logical.Response
336	var err error
337
338	b, s := createBackendWithStorage(t)
339
340	batchInput := []interface{}{
341		map[string]interface{}{"plaintext": "dGhlIHF1aWNrIGJyb3duIGZveA==", "context": "dmlzaGFsCg=="},
342		map[string]interface{}{"plaintext": "dGhlIHF1aWNrIGJyb3duIGZveA==", "context": "dmlzaGFsCg=="},
343	}
344
345	batchData := map[string]interface{}{
346		"batch_input": batchInput,
347	}
348	batchReq := &logical.Request{
349		Operation: logical.CreateOperation,
350		Path:      "encrypt/upserted_key",
351		Storage:   s,
352		Data:      batchData,
353	}
354	resp, err = b.HandleRequest(context.Background(), batchReq)
355	if err != nil || (resp != nil && resp.IsError()) {
356		t.Fatalf("err:%v resp:%#v", err, resp)
357	}
358
359	batchResponseItems := resp.Data["batch_results"].([]BatchResponseItem)
360
361	decReq := &logical.Request{
362		Operation: logical.UpdateOperation,
363		Path:      "decrypt/upserted_key",
364		Storage:   s,
365	}
366
367	plaintext := "dGhlIHF1aWNrIGJyb3duIGZveA=="
368
369	for _, item := range batchResponseItems {
370		decReq.Data = map[string]interface{}{
371			"ciphertext": item.Ciphertext,
372			"context":    "dmlzaGFsCg==",
373		}
374		resp, err = b.HandleRequest(context.Background(), decReq)
375		if err != nil || (resp != nil && resp.IsError()) {
376			t.Fatalf("err:%v resp:%#v", err, resp)
377		}
378
379		if resp.Data["plaintext"] != plaintext {
380			t.Fatalf("bad: plaintext. Expected: %q, Actual: %q", plaintext, resp.Data["plaintext"])
381		}
382	}
383}
384
385// Case8: If plaintext is not base64 encoded, encryption should fail
386func TestTransit_BatchEncryptionCase8(t *testing.T) {
387	var resp *logical.Response
388	var err error
389
390	b, s := createBackendWithStorage(t)
391
392	// Create the policy
393	policyReq := &logical.Request{
394		Operation: logical.UpdateOperation,
395		Path:      "keys/existing_key",
396		Storage:   s,
397	}
398	resp, err = b.HandleRequest(context.Background(), policyReq)
399	if err != nil || (resp != nil && resp.IsError()) {
400		t.Fatalf("err:%v resp:%#v", err, resp)
401	}
402
403	batchInput := []interface{}{
404		map[string]interface{}{"plaintext": "simple_plaintext"},
405	}
406	batchData := map[string]interface{}{
407		"batch_input": batchInput,
408	}
409	batchReq := &logical.Request{
410		Operation: logical.UpdateOperation,
411		Path:      "encrypt/existing_key",
412		Storage:   s,
413		Data:      batchData,
414	}
415	resp, err = b.HandleRequest(context.Background(), batchReq)
416	if err != nil || (resp != nil && resp.IsError()) {
417		t.Fatalf("err:%v resp:%#v", err, resp)
418	}
419
420	plaintext := "simple plaintext"
421
422	encData := map[string]interface{}{
423		"plaintext": plaintext,
424	}
425
426	encReq := &logical.Request{
427		Operation: logical.UpdateOperation,
428		Path:      "encrypt/existing_key",
429		Storage:   s,
430		Data:      encData,
431	}
432	resp, err = b.HandleRequest(context.Background(), encReq)
433	if err == nil {
434		t.Fatal("expected an error")
435	}
436}
437
438// Case9: If both plaintext and batch inputs are supplied, plaintext should be
439// ignored.
440func TestTransit_BatchEncryptionCase9(t *testing.T) {
441	var resp *logical.Response
442	var err error
443
444	b, s := createBackendWithStorage(t)
445
446	batchInput := []interface{}{
447		map[string]interface{}{"plaintext": "dGhlIHF1aWNrIGJyb3duIGZveA=="},
448		map[string]interface{}{"plaintext": "dGhlIHF1aWNrIGJyb3duIGZveA=="},
449	}
450	plaintext := "dGhlIHF1aWNrIGJyb3duIGZveA=="
451	batchData := map[string]interface{}{
452		"batch_input": batchInput,
453		"plaintext":   plaintext,
454	}
455	batchReq := &logical.Request{
456		Operation: logical.CreateOperation,
457		Path:      "encrypt/upserted_key",
458		Storage:   s,
459		Data:      batchData,
460	}
461	resp, err = b.HandleRequest(context.Background(), batchReq)
462	if err != nil || (resp != nil && resp.IsError()) {
463		t.Fatalf("err:%v resp:%#v", err, resp)
464	}
465
466	_, ok := resp.Data["ciphertext"]
467	if ok {
468		t.Fatal("ciphertext field should not be set")
469	}
470}
471
472// Case10: Inconsistent presence of 'context' in batch input should be caught
473func TestTransit_BatchEncryptionCase10(t *testing.T) {
474	var err error
475
476	b, s := createBackendWithStorage(t)
477
478	batchInput := []interface{}{
479		map[string]interface{}{"plaintext": "dGhlIHF1aWNrIGJyb3duIGZveA=="},
480		map[string]interface{}{"plaintext": "dGhlIHF1aWNrIGJyb3duIGZveA==", "context": "dmlzaGFsCg=="},
481	}
482
483	batchData := map[string]interface{}{
484		"batch_input": batchInput,
485	}
486
487	batchReq := &logical.Request{
488		Operation: logical.CreateOperation,
489		Path:      "encrypt/upserted_key",
490		Storage:   s,
491		Data:      batchData,
492	}
493	_, err = b.HandleRequest(context.Background(), batchReq)
494	if err == nil {
495		t.Fatalf("expected an error")
496	}
497}
498
499// Case11: Incorrect inputs for context and nonce should not fail the operation
500func TestTransit_BatchEncryptionCase11(t *testing.T) {
501	var err error
502
503	b, s := createBackendWithStorage(t)
504
505	batchInput := []interface{}{
506		map[string]interface{}{"plaintext": "dGhlIHF1aWNrIGJyb3duIGZveA==", "context": "dmlzaGFsCg=="},
507		map[string]interface{}{"plaintext": "dGhlIHF1aWNrIGJyb3duIGZveA==", "context": "not-encoded"},
508	}
509
510	batchData := map[string]interface{}{
511		"batch_input": batchInput,
512	}
513	batchReq := &logical.Request{
514		Operation: logical.CreateOperation,
515		Path:      "encrypt/upserted_key",
516		Storage:   s,
517		Data:      batchData,
518	}
519	_, err = b.HandleRequest(context.Background(), batchReq)
520	if err != nil {
521		t.Fatal(err)
522	}
523}
524
525// Case12: Invalid batch input
526func TestTransit_BatchEncryptionCase12(t *testing.T) {
527	var err error
528	b, s := createBackendWithStorage(t)
529
530	batchInput := []interface{}{
531		map[string]interface{}{},
532		"unexpected_interface",
533	}
534
535	batchData := map[string]interface{}{
536		"batch_input": batchInput,
537	}
538	batchReq := &logical.Request{
539		Operation: logical.CreateOperation,
540		Path:      "encrypt/upserted_key",
541		Storage:   s,
542		Data:      batchData,
543	}
544	_, err = b.HandleRequest(context.Background(), batchReq)
545	if err == nil {
546		t.Fatalf("expected an error")
547	}
548}
549