1//go:build go1.9 && s3crypto_integ
2// +build go1.9,s3crypto_integ
3
4package s3crypto_test
5
6import (
7	"bytes"
8	"encoding/base64"
9	"fmt"
10	"io/ioutil"
11	"strings"
12	"testing"
13
14	"github.com/aws/aws-sdk-go/aws"
15	"github.com/aws/aws-sdk-go/awstesting/integration"
16	"github.com/aws/aws-sdk-go/service/kms"
17	"github.com/aws/aws-sdk-go/service/s3"
18	"github.com/aws/aws-sdk-go/service/s3/s3crypto"
19)
20
21func TestInteg_EncryptFixtures(t *testing.T) {
22	sess := integration.SessionWithDefaultRegion("us-west-2")
23
24	const bucket = "aws-s3-shared-tests"
25	const version = "version_2"
26
27	cases := []struct {
28		CEKAlg           string
29		KEK, V1, V2, CEK string
30	}{
31		{
32			CEKAlg: "aes_gcm",
33			KEK:    "kms", V1: "AWS_SDK_TEST_ALIAS", V2: "us-west-2", CEK: "aes_gcm",
34		},
35		{
36			CEKAlg: "aes_cbc",
37			KEK:    "kms", V1: "AWS_SDK_TEST_ALIAS", V2: "us-west-2", CEK: "aes_cbc",
38		},
39	}
40
41	for _, c := range cases {
42		t.Run(c.CEKAlg, func(t *testing.T) {
43			s3Client := s3.New(sess)
44
45			fixtures := getFixtures(t, s3Client, c.CEKAlg, bucket)
46			builder, masterKey := getEncryptFixtureBuilder(t, c.KEK, c.V1, c.V2, c.CEK)
47
48			encClient := s3crypto.NewEncryptionClient(sess, builder)
49
50			for caseKey, plaintext := range fixtures.Plaintexts {
51				_, err := encClient.PutObject(&s3.PutObjectInput{
52					Bucket: aws.String(bucket),
53					Key: aws.String(
54						fmt.Sprintf("%s/%s/language_Go/ciphertext_test_case_%s",
55							fixtures.BaseFolder, version, caseKey),
56					),
57					Body: bytes.NewReader(plaintext),
58					Metadata: map[string]*string{
59						"Masterkey": &masterKey,
60					},
61				})
62				if err != nil {
63					t.Fatalf("failed to upload encrypted fixture, %v", err)
64				}
65			}
66		})
67	}
68}
69
70func TestInteg_DecryptFixtures(t *testing.T) {
71	sess := integration.SessionWithDefaultRegion("us-west-2")
72
73	const bucket = "aws-s3-shared-tests"
74	const version = "version_2"
75
76	cases := []struct {
77		CEKAlg string
78		Lang   string
79	}{
80		{CEKAlg: "aes_cbc", Lang: "Go"},
81		{CEKAlg: "aes_gcm", Lang: "Go"},
82		{CEKAlg: "aes_cbc", Lang: "Java"},
83		{CEKAlg: "aes_gcm", Lang: "Java"},
84	}
85
86	for _, c := range cases {
87		t.Run(c.CEKAlg+"-"+c.Lang, func(t *testing.T) {
88			decClient := s3crypto.NewDecryptionClient(sess)
89			s3Client := s3.New(sess)
90
91			fixtures := getFixtures(t, s3Client, c.CEKAlg, bucket)
92			ciphertexts := decryptFixtures(t, decClient, s3Client, fixtures, bucket, c.Lang, version)
93
94			for caseKey, ciphertext := range ciphertexts {
95				if e, a := len(fixtures.Plaintexts[caseKey]), len(ciphertext); e != a {
96					t.Errorf("expect %v text len, got %v", e, a)
97				}
98				if e, a := fixtures.Plaintexts[caseKey], ciphertext; !bytes.Equal(e, a) {
99					t.Errorf("expect %v text, got %v", e, a)
100				}
101			}
102		})
103	}
104}
105
106type testFixtures struct {
107	BaseFolder string
108	Plaintexts map[string][]byte
109}
110
111func getFixtures(t *testing.T, s3Client *s3.S3, cekAlg, bucket string) testFixtures {
112	t.Helper()
113
114	prefix := "plaintext_test_case_"
115	baseFolder := "crypto_tests/" + cekAlg
116
117	out, err := s3Client.ListObjects(&s3.ListObjectsInput{
118		Bucket: aws.String(bucket),
119		Prefix: aws.String(baseFolder + "/" + prefix),
120	})
121	if err != nil {
122		t.Fatalf("unable to list fixtures %v", err)
123	}
124
125	plaintexts := map[string][]byte{}
126	for _, obj := range out.Contents {
127		ptObj, err := s3Client.GetObject(&s3.GetObjectInput{
128			Bucket: aws.String(bucket),
129			Key:    obj.Key,
130		})
131		if err != nil {
132			t.Fatalf("unable to get fixture object %s, %v", *obj.Key, err)
133		}
134		caseKey := strings.TrimPrefix(*obj.Key, baseFolder+"/"+prefix)
135		plaintext, err := ioutil.ReadAll(ptObj.Body)
136		if err != nil {
137			t.Fatalf("unable to read fixture object %s, %v", *obj.Key, err)
138		}
139
140		plaintexts[caseKey] = plaintext
141	}
142
143	return testFixtures{
144		BaseFolder: baseFolder,
145		Plaintexts: plaintexts,
146	}
147}
148
149func getEncryptFixtureBuilder(t *testing.T, kek, v1, v2, cek string,
150) (builder s3crypto.ContentCipherBuilder, masterKey string) {
151	t.Helper()
152
153	var handler s3crypto.CipherDataGenerator
154	switch kek {
155	case "kms":
156		arn, err := getAliasInformation(v1, v2)
157		if err != nil {
158			t.Fatalf("failed to get fixture alias info for %s, %v", v1, err)
159		}
160
161		masterKey = base64.StdEncoding.EncodeToString([]byte(arn))
162		if err != nil {
163			t.Fatalf("failed to encode alias's arn %v", err)
164		}
165
166		kmsSvc := kms.New(integration.Session, &aws.Config{
167			Region: &v2,
168		})
169		handler = s3crypto.NewKMSKeyGenerator(kmsSvc, arn)
170	default:
171		t.Fatalf("unknown fixture KEK, %v", kek)
172	}
173
174	switch cek {
175	case "aes_gcm":
176		builder = s3crypto.AESGCMContentCipherBuilder(handler)
177	case "aes_cbc":
178		builder = s3crypto.AESCBCContentCipherBuilder(handler, s3crypto.AESCBCPadder)
179	default:
180		t.Fatalf("unknown fixture CEK, %v", cek)
181	}
182
183	return builder, masterKey
184}
185
186func getAliasInformation(alias, region string) (string, error) {
187	arn := ""
188	svc := kms.New(integration.Session, &aws.Config{
189		Region: &region,
190	})
191
192	truncated := true
193	var marker *string
194	for truncated {
195		out, err := svc.ListAliases(&kms.ListAliasesInput{
196			Marker: marker,
197		})
198		if err != nil {
199			return arn, err
200		}
201		for _, aliasEntry := range out.Aliases {
202			if *aliasEntry.AliasName == "alias/"+alias {
203				return *aliasEntry.AliasArn, nil
204			}
205		}
206		truncated = *out.Truncated
207		marker = out.NextMarker
208	}
209
210	return "", fmt.Errorf("kms alias %s does not exist", alias)
211}
212
213func decryptFixtures(t *testing.T, decClient *s3crypto.DecryptionClient, s3Client *s3.S3,
214	fixtures testFixtures, bucket, lang, version string,
215) map[string][]byte {
216	t.Helper()
217
218	prefix := "ciphertext_test_case_"
219	lang = "language_" + lang
220
221	ciphertexts := map[string][]byte{}
222	for caseKey := range fixtures.Plaintexts {
223		cipherKey := fixtures.BaseFolder + "/" + version + "/" + lang + "/" + prefix + caseKey
224
225		// To get metadata for encryption key
226		ctObj, err := s3Client.GetObject(&s3.GetObjectInput{
227			Bucket: &bucket,
228			Key:    &cipherKey,
229		})
230		if err != nil {
231			// TODO error?
232			continue
233		}
234
235		// We don't support wrap, so skip it
236		if ctObj.Metadata["X-Amz-Wrap-Alg"] == nil || *ctObj.Metadata["X-Amz-Wrap-Alg"] != "kms" {
237			continue
238		}
239
240		ctObj, err = decClient.GetObject(&s3.GetObjectInput{
241			Bucket: &bucket,
242			Key:    &cipherKey,
243		})
244		if err != nil {
245			t.Fatalf("failed to get encrypted object %v", err)
246		}
247
248		ciphertext, err := ioutil.ReadAll(ctObj.Body)
249		if err != nil {
250			t.Fatalf("failed to read object data %v", err)
251		}
252		ciphertexts[caseKey] = ciphertext
253	}
254
255	return ciphertexts
256}
257