1package s3crypto
2
3import (
4	"bytes"
5	"crypto/aes"
6	"crypto/cipher"
7	"io"
8	"io/ioutil"
9)
10
11// AESGCM Symmetric encryption algorithm. Since Golang designed this
12// with only TLS in mind. We have to load it all into memory meaning
13// this isn't streamed.
14type aesGCM struct {
15	aead  cipher.AEAD
16	nonce []byte
17}
18
19// newAESGCM creates a new AES GCM cipher. Expects keys to be of
20// the correct size.
21//
22// Example:
23//
24//	cd := &s3crypto.CipherData{
25//		Key: key,
26//		"IV": iv,
27//	}
28//	cipher, err := s3crypto.newAESGCM(cd)
29func newAESGCM(cd CipherData) (Cipher, error) {
30	block, err := aes.NewCipher(cd.Key)
31	if err != nil {
32		return nil, err
33	}
34
35	aesgcm, err := cipher.NewGCM(block)
36	if err != nil {
37		return nil, err
38	}
39
40	return &aesGCM{aesgcm, cd.IV}, nil
41}
42
43// Encrypt will encrypt the data using AES GCM
44// Tag will be included as the last 16 bytes of the slice
45func (c *aesGCM) Encrypt(src io.Reader) io.Reader {
46	reader := &gcmEncryptReader{
47		encrypter: c.aead,
48		nonce:     c.nonce,
49		src:       src,
50	}
51	return reader
52}
53
54type gcmEncryptReader struct {
55	encrypter cipher.AEAD
56	nonce     []byte
57	src       io.Reader
58	buf       *bytes.Buffer
59}
60
61func (reader *gcmEncryptReader) Read(data []byte) (int, error) {
62	if reader.buf == nil {
63		b, err := ioutil.ReadAll(reader.src)
64		if err != nil {
65			return 0, err
66		}
67		b = reader.encrypter.Seal(b[:0], reader.nonce, b, nil)
68		reader.buf = bytes.NewBuffer(b)
69	}
70
71	return reader.buf.Read(data)
72}
73
74// Decrypt will decrypt the data using AES GCM
75func (c *aesGCM) Decrypt(src io.Reader) io.Reader {
76	return &gcmDecryptReader{
77		decrypter: c.aead,
78		nonce:     c.nonce,
79		src:       src,
80	}
81}
82
83type gcmDecryptReader struct {
84	decrypter cipher.AEAD
85	nonce     []byte
86	src       io.Reader
87	buf       *bytes.Buffer
88}
89
90func (reader *gcmDecryptReader) Read(data []byte) (int, error) {
91	if reader.buf == nil {
92		b, err := ioutil.ReadAll(reader.src)
93		if err != nil {
94			return 0, err
95		}
96		b, err = reader.decrypter.Open(b[:0], reader.nonce, b, nil)
97		if err != nil {
98			return 0, err
99		}
100
101		reader.buf = bytes.NewBuffer(b)
102	}
103
104	return reader.buf.Read(data)
105}
106