1// Copyright (c) 2021 Snowflake Computing Inc. All right reserved. 2 3package gosnowflake 4 5import ( 6 "bytes" 7 "crypto/aes" 8 "crypto/cipher" 9 "crypto/rand" 10 "encoding/base64" 11 "encoding/json" 12 "fmt" 13 "io" 14 "io/ioutil" 15 "os" 16 "strconv" 17) 18 19type snowflakeFileEncryption struct { 20 QueryStageMasterKey string `json:"queryStageMasterKey,omitempty"` 21 QueryID string `json:"queryId,omitempty"` 22 SMKID int64 `json:"smkId,omitempty"` 23} 24 25// PUT requests return a single encryptionMaterial object whereas GET requests 26// return a slice (array) of encryptionMaterial objects, both under the field 27// 'encryptionMaterial' 28type encryptionWrapper struct { 29 snowflakeFileEncryption 30 EncryptionMaterials []snowflakeFileEncryption 31} 32 33// override default behavior for wrapper 34func (ew *encryptionWrapper) UnmarshalJSON(data []byte) error { 35 // if GET, unmarshal slice of encryptionMaterial 36 if err := json.Unmarshal(data, &ew.EncryptionMaterials); err == nil { 37 return err 38 } 39 // else (if PUT), unmarshal the encryptionMaterial itself 40 return json.Unmarshal(data, &ew.snowflakeFileEncryption) 41} 42 43type encryptMetadata struct { 44 key string 45 iv string 46 matdesc string 47} 48 49// encryptStream encrypts a stream buffer using AES128 block cipher in CBC mode 50// with PKCS5 padding 51func encryptStream( 52 sfe *snowflakeFileEncryption, 53 src io.Reader, 54 out io.Writer, 55 chunkSize int) (*encryptMetadata, error) { 56 if chunkSize == 0 { 57 chunkSize = aes.BlockSize * 4 * 1024 58 } 59 decodedKey, _ := base64.StdEncoding.DecodeString(sfe.QueryStageMasterKey) 60 keySize := len(decodedKey) 61 62 fileKey := getSecureRandom(keySize) 63 block, _ := aes.NewCipher(fileKey) 64 ivData := getSecureRandom(block.BlockSize()) 65 66 mode := cipher.NewCBCEncrypter(block, ivData) 67 cipherText := make([]byte, chunkSize) 68 69 // encrypt file with CBC 70 var err error 71 padded := false 72 for { 73 chunk := make([]byte, chunkSize) 74 n, err := src.Read(chunk) 75 if n == 0 || err != nil { 76 break 77 } else if n%aes.BlockSize != 0 { 78 chunk = padBytesLength(chunk[:n], aes.BlockSize) 79 padded = true 80 } 81 mode.CryptBlocks(cipherText, chunk) 82 out.Write(cipherText[:len(chunk)]) 83 } 84 if err != nil { 85 return nil, err 86 } 87 if !padded { 88 blockSizeCipher := bytes.Repeat([]byte{byte(aes.BlockSize)}, aes.BlockSize) 89 chunk := make([]byte, aes.BlockSize) 90 mode.CryptBlocks(chunk, blockSizeCipher) 91 out.Write(chunk) 92 } 93 94 // encrypt key with ECB 95 fileKey = padBytesLength(fileKey, block.BlockSize()) 96 encryptedFileKey := make([]byte, len(fileKey)) 97 if err = encryptECB(encryptedFileKey, fileKey, decodedKey); err != nil { 98 return nil, err 99 } 100 101 matDesc := materialDescriptor{ 102 strconv.Itoa(int(sfe.SMKID)), 103 sfe.QueryID, 104 strconv.Itoa(keySize * 8), 105 } 106 107 return &encryptMetadata{ 108 base64.StdEncoding.EncodeToString(encryptedFileKey), 109 base64.StdEncoding.EncodeToString(ivData), 110 matdescToUnicode(matDesc), 111 }, nil 112} 113 114func encryptECB(encrypted []byte, fileKey []byte, decodedKey []byte) error { 115 block, _ := aes.NewCipher(decodedKey) 116 if len(fileKey)%block.BlockSize() != 0 { 117 return fmt.Errorf("input not full of blocks") 118 } 119 if len(encrypted) < len(fileKey) { 120 return fmt.Errorf("output length is smaller than input length") 121 } 122 for len(fileKey) > 0 { 123 block.Encrypt(encrypted, fileKey[:block.BlockSize()]) 124 encrypted = encrypted[block.BlockSize():] 125 fileKey = fileKey[block.BlockSize():] 126 } 127 return nil 128} 129 130func decryptECB(decrypted []byte, keyBytes []byte, decodedKey []byte) error { 131 block, _ := aes.NewCipher(decodedKey) 132 if len(keyBytes)%block.BlockSize() != 0 { 133 return fmt.Errorf("input not full of blocks") 134 } 135 if len(decrypted) < len(keyBytes) { 136 return fmt.Errorf("output length is smaller than input length") 137 } 138 for len(keyBytes) > 0 { 139 block.Decrypt(decrypted, keyBytes[:block.BlockSize()]) 140 keyBytes = keyBytes[block.BlockSize():] 141 decrypted = decrypted[block.BlockSize():] 142 } 143 return nil 144} 145 146func encryptFile( 147 sfe *snowflakeFileEncryption, 148 filename string, 149 chunkSize int, 150 tmpDir string) ( 151 *encryptMetadata, string, error) { 152 if chunkSize == 0 { 153 chunkSize = aes.BlockSize * 4 * 1024 154 } 155 tmpOutputFile, _ := ioutil.TempFile(tmpDir, baseName(filename)+"#") 156 infile, err := os.OpenFile(filename, os.O_CREATE|os.O_RDONLY, os.ModePerm) 157 if err != nil { 158 return nil, "", err 159 } 160 meta, err := encryptStream(sfe, infile, tmpOutputFile, chunkSize) 161 if err != nil { 162 return nil, "", err 163 } 164 return meta, tmpOutputFile.Name(), nil 165} 166 167func decryptFile( 168 metadata *encryptMetadata, 169 sfe *snowflakeFileEncryption, 170 filename string, 171 chunkSize int, 172 tmpDir string) ( 173 string, error) { 174 if chunkSize == 0 { 175 chunkSize = aes.BlockSize * 4 * 1024 176 } 177 decodedKey, _ := base64.StdEncoding.DecodeString(sfe.QueryStageMasterKey) 178 keyBytes, _ := base64.StdEncoding.DecodeString(metadata.key) // encrypted file key 179 ivBytes, _ := base64.StdEncoding.DecodeString(metadata.iv) 180 181 // decrypt file key 182 decryptedKey := make([]byte, len(keyBytes)) 183 if err := decryptECB(decryptedKey, keyBytes, decodedKey); err != nil { 184 return "", err 185 } 186 decryptedKey = paddingTrim(decryptedKey) 187 188 // decrypt file 189 block, _ := aes.NewCipher(decryptedKey) 190 mode := cipher.NewCBCDecrypter(block, ivBytes) 191 192 tmpOutputFile, err := ioutil.TempFile(tmpDir, baseName(filename)+"#") 193 if err != nil { 194 return "", err 195 } 196 defer tmpOutputFile.Close() 197 infile, err := os.OpenFile(filename, os.O_RDONLY, os.ModePerm) 198 if err != nil { 199 return "", err 200 } 201 defer infile.Close() 202 var totalFileSize int 203 var prevChunk []byte 204 for { 205 chunk := make([]byte, chunkSize) 206 n, err := infile.Read(chunk) 207 if n == 0 || err != nil { 208 break 209 } 210 totalFileSize += n 211 chunk = chunk[:n] 212 mode.CryptBlocks(chunk, chunk) 213 tmpOutputFile.Write(chunk) 214 prevChunk = chunk 215 } 216 if err != nil { 217 return "", err 218 } 219 if prevChunk != nil { 220 totalFileSize -= paddingOffset(prevChunk) 221 } 222 tmpOutputFile.Truncate(int64(totalFileSize)) 223 return tmpOutputFile.Name(), nil 224} 225 226type materialDescriptor struct { 227 SmkID string `json:"smkId"` 228 QueryID string `json:"queryId"` 229 KeySize string `json:"keySize"` 230} 231 232func matdescToUnicode(matdesc materialDescriptor) string { 233 s, _ := json.Marshal(&matdesc) 234 return string(s) 235} 236 237func getSecureRandom(byteLength int) []byte { 238 token := make([]byte, byteLength) 239 rand.Read(token) 240 return token 241} 242 243func padBytesLength(src []byte, blockSize int) []byte { 244 padLength := blockSize - len(src)%blockSize 245 padText := bytes.Repeat([]byte{byte(padLength)}, padLength) 246 return append(src, padText...) 247} 248 249func paddingTrim(src []byte) []byte { 250 unpadding := src[len(src)-1] 251 return src[:len(src)-int(unpadding)] 252} 253 254func paddingOffset(src []byte) int { 255 length := len(src) 256 return int(src[length-1]) 257} 258 259type contentKey struct { 260 KeyID string `json:"KeyId,omitempty"` 261 EncryptionKey string `json:"EncryptedKey,omitempty"` 262 Algorithm string `json:"Algorithm,omitempty"` 263} 264 265type encryptionAgent struct { 266 Protocol string `json:"Protocol,omitempty"` 267 EncryptionAlgorithm string `json:"EncryptionAlgorithm,omitempty"` 268} 269 270type keyMetadata struct { 271 EncryptionLibrary string `json:"EncryptionLibrary,omitempty"` 272} 273 274type encryptionData struct { 275 EncryptionMode string `json:"EncryptionMode,omitempty"` 276 WrappedContentKey contentKey `json:"WrappedContentKey,omitempty"` 277 EncryptionAgent encryptionAgent `json:"EncryptionAgent,omitempty"` 278 ContentEncryptionIV string `json:"ContentEncryptionIV,omitempty"` 279 KeyWrappingMetadata keyMetadata `json:"KeyWrappingMetadata,omitempty"` 280} 281