1package s3crypto
2
3import (
4	"bytes"
5	"crypto/aes"
6	"crypto/cipher"
7	"io"
8)
9
10// AESCBC is a symmetric crypto algorithm. This algorithm
11// requires a padder due to CBC needing to be of the same block
12// size. AES CBC is vulnerable to Padding Oracle attacks and
13// so should be avoided when possible.
14type aesCBC struct {
15	encrypter cipher.BlockMode
16	decrypter cipher.BlockMode
17	padder    Padder
18}
19
20// newAESCBC creates a new AES CBC cipher. Expects keys to be of
21// the correct size.
22func newAESCBC(cd CipherData, padder Padder) (Cipher, error) {
23	block, err := aes.NewCipher(cd.Key)
24	if err != nil {
25		return nil, err
26	}
27
28	encrypter := cipher.NewCBCEncrypter(block, cd.IV)
29	decrypter := cipher.NewCBCDecrypter(block, cd.IV)
30
31	return &aesCBC{encrypter, decrypter, padder}, nil
32}
33
34// Encrypt will encrypt the data using AES CBC by returning
35// an io.Reader. The io.Reader will encrypt the data as Read
36// is called.
37func (c *aesCBC) Encrypt(src io.Reader) io.Reader {
38	reader := &cbcEncryptReader{
39		encrypter: c.encrypter,
40		src:       src,
41		padder:    c.padder,
42	}
43	return reader
44}
45
46type cbcEncryptReader struct {
47	encrypter cipher.BlockMode
48	src       io.Reader
49	padder    Padder
50	size      int
51	buf       bytes.Buffer
52}
53
54// Read will read from our io.Reader and encrypt the data as necessary.
55// Due to padding, we have to do some logic that when we encounter an
56// end of file to pad properly.
57func (reader *cbcEncryptReader) Read(data []byte) (int, error) {
58	n, err := reader.src.Read(data)
59	reader.size += n
60	blockSize := reader.encrypter.BlockSize()
61	reader.buf.Write(data[:n])
62
63	if err == io.EOF {
64		b := make([]byte, getSliceSize(blockSize, reader.buf.Len(), len(data)))
65		n, err = reader.buf.Read(b)
66		if err != nil && err != io.EOF {
67			return n, err
68		}
69		// The buffer is now empty, we can now pad the data
70		if reader.buf.Len() == 0 {
71			b, err = reader.padder.Pad(b[:n], reader.size)
72			if err != nil {
73				return n, err
74			}
75			n = len(b)
76			err = io.EOF
77		}
78		// We only want to encrypt if we have read anything
79		if n > 0 {
80			reader.encrypter.CryptBlocks(data, b)
81		}
82		return n, err
83	}
84
85	if err != nil {
86		return n, err
87	}
88
89	if size := reader.buf.Len(); size >= blockSize {
90		nBlocks := size / blockSize
91		if size > len(data) {
92			nBlocks = len(data) / blockSize
93		}
94
95		if nBlocks > 0 {
96			b := make([]byte, nBlocks*blockSize)
97			n, _ = reader.buf.Read(b)
98			reader.encrypter.CryptBlocks(data, b[:n])
99		}
100	} else {
101		n = 0
102	}
103	return n, nil
104}
105
106// Decrypt will decrypt the data using AES CBC
107func (c *aesCBC) Decrypt(src io.Reader) io.Reader {
108	return &cbcDecryptReader{
109		decrypter: c.decrypter,
110		src:       src,
111		padder:    c.padder,
112	}
113}
114
115type cbcDecryptReader struct {
116	decrypter cipher.BlockMode
117	src       io.Reader
118	padder    Padder
119	buf       bytes.Buffer
120}
121
122// Read will read from our io.Reader and decrypt the data as necessary.
123// Due to padding, we have to do some logic that when we encounter an
124// end of file to pad properly.
125func (reader *cbcDecryptReader) Read(data []byte) (int, error) {
126	n, err := reader.src.Read(data)
127	blockSize := reader.decrypter.BlockSize()
128	reader.buf.Write(data[:n])
129
130	if err == io.EOF {
131		b := make([]byte, getSliceSize(blockSize, reader.buf.Len(), len(data)))
132		n, err = reader.buf.Read(b)
133		if err != nil && err != io.EOF {
134			return n, err
135		}
136		// We only want to decrypt if we have read anything
137		if n > 0 {
138			reader.decrypter.CryptBlocks(data, b)
139		}
140
141		if reader.buf.Len() == 0 {
142			b, err = reader.padder.Unpad(data[:n])
143			n = len(b)
144			if err != nil {
145				return n, err
146			}
147			err = io.EOF
148		}
149		return n, err
150	}
151
152	if err != nil {
153		return n, err
154	}
155
156	if size := reader.buf.Len(); size >= blockSize {
157		nBlocks := size / blockSize
158		if size > len(data) {
159			nBlocks = len(data) / blockSize
160		}
161		// The last block is always padded. This will allow us to unpad
162		// when we receive an io.EOF error
163		nBlocks -= blockSize
164
165		if nBlocks > 0 {
166			b := make([]byte, nBlocks*blockSize)
167			n, _ = reader.buf.Read(b)
168			reader.decrypter.CryptBlocks(data, b[:n])
169		} else {
170			n = 0
171		}
172	}
173
174	return n, nil
175}
176
177// getSliceSize will return the correct amount of bytes we need to
178// read with regards to padding.
179func getSliceSize(blockSize, bufSize, dataSize int) int {
180	size := bufSize
181	if bufSize > dataSize {
182		size = dataSize
183	}
184	size = size - (size % blockSize) - blockSize
185	if size <= 0 {
186		size = blockSize
187	}
188
189	return size
190}
191