1// Copyright (c) 2021 Snowflake Computing Inc. All right reserved.
2
3package gosnowflake
4
5import (
6	"bytes"
7	"crypto/aes"
8	"crypto/cipher"
9	"crypto/rand"
10	"encoding/base64"
11	"encoding/json"
12	"fmt"
13	"io"
14	"io/ioutil"
15	"os"
16	"strconv"
17)
18
19type snowflakeFileEncryption struct {
20	QueryStageMasterKey string `json:"queryStageMasterKey,omitempty"`
21	QueryID             string `json:"queryId,omitempty"`
22	SMKID               int64  `json:"smkId,omitempty"`
23}
24
25// PUT requests return a single encryptionMaterial object whereas GET requests
26// return a slice (array) of encryptionMaterial objects, both under the field
27// 'encryptionMaterial'
28type encryptionWrapper struct {
29	snowflakeFileEncryption
30	EncryptionMaterials []snowflakeFileEncryption
31}
32
33// override default behavior for wrapper
34func (ew *encryptionWrapper) UnmarshalJSON(data []byte) error {
35	// if GET, unmarshal slice of encryptionMaterial
36	if err := json.Unmarshal(data, &ew.EncryptionMaterials); err == nil {
37		return err
38	}
39	// else (if PUT), unmarshal the encryptionMaterial itself
40	return json.Unmarshal(data, &ew.snowflakeFileEncryption)
41}
42
43type encryptMetadata struct {
44	key     string
45	iv      string
46	matdesc string
47}
48
49// encryptStream encrypts a stream buffer using AES128 block cipher in CBC mode
50// with PKCS5 padding
51func encryptStream(
52	sfe *snowflakeFileEncryption,
53	src io.Reader,
54	out io.Writer,
55	chunkSize int) (*encryptMetadata, error) {
56	if chunkSize == 0 {
57		chunkSize = aes.BlockSize * 4 * 1024
58	}
59	decodedKey, _ := base64.StdEncoding.DecodeString(sfe.QueryStageMasterKey)
60	keySize := len(decodedKey)
61
62	fileKey := getSecureRandom(keySize)
63	block, _ := aes.NewCipher(fileKey)
64	ivData := getSecureRandom(block.BlockSize())
65
66	mode := cipher.NewCBCEncrypter(block, ivData)
67	cipherText := make([]byte, chunkSize)
68
69	// encrypt file with CBC
70	var err error
71	padded := false
72	for {
73		chunk := make([]byte, chunkSize)
74		n, err := src.Read(chunk)
75		if n == 0 || err != nil {
76			break
77		} else if n%aes.BlockSize != 0 {
78			chunk = padBytesLength(chunk[:n], aes.BlockSize)
79			padded = true
80		}
81		mode.CryptBlocks(cipherText, chunk)
82		out.Write(cipherText[:len(chunk)])
83	}
84	if err != nil {
85		return nil, err
86	}
87	if !padded {
88		blockSizeCipher := bytes.Repeat([]byte{byte(aes.BlockSize)}, aes.BlockSize)
89		chunk := make([]byte, aes.BlockSize)
90		mode.CryptBlocks(chunk, blockSizeCipher)
91		out.Write(chunk)
92	}
93
94	// encrypt key with ECB
95	fileKey = padBytesLength(fileKey, block.BlockSize())
96	encryptedFileKey := make([]byte, len(fileKey))
97	if err = encryptECB(encryptedFileKey, fileKey, decodedKey); err != nil {
98		return nil, err
99	}
100
101	matDesc := materialDescriptor{
102		strconv.Itoa(int(sfe.SMKID)),
103		sfe.QueryID,
104		strconv.Itoa(keySize * 8),
105	}
106
107	return &encryptMetadata{
108		base64.StdEncoding.EncodeToString(encryptedFileKey),
109		base64.StdEncoding.EncodeToString(ivData),
110		matdescToUnicode(matDesc),
111	}, nil
112}
113
114func encryptECB(encrypted []byte, fileKey []byte, decodedKey []byte) error {
115	block, _ := aes.NewCipher(decodedKey)
116	if len(fileKey)%block.BlockSize() != 0 {
117		return fmt.Errorf("input not full of blocks")
118	}
119	if len(encrypted) < len(fileKey) {
120		return fmt.Errorf("output length is smaller than input length")
121	}
122	for len(fileKey) > 0 {
123		block.Encrypt(encrypted, fileKey[:block.BlockSize()])
124		encrypted = encrypted[block.BlockSize():]
125		fileKey = fileKey[block.BlockSize():]
126	}
127	return nil
128}
129
130func decryptECB(decrypted []byte, keyBytes []byte, decodedKey []byte) error {
131	block, _ := aes.NewCipher(decodedKey)
132	if len(keyBytes)%block.BlockSize() != 0 {
133		return fmt.Errorf("input not full of blocks")
134	}
135	if len(decrypted) < len(keyBytes) {
136		return fmt.Errorf("output length is smaller than input length")
137	}
138	for len(keyBytes) > 0 {
139		block.Decrypt(decrypted, keyBytes[:block.BlockSize()])
140		keyBytes = keyBytes[block.BlockSize():]
141		decrypted = decrypted[block.BlockSize():]
142	}
143	return nil
144}
145
146func encryptFile(
147	sfe *snowflakeFileEncryption,
148	filename string,
149	chunkSize int,
150	tmpDir string) (
151	*encryptMetadata, string, error) {
152	if chunkSize == 0 {
153		chunkSize = aes.BlockSize * 4 * 1024
154	}
155	tmpOutputFile, _ := ioutil.TempFile(tmpDir, baseName(filename)+"#")
156	infile, err := os.OpenFile(filename, os.O_CREATE|os.O_RDONLY, os.ModePerm)
157	if err != nil {
158		return nil, "", err
159	}
160	meta, err := encryptStream(sfe, infile, tmpOutputFile, chunkSize)
161	if err != nil {
162		return nil, "", err
163	}
164	return meta, tmpOutputFile.Name(), nil
165}
166
167func decryptFile(
168	metadata *encryptMetadata,
169	sfe *snowflakeFileEncryption,
170	filename string,
171	chunkSize int,
172	tmpDir string) (
173	string, error) {
174	if chunkSize == 0 {
175		chunkSize = aes.BlockSize * 4 * 1024
176	}
177	decodedKey, _ := base64.StdEncoding.DecodeString(sfe.QueryStageMasterKey)
178	keyBytes, _ := base64.StdEncoding.DecodeString(metadata.key) // encrypted file key
179	ivBytes, _ := base64.StdEncoding.DecodeString(metadata.iv)
180
181	// decrypt file key
182	decryptedKey := make([]byte, len(keyBytes))
183	if err := decryptECB(decryptedKey, keyBytes, decodedKey); err != nil {
184		return "", err
185	}
186	decryptedKey = paddingTrim(decryptedKey)
187
188	// decrypt file
189	block, _ := aes.NewCipher(decryptedKey)
190	mode := cipher.NewCBCDecrypter(block, ivBytes)
191
192	tmpOutputFile, err := ioutil.TempFile(tmpDir, baseName(filename)+"#")
193	if err != nil {
194		return "", err
195	}
196	defer tmpOutputFile.Close()
197	infile, err := os.OpenFile(filename, os.O_RDONLY, os.ModePerm)
198	if err != nil {
199		return "", err
200	}
201	defer infile.Close()
202	var totalFileSize int
203	var prevChunk []byte
204	for {
205		chunk := make([]byte, chunkSize)
206		n, err := infile.Read(chunk)
207		if n == 0 || err != nil {
208			break
209		}
210		totalFileSize += n
211		chunk = chunk[:n]
212		mode.CryptBlocks(chunk, chunk)
213		tmpOutputFile.Write(chunk)
214		prevChunk = chunk
215	}
216	if err != nil {
217		return "", err
218	}
219	if prevChunk != nil {
220		totalFileSize -= paddingOffset(prevChunk)
221	}
222	tmpOutputFile.Truncate(int64(totalFileSize))
223	return tmpOutputFile.Name(), nil
224}
225
226type materialDescriptor struct {
227	SmkID   string `json:"smkId"`
228	QueryID string `json:"queryId"`
229	KeySize string `json:"keySize"`
230}
231
232func matdescToUnicode(matdesc materialDescriptor) string {
233	s, _ := json.Marshal(&matdesc)
234	return string(s)
235}
236
237func getSecureRandom(byteLength int) []byte {
238	token := make([]byte, byteLength)
239	rand.Read(token)
240	return token
241}
242
243func padBytesLength(src []byte, blockSize int) []byte {
244	padLength := blockSize - len(src)%blockSize
245	padText := bytes.Repeat([]byte{byte(padLength)}, padLength)
246	return append(src, padText...)
247}
248
249func paddingTrim(src []byte) []byte {
250	unpadding := src[len(src)-1]
251	return src[:len(src)-int(unpadding)]
252}
253
254func paddingOffset(src []byte) int {
255	length := len(src)
256	return int(src[length-1])
257}
258
259type contentKey struct {
260	KeyID         string `json:"KeyId,omitempty"`
261	EncryptionKey string `json:"EncryptedKey,omitempty"`
262	Algorithm     string `json:"Algorithm,omitempty"`
263}
264
265type encryptionAgent struct {
266	Protocol            string `json:"Protocol,omitempty"`
267	EncryptionAlgorithm string `json:"EncryptionAlgorithm,omitempty"`
268}
269
270type keyMetadata struct {
271	EncryptionLibrary string `json:"EncryptionLibrary,omitempty"`
272}
273
274type encryptionData struct {
275	EncryptionMode      string          `json:"EncryptionMode,omitempty"`
276	WrappedContentKey   contentKey      `json:"WrappedContentKey,omitempty"`
277	EncryptionAgent     encryptionAgent `json:"EncryptionAgent,omitempty"`
278	ContentEncryptionIV string          `json:"ContentEncryptionIV,omitempty"`
279	KeyWrappingMetadata keyMetadata     `json:"KeyWrappingMetadata,omitempty"`
280}
281