1 #pragma once
2 #include <array>
3 #include <openssl/err.h>
4 #include <openssl/evp.h>
5 #include <vector>
6 
7 using EVPKey = std::array<uint8_t, EVP_MAX_KEY_LENGTH>;
8 using EVPIv = std::array<uint8_t, EVP_MAX_IV_LENGTH>;
9 
10 class EVPCipherException : public std::runtime_error {
11 public:
EVPCipherException()12     EVPCipherException()
13         : std::runtime_error("EVP cipher error"), err_code{ERR_peek_last_error()}, err_msg_buf{0} {}
14 
what()15     virtual const char* what() {
16         return ERR_error_string(err_code, err_msg_buf);
17     }
18 
19     unsigned long err_code;
20     char err_msg_buf[128];
21 };
22 
23 class EVPCipher {
24 private:
25     std::function<void(EVP_CIPHER_CTX*)> cipherFreeFn = [](EVP_CIPHER_CTX* ptr) {
26 #if OPENSSL_VERSION_NUMBER < 0x10100000L
27         delete ptr;
28 #else
29         EVP_CIPHER_CTX_free(ptr);
30 #endif
31     };
32 
33 public:
EVPCipher(const EVP_CIPHER * type,const EVPKey & key,const EVPIv & iv,bool encrypt)34     EVPCipher(const EVP_CIPHER* type, const EVPKey& key, const EVPIv& iv, bool encrypt) {
35 #if OPENSSL_VERSION_NUMBER < 0x10100000L
36         ctx = std::unique_ptr<EVP_CIPHER_CTX, decltype(cipherFreeFn)>(new EVP_CIPHER_CTX,
37                                                                       cipherFreeFn);
38 #else
39         ctx = std::unique_ptr<EVP_CIPHER_CTX, decltype(cipherFreeFn)>(EVP_CIPHER_CTX_new(),
40                                                                       cipherFreeFn);
41 #endif
42         EVP_CIPHER_CTX_init(ctx.get());
43         EVP_CipherInit_ex(ctx.get(), type, nullptr, key.data(), iv.data(), encrypt ? 1 : 0);
44     }
45 
~EVPCipher()46     ~EVPCipher() {
47         EVP_CIPHER_CTX_cleanup(ctx.get());
48     }
49 
expandAccumulator(size_t inputSize)50     void expandAccumulator(size_t inputSize) {
51         auto block_size = EVP_CIPHER_CTX_block_size(ctx.get());
52         const auto maxIncrease = (((inputSize / block_size) + 1) * block_size);
53         accumulator.resize(accumulator.size() + maxIncrease);
54     }
55 
update(const std::vector<uint8_t> & data)56     void update(const std::vector<uint8_t>& data) {
57         const auto oldSize = accumulator.size();
58         expandAccumulator(data.size());
59         int encryptedSize = 0;
60         if (EVP_CipherUpdate(ctx.get(),
61                              accumulator.data() + oldSize,
62                              &encryptedSize,
63                              data.data(),
64                              data.size()) != 1) {
65             throw EVPCipherException();
66         }
67 
68         accumulator.resize(oldSize + encryptedSize);
69     }
70 
update(const std::string & data)71     void update(const std::string& data) {
72         const auto oldSize = accumulator.size();
73         expandAccumulator(data.size());
74         int encryptedSize = 0;
75 
76         auto data_ptr = reinterpret_cast<const uint8_t*>(data.data());
77         if (EVP_CipherUpdate(
78                 ctx.get(), accumulator.data() + oldSize, &encryptedSize, data_ptr, data.size()) !=
79             1) {
80             throw EVPCipherException();
81         }
82 
83         accumulator.resize(oldSize + encryptedSize);
84     }
85 
finalize()86     void finalize() {
87         if (finalized)
88             return;
89         const auto oldSize = accumulator.size();
90         // Add one more block size to the accumulator;
91         expandAccumulator(1);
92         int encryptSize = oldSize;
93         if (EVP_CipherFinal_ex(ctx.get(), accumulator.data() + oldSize, &encryptSize) != 1) {
94             throw EVPCipherException();
95         }
96         accumulator.resize(oldSize + encryptSize);
97         finalized = true;
98     }
99 
begin()100     std::vector<uint8_t>::iterator begin() {
101         finalize();
102         return accumulator.begin();
103     }
104 
end()105     std::vector<uint8_t>::iterator end() {
106         finalize();
107         return accumulator.end();
108     }
109 
cbegin()110     std::vector<uint8_t>::const_iterator cbegin() {
111         finalize();
112         return accumulator.cbegin();
113     }
114 
cend()115     std::vector<uint8_t>::const_iterator cend() {
116         finalize();
117         return accumulator.cend();
118     }
119 
120     std::unique_ptr<EVP_CIPHER_CTX, decltype(cipherFreeFn)> ctx;
121     std::vector<uint8_t> accumulator;
122     bool finalized = false;
123 };
124