1 // This file is part of Desktop App Toolkit,
2 // a set of libraries for developing nice desktop applications.
3 //
4 // For license and copyright information please follow this link:
5 // https://github.com/desktop-app/legal/blob/master/LEGAL
6 //
7 #pragma once
8
9 #include "base/bytes.h"
10 #include "base/algorithm.h"
11 #include "base/basic_types.h"
12
13 extern "C" {
14 #include <openssl/bn.h>
15 #include <openssl/sha.h>
16 #include <openssl/aes.h>
17 #include <openssl/modes.h>
18 #include <openssl/crypto.h>
19 #include <openssl/evp.h>
20 #include <openssl/hmac.h>
21 #include <openssl/rsa.h>
22 #include <openssl/pem.h>
23 #include <openssl/err.h>
24 } // extern "C"
25
26 #ifdef small
27 #undef small
28 #endif // small
29
30 namespace openssl {
31
32 class Context {
33 public:
Context()34 Context() : _data(BN_CTX_new()) {
35 }
36 Context(const Context &other) = delete;
Context(Context && other)37 Context(Context &&other) : _data(base::take(other._data)) {
38 }
39 Context &operator=(const Context &other) = delete;
40 Context &operator=(Context &&other) {
41 _data = base::take(other._data);
42 return *this;
43 }
~Context()44 ~Context() {
45 if (_data) {
46 BN_CTX_free(_data);
47 }
48 }
49
raw()50 BN_CTX *raw() const {
51 return _data;
52 }
53
54 private:
55 BN_CTX *_data = nullptr;
56
57 };
58
59 class BigNum {
60 public:
61 BigNum() = default;
BigNum(const BigNum & other)62 BigNum(const BigNum &other)
63 : _data((other.failed() || other.isZero())
64 ? nullptr
65 : BN_dup(other.raw()))
66 , _failed(other._failed) {
67 }
BigNum(BigNum && other)68 BigNum(BigNum &&other)
69 : _data(std::exchange(other._data, nullptr))
70 , _failed(std::exchange(other._failed, false)) {
71 }
72 BigNum &operator=(const BigNum &other) {
73 if (other.failed()) {
74 _failed = true;
75 } else if (other.isZero()) {
76 clear();
77 _failed = false;
78 } else if (!_data) {
79 _data = BN_dup(other.raw());
80 _failed = false;
81 } else {
82 _failed = !BN_copy(raw(), other.raw());
83 }
84 return *this;
85 }
86 BigNum &operator=(BigNum &&other) {
87 std::swap(_data, other._data);
88 std::swap(_failed, other._failed);
89 return *this;
90 }
~BigNum()91 ~BigNum() {
92 clear();
93 }
94
BigNum(unsigned int word)95 explicit BigNum(unsigned int word) : BigNum() {
96 setWord(word);
97 }
BigNum(bytes::const_span bytes)98 explicit BigNum(bytes::const_span bytes) : BigNum() {
99 setBytes(bytes);
100 }
101
setWord(unsigned int word)102 BigNum &setWord(unsigned int word) {
103 if (!word) {
104 clear();
105 _failed = false;
106 } else {
107 _failed = !BN_set_word(raw(), word);
108 }
109 return *this;
110 }
setBytes(bytes::const_span bytes)111 BigNum &setBytes(bytes::const_span bytes) {
112 if (bytes.empty()) {
113 clear();
114 _failed = false;
115 } else {
116 _failed = !BN_bin2bn(
117 reinterpret_cast<const unsigned char*>(bytes.data()),
118 bytes.size(),
119 raw());
120 }
121 return *this;
122 }
123
setAdd(const BigNum & a,const BigNum & b)124 BigNum &setAdd(const BigNum &a, const BigNum &b) {
125 if (a.failed() || b.failed()) {
126 _failed = true;
127 } else {
128 _failed = !BN_add(raw(), a.raw(), b.raw());
129 }
130 return *this;
131 }
setSub(const BigNum & a,const BigNum & b)132 BigNum &setSub(const BigNum &a, const BigNum &b) {
133 if (a.failed() || b.failed()) {
134 _failed = true;
135 } else {
136 _failed = !BN_sub(raw(), a.raw(), b.raw());
137 }
138 return *this;
139 }
140 BigNum &setMul(
141 const BigNum &a,
142 const BigNum &b,
143 const Context &context = Context()) {
144 if (a.failed() || b.failed()) {
145 _failed = true;
146 } else {
147 _failed = !BN_mul(raw(), a.raw(), b.raw(), context.raw());
148 }
149 return *this;
150 }
151 BigNum &setModAdd(
152 const BigNum &a,
153 const BigNum &b,
154 const BigNum &m,
155 const Context &context = Context()) {
156 if (a.failed() || b.failed() || m.failed()) {
157 _failed = true;
158 } else if (a.isNegative() || b.isNegative() || m.isNegative()) {
159 _failed = true;
160 } else if (!BN_mod_add(raw(), a.raw(), b.raw(), m.raw(), context.raw())) {
161 _failed = true;
162 } else if (isNegative()) {
163 _failed = true;
164 } else {
165 _failed = false;
166 }
167 return *this;
168 }
169 BigNum &setModSub(
170 const BigNum &a,
171 const BigNum &b,
172 const BigNum &m,
173 const Context &context = Context()) {
174 if (a.failed() || b.failed() || m.failed()) {
175 _failed = true;
176 } else if (a.isNegative() || b.isNegative() || m.isNegative()) {
177 _failed = true;
178 } else if (!BN_mod_sub(raw(), a.raw(), b.raw(), m.raw(), context.raw())) {
179 _failed = true;
180 } else if (isNegative()) {
181 _failed = true;
182 } else {
183 _failed = false;
184 }
185 return *this;
186 }
187 BigNum &setModMul(
188 const BigNum &a,
189 const BigNum &b,
190 const BigNum &m,
191 const Context &context = Context()) {
192 if (a.failed() || b.failed() || m.failed()) {
193 _failed = true;
194 } else if (a.isNegative() || b.isNegative() || m.isNegative()) {
195 _failed = true;
196 } else if (!BN_mod_mul(raw(), a.raw(), b.raw(), m.raw(), context.raw())) {
197 _failed = true;
198 } else if (isNegative()) {
199 _failed = true;
200 } else {
201 _failed = false;
202 }
203 return *this;
204 }
205 BigNum &setModInverse(
206 const BigNum &a,
207 const BigNum &m,
208 const Context &context = Context()) {
209 if (a.failed() || m.failed()) {
210 _failed = true;
211 } else if (a.isNegative() || m.isNegative()) {
212 _failed = true;
213 } else if (!BN_mod_inverse(raw(), a.raw(), m.raw(), context.raw())) {
214 _failed = true;
215 } else if (isNegative()) {
216 _failed = true;
217 } else {
218 _failed = false;
219 }
220 return *this;
221 }
222 BigNum &setModExp(
223 const BigNum &base,
224 const BigNum &power,
225 const BigNum &m,
226 const Context &context = Context()) {
227 if (base.failed() || power.failed() || m.failed()) {
228 _failed = true;
229 } else if (base.isNegative() || power.isNegative() || m.isNegative()) {
230 _failed = true;
231 } else if (!BN_mod_exp(raw(), base.raw(), power.raw(), m.raw(), context.raw())) {
232 _failed = true;
233 } else if (isNegative()) {
234 _failed = true;
235 } else {
236 _failed = false;
237 }
238 return *this;
239 }
240 BigNum &setGcd(
241 const BigNum &a,
242 const BigNum &b,
243 const Context &context = Context()) {
244 if (a.failed() || b.failed()) {
245 _failed = true;
246 } else if (a.isNegative() || b.isNegative()) {
247 _failed = true;
248 } else if (!BN_gcd(raw(), a.raw(), b.raw(), context.raw())) {
249 _failed = true;
250 } else if (isNegative()) {
251 _failed = true;
252 } else {
253 _failed = false;
254 }
255 return *this;
256 }
257
isZero()258 [[nodiscard]] bool isZero() const {
259 return !failed() && (!_data || BN_is_zero(raw()));
260 }
261
isOne()262 [[nodiscard]] bool isOne() const {
263 return !failed() && _data && BN_is_one(raw());
264 }
265
isNegative()266 [[nodiscard]] bool isNegative() const {
267 return !failed() && _data && BN_is_negative(raw());
268 }
269
270 [[nodiscard]] bool isPrime(const Context &context = Context()) const {
271 if (failed() || !_data) {
272 return false;
273 }
274 constexpr auto kMillerRabinIterationCount = 30;
275 const auto result = BN_is_prime_ex(
276 raw(),
277 kMillerRabinIterationCount,
278 context.raw(),
279 nullptr);
280 if (result == 1) {
281 return true;
282 } else if (result != 0) {
283 _failed = true;
284 }
285 return false;
286 }
287
subWord(unsigned int word)288 BigNum &subWord(unsigned int word) {
289 if (failed()) {
290 return *this;
291 } else if (!BN_sub_word(raw(), word)) {
292 _failed = true;
293 }
294 return *this;
295 }
296 BigNum &divWord(BN_ULONG word, BN_ULONG *mod = nullptr) {
297 Expects(word != 0);
298
299 const auto result = failed()
300 ? (BN_ULONG)-1
301 : BN_div_word(raw(), word);
302 if (result == (BN_ULONG)-1) {
303 _failed = true;
304 }
305 if (mod) {
306 *mod = result;
307 }
308 return *this;
309 }
countModWord(BN_ULONG word)310 [[nodiscard]] BN_ULONG countModWord(BN_ULONG word) const {
311 Expects(word != 0);
312
313 return failed() ? (BN_ULONG)-1 : BN_mod_word(raw(), word);
314 }
315
bitsSize()316 [[nodiscard]] int bitsSize() const {
317 return failed() ? 0 : BN_num_bits(raw());
318 }
bytesSize()319 [[nodiscard]] int bytesSize() const {
320 return failed() ? 0 : BN_num_bytes(raw());
321 }
322
getBytes()323 [[nodiscard]] bytes::vector getBytes() const {
324 if (failed()) {
325 return {};
326 }
327 auto length = BN_num_bytes(raw());
328 auto result = bytes::vector(length);
329 auto resultSize = BN_bn2bin(
330 raw(),
331 reinterpret_cast<unsigned char*>(result.data()));
332 Assert(resultSize == length);
333 return result;
334 }
335
raw()336 [[nodiscard]] BIGNUM *raw() {
337 if (!_data) _data = BN_new();
338 return _data;
339 }
raw()340 [[nodiscard]] const BIGNUM *raw() const {
341 if (!_data) _data = BN_new();
342 return _data;
343 }
takeRaw()344 [[nodiscard]] BIGNUM *takeRaw() {
345 return _failed
346 ? nullptr
347 : _data
348 ? std::exchange(_data, nullptr)
349 : BN_new();
350 }
351
failed()352 [[nodiscard]] bool failed() const {
353 return _failed;
354 }
355
Add(const BigNum & a,const BigNum & b)356 [[nodiscard]] static BigNum Add(const BigNum &a, const BigNum &b) {
357 return BigNum().setAdd(a, b);
358 }
Sub(const BigNum & a,const BigNum & b)359 [[nodiscard]] static BigNum Sub(const BigNum &a, const BigNum &b) {
360 return BigNum().setSub(a, b);
361 }
362 [[nodiscard]] static BigNum Mul(
363 const BigNum &a,
364 const BigNum &b,
365 const Context &context = Context()) {
366 return BigNum().setMul(a, b, context);
367 }
368 [[nodiscard]] static BigNum ModAdd(
369 const BigNum &a,
370 const BigNum &b,
371 const BigNum &mod,
372 const Context &context = Context()) {
373 return BigNum().setModAdd(a, b, mod, context);
374 }
375 [[nodiscard]] static BigNum ModSub(
376 const BigNum &a,
377 const BigNum &b,
378 const BigNum &mod,
379 const Context &context = Context()) {
380 return BigNum().setModSub(a, b, mod, context);
381 }
382 [[nodiscard]] static BigNum ModMul(
383 const BigNum &a,
384 const BigNum &b,
385 const BigNum &mod,
386 const Context &context = Context()) {
387 return BigNum().setModMul(a, b, mod, context);
388 }
389 [[nodiscard]] static BigNum ModInverse(
390 const BigNum &a,
391 const BigNum &mod,
392 const Context &context = Context()) {
393 return BigNum().setModInverse(a, mod, context);
394 }
395 [[nodiscard]] static BigNum ModExp(
396 const BigNum &base,
397 const BigNum &power,
398 const BigNum &mod,
399 const Context &context = Context()) {
400 return BigNum().setModExp(base, power, mod, context);
401 }
Compare(const BigNum & a,const BigNum & b)402 [[nodiscard]] static int Compare(const BigNum &a, const BigNum &b) {
403 return a.failed() ? -1 : b.failed() ? 1 : BN_cmp(a.raw(), b.raw());
404 }
405 static void Div(
406 BigNum *dv,
407 BigNum *rem,
408 const BigNum &a,
409 const BigNum &b,
410 const Context &context = Context()) {
411 if (!dv && !rem) {
412 return;
413 } else if (a.failed()
414 || b.failed()
415 || !BN_div(
416 dv ? dv->raw() : nullptr,
417 rem ? rem->raw() : nullptr,
418 a.raw(),
419 b.raw(),
420 context.raw())) {
421 if (dv) {
422 dv->_failed = true;
423 }
424 if (rem) {
425 rem->_failed = true;
426 }
427 } else {
428 if (dv) {
429 dv->_failed = false;
430 }
431 if (rem) {
432 rem->_failed = false;
433 }
434 }
435 }
Failed()436 [[nodiscard]] static BigNum Failed() {
437 auto result = BigNum();
438 result._failed = true;
439 return result;
440 }
441
442 private:
clear()443 void clear() {
444 BN_clear_free(std::exchange(_data, nullptr));
445 }
446
447 mutable BIGNUM *_data = nullptr;
448 mutable bool _failed = false;
449
450 };
451
452 namespace details {
453
454 template <typename Context, typename Method, typename Arg>
ShaUpdate(Context context,Method method,Arg && arg)455 inline void ShaUpdate(Context context, Method method, Arg &&arg) {
456 const auto span = bytes::make_span(arg);
457 method(context, span.data(), span.size());
458 }
459
460 template <typename Context, typename Method, typename Arg, typename ...Args>
ShaUpdate(Context context,Method method,Arg && arg,Args &&...args)461 inline void ShaUpdate(Context context, Method method, Arg &&arg, Args &&...args) {
462 const auto span = bytes::make_span(arg);
463 method(context, span.data(), span.size());
464 ShaUpdate(context, method, args...);
465 }
466
467 template <size_type Size, typename Method>
Sha(bytes::span dst,Method method,bytes::const_span data)468 inline void Sha(
469 bytes::span dst,
470 Method method,
471 bytes::const_span data) {
472 Expects(dst.size() >= Size);
473
474 method(
475 reinterpret_cast<const unsigned char*>(data.data()),
476 data.size(),
477 reinterpret_cast<unsigned char*>(dst.data()));
478 }
479
480 template <size_type Size, typename Method>
Sha(Method method,bytes::const_span data)481 [[nodiscard]] inline bytes::vector Sha(
482 Method method,
483 bytes::const_span data) {
484 auto bytes = bytes::vector(Size);
485 Sha<Size>(bytes, method, data);
486 return bytes;
487 }
488
489 template <
490 size_type Size,
491 typename Context,
492 typename Init,
493 typename Update,
494 typename Finalize,
495 typename ...Args,
496 typename = std::enable_if_t<(sizeof...(Args) > 1)>>
Sha(Context context,Init init,Update update,Finalize finalize,Args &&...args)497 [[nodiscard]] bytes::vector Sha(
498 Context context,
499 Init init,
500 Update update,
501 Finalize finalize,
502 Args &&...args) {
503 auto bytes = bytes::vector(Size);
504
505 init(&context);
506 ShaUpdate(&context, update, args...);
507 finalize(reinterpret_cast<unsigned char*>(bytes.data()), &context);
508
509 return bytes;
510 }
511
512 template <
513 size_type Size,
514 typename Evp>
Pbkdf2(bytes::const_span password,bytes::const_span salt,int iterations,Evp evp)515 [[nodiscard]] bytes::vector Pbkdf2(
516 bytes::const_span password,
517 bytes::const_span salt,
518 int iterations,
519 Evp evp) {
520 auto result = bytes::vector(Size);
521 PKCS5_PBKDF2_HMAC(
522 reinterpret_cast<const char*>(password.data()),
523 password.size(),
524 reinterpret_cast<const unsigned char*>(salt.data()),
525 salt.size(),
526 iterations,
527 evp,
528 result.size(),
529 reinterpret_cast<unsigned char*>(result.data()));
530 return result;
531 }
532
533 } // namespace details
534
535 constexpr auto kSha1Size = size_type(SHA_DIGEST_LENGTH);
536 constexpr auto kSha256Size = size_type(SHA256_DIGEST_LENGTH);
537 constexpr auto kSha512Size = size_type(SHA512_DIGEST_LENGTH);
538
Sha1(bytes::const_span data)539 [[nodiscard]] inline bytes::vector Sha1(bytes::const_span data) {
540 return details::Sha<kSha1Size>(SHA1, data);
541 }
542
Sha1To(bytes::span dst,bytes::const_span data)543 inline void Sha1To(bytes::span dst, bytes::const_span data) {
544 details::Sha<kSha1Size>(dst, SHA1, data);
545 }
546
547 template <
548 typename ...Args,
549 typename = std::enable_if_t<(sizeof...(Args) > 1)>>
Sha1(Args &&...args)550 [[nodiscard]] inline bytes::vector Sha1(Args &&...args) {
551 return details::Sha<kSha1Size>(
552 SHA_CTX(),
553 SHA1_Init,
554 SHA1_Update,
555 SHA1_Final,
556 args...);
557 }
558
Sha256(bytes::const_span data)559 [[nodiscard]] inline bytes::vector Sha256(bytes::const_span data) {
560 return details::Sha<kSha256Size>(SHA256, data);
561 }
562
Sha256To(bytes::span dst,bytes::const_span data)563 inline void Sha256To(bytes::span dst, bytes::const_span data) {
564 details::Sha<kSha256Size>(dst, SHA256, data);
565 }
566
567 template <
568 typename ...Args,
569 typename = std::enable_if_t<(sizeof...(Args) > 1)>>
Sha256(Args &&...args)570 [[nodiscard]] inline bytes::vector Sha256(Args &&...args) {
571 return details::Sha<kSha256Size>(
572 SHA256_CTX(),
573 SHA256_Init,
574 SHA256_Update,
575 SHA256_Final,
576 args...);
577 }
578
Sha512(bytes::const_span data)579 [[nodiscard]] inline bytes::vector Sha512(bytes::const_span data) {
580 return details::Sha<kSha512Size>(SHA512, data);
581 }
582
Sha512To(bytes::span dst,bytes::const_span data)583 inline void Sha512To(bytes::span dst, bytes::const_span data) {
584 details::Sha<kSha512Size>(dst, SHA512, data);
585 }
586
587 template <
588 typename ...Args,
589 typename = std::enable_if_t<(sizeof...(Args) > 1)>>
Sha512(Args &&...args)590 [[nodiscard]] inline bytes::vector Sha512(Args &&...args) {
591 return details::Sha<kSha512Size>(
592 SHA512_CTX(),
593 SHA512_Init,
594 SHA512_Update,
595 SHA512_Final,
596 args...);
597 }
598
Pbkdf2Sha512(bytes::const_span password,bytes::const_span salt,int iterations)599 inline bytes::vector Pbkdf2Sha512(
600 bytes::const_span password,
601 bytes::const_span salt,
602 int iterations) {
603 return details::Pbkdf2<kSha512Size>(
604 password,
605 salt,
606 iterations,
607 EVP_sha512());
608 }
609
HmacSha256(bytes::const_span key,bytes::const_span data)610 inline bytes::vector HmacSha256(
611 bytes::const_span key,
612 bytes::const_span data) {
613 auto result = bytes::vector(kSha256Size);
614 auto length = (unsigned int)kSha256Size;
615
616 HMAC(
617 EVP_sha256(),
618 key.data(),
619 key.size(),
620 reinterpret_cast<const unsigned char*>(data.data()),
621 data.size(),
622 reinterpret_cast<unsigned char*>(result.data()),
623 &length);
624
625 return result;
626 }
627
628 } // namespace openssl
629