1 #include <algorithm>
2 #include <pichi/common/asserts.hpp>
3 #include <pichi/common/literals.hpp>
4 #include <pichi/crypto/aead.hpp>
5 #include <pichi/crypto/hash.hpp>
6 #include <sodium/crypto_aead_chacha20poly1305.h>
7 #include <sodium/crypto_aead_xchacha20poly1305.h>
8 #include <sodium/randombytes.h>
9 #include <sodium/utils.h>
10 
11 using namespace std;
12 
13 namespace pichi::crypto {
14 
15 template <CryptoMethod method>
initialize(AeadContext<method> & ctx,ConstBuffer<uint8_t> ikm,ConstBuffer<uint8_t> salt)16 static void initialize(AeadContext<method>& ctx, ConstBuffer<uint8_t> ikm,
17                        ConstBuffer<uint8_t> salt)
18 {
19   suppressC4100(ctx);
20   assertTrue(ikm.size() == KEY_SIZE<method>, PichiError::CRYPTO_ERROR);
21   assertTrue(salt.size() == IV_SIZE<method>, PichiError::CRYPTO_ERROR);
22   if constexpr (detail::isGcm<method>()) {
23     auto skey = array<uint8_t, KEY_SIZE<method>>{};
24     hkdf<HashAlgorithm::SHA1>(skey, ikm, salt);
25     mbedtls_gcm_init(&ctx);
26     assertTrue(mbedtls_gcm_setkey(&ctx, MBEDTLS_CIPHER_ID_AES, skey.data(),
27                                   static_cast<unsigned int>(skey.size() * 8)) == 0,
28                PichiError::CRYPTO_ERROR);
29   }
30   else if constexpr (detail::isSodiumAead<method>()) {
31     hkdf<HashAlgorithm::SHA1>(ctx, ikm, salt);
32   }
33   else
34     static_assert(detail::DependentFalse<method>::value);
35 }
36 
release(AeadContext<method> & ctx)37 template <CryptoMethod method> static void release(AeadContext<method>& ctx)
38 {
39   suppressC4100(ctx);
40   if constexpr (detail::isGcm<method>())
41     mbedtls_gcm_free(&ctx);
42   else
43     static_assert(detail::isSodiumAead<method>());
44 }
45 
46 template <CryptoMethod method>
encrypt(AeadContext<method> & ctx,ConstBuffer<uint8_t> nonce,ConstBuffer<uint8_t> plain,MutableBuffer<uint8_t> cipher)47 static void encrypt(AeadContext<method>& ctx, ConstBuffer<uint8_t> nonce,
48                     ConstBuffer<uint8_t> plain, MutableBuffer<uint8_t> cipher)
49 {
50   suppressC4100(ctx);
51   assertTrue(nonce.size() == NONCE_SIZE<method>, PichiError::CRYPTO_ERROR);
52   assertTrue(cipher.size() >= plain.size() + TAG_SIZE<method>, PichiError::CRYPTO_ERROR);
53   if constexpr (detail::isGcm<method>()) {
54     assertTrue(mbedtls_gcm_crypt_and_tag(&ctx, MBEDTLS_GCM_ENCRYPT, plain.size(), nonce.data(),
55                                          nonce.size(), nullptr, 0, plain.data(), cipher.data(),
56                                          TAG_SIZE<method>, cipher.data() + plain.size()) == 0,
57                PichiError::CRYPTO_ERROR);
58   }
59   else if constexpr (method == CryptoMethod::CHACHA20_IETF_POLY1305) {
60     auto clen = static_cast<unsigned long long>(plain.size() + TAG_SIZE<method>);
61     assertTrue(crypto_aead_chacha20poly1305_ietf_encrypt(cipher.data(), &clen, plain.data(),
62                                                          plain.size(), nullptr, 0, nullptr,
63                                                          nonce.data(), ctx.data()) == 0,
64                PichiError::CRYPTO_ERROR);
65   }
66   else if constexpr (method == CryptoMethod::XCHACHA20_IETF_POLY1305) {
67     auto clen = static_cast<unsigned long long>(plain.size() + TAG_SIZE<method>);
68     assertTrue(crypto_aead_xchacha20poly1305_ietf_encrypt(cipher.data(), &clen, plain.data(),
69                                                           plain.size(), nullptr, 0, nullptr,
70                                                           nonce.data(), ctx.data()) == 0,
71                PichiError::CRYPTO_ERROR);
72   }
73   else
74     static_assert(detail::DependentFalse<method>::value);
75 }
76 
77 template <CryptoMethod method>
decrypt(AeadContext<method> & ctx,ConstBuffer<uint8_t> nonce,ConstBuffer<uint8_t> cipher,MutableBuffer<uint8_t> plain)78 static void decrypt(AeadContext<method>& ctx, ConstBuffer<uint8_t> nonce,
79                     ConstBuffer<uint8_t> cipher, MutableBuffer<uint8_t> plain)
80 {
81   suppressC4100(ctx);
82   assertTrue(nonce.size() == NONCE_SIZE<method>, PichiError::CRYPTO_ERROR);
83   assertTrue(plain.size() + TAG_SIZE<method> >= cipher.size(), PichiError::CRYPTO_ERROR);
84   if constexpr (detail::isGcm<method>()) {
85     assertTrue(mbedtls_gcm_auth_decrypt(&ctx, cipher.size() - TAG_SIZE<method>, nonce.data(),
86                                         nonce.size(), nullptr, 0,
87                                         cipher.data() + cipher.size() - TAG_SIZE<method>,
88                                         TAG_SIZE<method>, cipher.data(), plain.data()) == 0,
89                PichiError::CRYPTO_ERROR);
90   }
91   else if constexpr (method == CryptoMethod::CHACHA20_IETF_POLY1305) {
92     auto mlen = 0ull;
93     assertTrue(crypto_aead_chacha20poly1305_ietf_decrypt(plain.data(), &mlen, nullptr,
94                                                          cipher.data(), cipher.size(), nullptr, 0,
95                                                          nonce.data(), ctx.data()) == 0,
96                PichiError::CRYPTO_ERROR);
97   }
98   else if constexpr (method == CryptoMethod::XCHACHA20_IETF_POLY1305) {
99     auto mlen = 0ull;
100     assertTrue(crypto_aead_xchacha20poly1305_ietf_decrypt(plain.data(), &mlen, nullptr,
101                                                           cipher.data(), cipher.size(), nullptr, 0,
102                                                           nonce.data(), ctx.data()) == 0,
103                PichiError::CRYPTO_ERROR);
104   }
105   else
106     static_assert(detail::DependentFalse<method>::value);
107 }
108 
109 template <CryptoMethod method>
AeadEncryptor(ConstBuffer<uint8_t> key,ConstBuffer<uint8_t> salt)110 AeadEncryptor<method>::AeadEncryptor(ConstBuffer<uint8_t> key, ConstBuffer<uint8_t> salt)
111 {
112   if (salt.size() == 0) {
113     randombytes_buf(salt_.data(), IV_SIZE<method>);
114   }
115   else {
116     assertTrue(salt.size() == IV_SIZE<method>, PichiError::CRYPTO_ERROR);
117     copy_n(cbegin(salt), IV_SIZE<method>, begin(salt_));
118   }
119   fill_n(begin(nonce_), NONCE_SIZE<method>, 0_u8);
120   initialize<method>(ctx_, key, salt_);
121 }
122 
~AeadEncryptor()123 template <CryptoMethod method> AeadEncryptor<method>::~AeadEncryptor() { release<method>(ctx_); }
124 
getIv() const125 template <CryptoMethod method> ConstBuffer<uint8_t> AeadEncryptor<method>::getIv() const
126 {
127   return salt_;
128 }
129 
130 template <CryptoMethod method>
encrypt(ConstBuffer<uint8_t> plain,MutableBuffer<uint8_t> cipher)131 size_t AeadEncryptor<method>::encrypt(ConstBuffer<uint8_t> plain, MutableBuffer<uint8_t> cipher)
132 {
133   assertTrue(plain.size() <= 0x3fff, PichiError::CRYPTO_ERROR);
134   assertTrue(cipher.size() >= plain.size() + TAG_SIZE<method>, PichiError::CRYPTO_ERROR);
135   pichi::crypto::encrypt<method>(ctx_, nonce_, plain, cipher);
136   sodium_increment(nonce_.data(), NONCE_SIZE<method>);
137   return plain.size() + TAG_SIZE<method>;
138 }
139 
140 template class AeadEncryptor<CryptoMethod::AES_128_GCM>;
141 template class AeadEncryptor<CryptoMethod::AES_192_GCM>;
142 template class AeadEncryptor<CryptoMethod::AES_256_GCM>;
143 template class AeadEncryptor<CryptoMethod::CHACHA20_IETF_POLY1305>;
144 template class AeadEncryptor<CryptoMethod::XCHACHA20_IETF_POLY1305>;
145 
AeadDecryptor(ConstBuffer<uint8_t> key)146 template <CryptoMethod method> AeadDecryptor<method>::AeadDecryptor(ConstBuffer<uint8_t> key)
147 {
148   assertTrue(key.size() == KEY_SIZE<method>, PichiError::CRYPTO_ERROR);
149   copy_n(cbegin(key), KEY_SIZE<method>, begin(ikm_));
150   fill_n(begin(nonce_), NONCE_SIZE<method>, 0_u8);
151 }
152 
~AeadDecryptor()153 template <CryptoMethod method> AeadDecryptor<method>::~AeadDecryptor()
154 {
155   if (initialized_) release<method>(ctx_);
156 }
157 
getIvSize() const158 template <CryptoMethod method> size_t AeadDecryptor<method>::getIvSize() const
159 {
160   return IV_SIZE<method>;
161 }
162 
setIv(ConstBuffer<uint8_t> iv)163 template <CryptoMethod method> void AeadDecryptor<method>::setIv(ConstBuffer<uint8_t> iv)
164 {
165   assertTrue(iv.size() == IV_SIZE<method>, PichiError::CRYPTO_ERROR);
166   initialize<method>(ctx_, ikm_, iv);
167   initialized_ = true;
168 }
169 
170 template <CryptoMethod method>
decrypt(ConstBuffer<uint8_t> cipher,MutableBuffer<uint8_t> plain)171 size_t AeadDecryptor<method>::decrypt(ConstBuffer<uint8_t> cipher, MutableBuffer<uint8_t> plain)
172 {
173   assertTrue(cipher.size() > TAG_SIZE<method>, PichiError::CRYPTO_ERROR);
174   assertTrue(plain.size() >= cipher.size() - TAG_SIZE<method>, PichiError::CRYPTO_ERROR);
175   pichi::crypto::decrypt<method>(ctx_, nonce_, cipher, plain);
176   sodium_increment(nonce_.data(), NONCE_SIZE<method>);
177   return cipher.size() - TAG_SIZE<method>;
178 }
179 
180 template class AeadDecryptor<CryptoMethod::AES_128_GCM>;
181 template class AeadDecryptor<CryptoMethod::AES_192_GCM>;
182 template class AeadDecryptor<CryptoMethod::AES_256_GCM>;
183 template class AeadDecryptor<CryptoMethod::CHACHA20_IETF_POLY1305>;
184 template class AeadDecryptor<CryptoMethod::XCHACHA20_IETF_POLY1305>;
185 
186 }  // namespace pichi::crypto
187