1 // This file is part of Desktop App Toolkit,
2 // a set of libraries for developing nice desktop applications.
3 //
4 // For license and copyright information please follow this link:
5 // https://github.com/desktop-app/legal/blob/master/LEGAL
6 //
7 #pragma once
8 
9 #include "base/bytes.h"
10 #include "base/algorithm.h"
11 #include "base/basic_types.h"
12 
13 extern "C" {
14 #include <openssl/bn.h>
15 #include <openssl/sha.h>
16 #include <openssl/aes.h>
17 #include <openssl/modes.h>
18 #include <openssl/crypto.h>
19 #include <openssl/evp.h>
20 #include <openssl/hmac.h>
21 #include <openssl/rsa.h>
22 #include <openssl/pem.h>
23 #include <openssl/err.h>
24 } // extern "C"
25 
26 #ifdef small
27 #undef small
28 #endif // small
29 
30 namespace openssl {
31 
32 class Context {
33 public:
Context()34 	Context() : _data(BN_CTX_new()) {
35 	}
36 	Context(const Context &other) = delete;
Context(Context && other)37 	Context(Context &&other) : _data(base::take(other._data)) {
38 	}
39 	Context &operator=(const Context &other) = delete;
40 	Context &operator=(Context &&other) {
41 		_data = base::take(other._data);
42 		return *this;
43 	}
~Context()44 	~Context() {
45 		if (_data) {
46 			BN_CTX_free(_data);
47 		}
48 	}
49 
raw()50 	BN_CTX *raw() const {
51 		return _data;
52 	}
53 
54 private:
55 	BN_CTX *_data = nullptr;
56 
57 };
58 
59 class BigNum {
60 public:
61 	BigNum() = default;
BigNum(const BigNum & other)62 	BigNum(const BigNum &other)
63 	: _data((other.failed() || other.isZero())
64 		? nullptr
65 		: BN_dup(other.raw()))
66 	, _failed(other._failed) {
67 	}
BigNum(BigNum && other)68 	BigNum(BigNum &&other)
69 	: _data(std::exchange(other._data, nullptr))
70 	, _failed(std::exchange(other._failed, false)) {
71 	}
72 	BigNum &operator=(const BigNum &other) {
73 		if (other.failed()) {
74 			_failed = true;
75 		} else if (other.isZero()) {
76 			clear();
77 			_failed = false;
78 		} else if (!_data) {
79 			_data = BN_dup(other.raw());
80 			_failed = false;
81 		} else {
82 			_failed = !BN_copy(raw(), other.raw());
83 		}
84 		return *this;
85 	}
86 	BigNum &operator=(BigNum &&other) {
87 		std::swap(_data, other._data);
88 		std::swap(_failed, other._failed);
89 		return *this;
90 	}
~BigNum()91 	~BigNum() {
92 		clear();
93 	}
94 
BigNum(unsigned int word)95 	explicit BigNum(unsigned int word) : BigNum() {
96 		setWord(word);
97 	}
BigNum(bytes::const_span bytes)98 	explicit BigNum(bytes::const_span bytes) : BigNum() {
99 		setBytes(bytes);
100 	}
101 
setWord(unsigned int word)102 	BigNum &setWord(unsigned int word) {
103 		if (!word) {
104 			clear();
105 			_failed = false;
106 		} else {
107 			_failed = !BN_set_word(raw(), word);
108 		}
109 		return *this;
110 	}
setBytes(bytes::const_span bytes)111 	BigNum &setBytes(bytes::const_span bytes) {
112 		if (bytes.empty()) {
113 			clear();
114 			_failed = false;
115 		} else {
116 			_failed = !BN_bin2bn(
117 				reinterpret_cast<const unsigned char*>(bytes.data()),
118 				bytes.size(),
119 				raw());
120 		}
121 		return *this;
122 	}
123 
setAdd(const BigNum & a,const BigNum & b)124 	BigNum &setAdd(const BigNum &a, const BigNum &b) {
125 		if (a.failed() || b.failed()) {
126 			_failed = true;
127 		} else {
128 			_failed = !BN_add(raw(), a.raw(), b.raw());
129 		}
130 		return *this;
131 	}
setSub(const BigNum & a,const BigNum & b)132 	BigNum &setSub(const BigNum &a, const BigNum &b) {
133 		if (a.failed() || b.failed()) {
134 			_failed = true;
135 		} else {
136 			_failed = !BN_sub(raw(), a.raw(), b.raw());
137 		}
138 		return *this;
139 	}
140 	BigNum &setMul(
141 			const BigNum &a,
142 			const BigNum &b,
143 			const Context &context = Context()) {
144 		if (a.failed() || b.failed()) {
145 			_failed = true;
146 		} else {
147 			_failed = !BN_mul(raw(), a.raw(), b.raw(), context.raw());
148 		}
149 		return *this;
150 	}
151 	BigNum &setModAdd(
152 			const BigNum &a,
153 			const BigNum &b,
154 			const BigNum &m,
155 			const Context &context = Context()) {
156 		if (a.failed() || b.failed() || m.failed()) {
157 			_failed = true;
158 		} else if (a.isNegative() || b.isNegative() || m.isNegative()) {
159 			_failed = true;
160 		} else if (!BN_mod_add(raw(), a.raw(), b.raw(), m.raw(), context.raw())) {
161 			_failed = true;
162 		} else if (isNegative()) {
163 			_failed = true;
164 		} else {
165 			_failed = false;
166 		}
167 		return *this;
168 	}
169 	BigNum &setModSub(
170 			const BigNum &a,
171 			const BigNum &b,
172 			const BigNum &m,
173 			const Context &context = Context()) {
174 		if (a.failed() || b.failed() || m.failed()) {
175 			_failed = true;
176 		} else if (a.isNegative() || b.isNegative() || m.isNegative()) {
177 			_failed = true;
178 		} else if (!BN_mod_sub(raw(), a.raw(), b.raw(), m.raw(), context.raw())) {
179 			_failed = true;
180 		} else if (isNegative()) {
181 			_failed = true;
182 		} else {
183 			_failed = false;
184 		}
185 		return *this;
186 	}
187 	BigNum &setModMul(
188 			const BigNum &a,
189 			const BigNum &b,
190 			const BigNum &m,
191 			const Context &context = Context()) {
192 		if (a.failed() || b.failed() || m.failed()) {
193 			_failed = true;
194 		} else if (a.isNegative() || b.isNegative() || m.isNegative()) {
195 			_failed = true;
196 		} else if (!BN_mod_mul(raw(), a.raw(), b.raw(), m.raw(), context.raw())) {
197 			_failed = true;
198 		} else if (isNegative()) {
199 			_failed = true;
200 		} else {
201 			_failed = false;
202 		}
203 		return *this;
204 	}
205 	BigNum &setModInverse(
206 			const BigNum &a,
207 			const BigNum &m,
208 			const Context &context = Context()) {
209 		if (a.failed() || m.failed()) {
210 			_failed = true;
211 		} else if (a.isNegative() || m.isNegative()) {
212 			_failed = true;
213 		} else if (!BN_mod_inverse(raw(), a.raw(), m.raw(), context.raw())) {
214 			_failed = true;
215 		} else if (isNegative()) {
216 			_failed = true;
217 		} else {
218 			_failed = false;
219 		}
220 		return *this;
221 	}
222 	BigNum &setModExp(
223 			const BigNum &base,
224 			const BigNum &power,
225 			const BigNum &m,
226 			const Context &context = Context()) {
227 		if (base.failed() || power.failed() || m.failed()) {
228 			_failed = true;
229 		} else if (base.isNegative() || power.isNegative() || m.isNegative()) {
230 			_failed = true;
231 		} else if (!BN_mod_exp(raw(), base.raw(), power.raw(), m.raw(), context.raw())) {
232 			_failed = true;
233 		} else if (isNegative()) {
234 			_failed = true;
235 		} else {
236 			_failed = false;
237 		}
238 		return *this;
239 	}
240 	BigNum &setGcd(
241 			const BigNum &a,
242 			const BigNum &b,
243 			const Context &context = Context()) {
244 		if (a.failed() || b.failed()) {
245 			_failed = true;
246 		} else if (a.isNegative() || b.isNegative()) {
247 			_failed = true;
248 		} else if (!BN_gcd(raw(), a.raw(), b.raw(), context.raw())) {
249 			_failed = true;
250 		} else if (isNegative()) {
251 			_failed = true;
252 		} else {
253 			_failed = false;
254 		}
255 		return *this;
256 	}
257 
isZero()258 	[[nodiscard]] bool isZero() const {
259 		return !failed() && (!_data || BN_is_zero(raw()));
260 	}
261 
isOne()262 	[[nodiscard]] bool isOne() const {
263 		return !failed() && _data && BN_is_one(raw());
264 	}
265 
isNegative()266 	[[nodiscard]] bool isNegative() const {
267 		return !failed() && _data && BN_is_negative(raw());
268 	}
269 
270 	[[nodiscard]] bool isPrime(const Context &context = Context()) const {
271 		if (failed() || !_data) {
272 			return false;
273 		}
274 		constexpr auto kMillerRabinIterationCount = 30;
275 		const auto result = BN_is_prime_ex(
276 			raw(),
277 			kMillerRabinIterationCount,
278 			context.raw(),
279 			nullptr);
280 		if (result == 1) {
281 			return true;
282 		} else if (result != 0) {
283 			_failed = true;
284 		}
285 		return false;
286 	}
287 
subWord(unsigned int word)288 	BigNum &subWord(unsigned int word) {
289 		if (failed()) {
290 			return *this;
291 		} else if (!BN_sub_word(raw(), word)) {
292 			_failed = true;
293 		}
294 		return *this;
295 	}
296 	BigNum &divWord(BN_ULONG word, BN_ULONG *mod = nullptr) {
297 		Expects(word != 0);
298 
299 		const auto result = failed()
300 			? (BN_ULONG)-1
301 			: BN_div_word(raw(), word);
302 		if (result == (BN_ULONG)-1) {
303 			_failed = true;
304 		}
305 		if (mod) {
306 			*mod = result;
307 		}
308 		return *this;
309 	}
countModWord(BN_ULONG word)310 	[[nodiscard]] BN_ULONG countModWord(BN_ULONG word) const {
311 		Expects(word != 0);
312 
313 		return failed() ? (BN_ULONG)-1 : BN_mod_word(raw(), word);
314 	}
315 
bitsSize()316 	[[nodiscard]] int bitsSize() const {
317 		return failed() ? 0 : BN_num_bits(raw());
318 	}
bytesSize()319 	[[nodiscard]] int bytesSize() const {
320 		return failed() ? 0 : BN_num_bytes(raw());
321 	}
322 
getBytes()323 	[[nodiscard]] bytes::vector getBytes() const {
324 		if (failed()) {
325 			return {};
326 		}
327 		auto length = BN_num_bytes(raw());
328 		auto result = bytes::vector(length);
329 		auto resultSize = BN_bn2bin(
330 			raw(),
331 			reinterpret_cast<unsigned char*>(result.data()));
332 		Assert(resultSize == length);
333 		return result;
334 	}
335 
raw()336 	[[nodiscard]] BIGNUM *raw() {
337 		if (!_data) _data = BN_new();
338 		return _data;
339 	}
raw()340 	[[nodiscard]] const BIGNUM *raw() const {
341 		if (!_data) _data = BN_new();
342 		return _data;
343 	}
takeRaw()344 	[[nodiscard]] BIGNUM *takeRaw() {
345 		return _failed
346 			? nullptr
347 			: _data
348 			? std::exchange(_data, nullptr)
349 			: BN_new();
350 	}
351 
failed()352 	[[nodiscard]] bool failed() const {
353 		return _failed;
354 	}
355 
Add(const BigNum & a,const BigNum & b)356 	[[nodiscard]] static BigNum Add(const BigNum &a, const BigNum &b) {
357 		return BigNum().setAdd(a, b);
358 	}
Sub(const BigNum & a,const BigNum & b)359 	[[nodiscard]] static BigNum Sub(const BigNum &a, const BigNum &b) {
360 		return BigNum().setSub(a, b);
361 	}
362 	[[nodiscard]] static BigNum Mul(
363 			const BigNum &a,
364 			const BigNum &b,
365 			const Context &context = Context()) {
366 		return BigNum().setMul(a, b, context);
367 	}
368 	[[nodiscard]] static BigNum ModAdd(
369 			const BigNum &a,
370 			const BigNum &b,
371 			const BigNum &mod,
372 			const Context &context = Context()) {
373 		return BigNum().setModAdd(a, b, mod, context);
374 	}
375 	[[nodiscard]] static BigNum ModSub(
376 			const BigNum &a,
377 			const BigNum &b,
378 			const BigNum &mod,
379 			const Context &context = Context()) {
380 		return BigNum().setModSub(a, b, mod, context);
381 	}
382 	[[nodiscard]] static BigNum ModMul(
383 			const BigNum &a,
384 			const BigNum &b,
385 			const BigNum &mod,
386 			const Context &context = Context()) {
387 		return BigNum().setModMul(a, b, mod, context);
388 	}
389 	[[nodiscard]] static BigNum ModInverse(
390 			const BigNum &a,
391 			const BigNum &mod,
392 			const Context &context = Context()) {
393 		return BigNum().setModInverse(a, mod, context);
394 	}
395 	[[nodiscard]] static BigNum ModExp(
396 			const BigNum &base,
397 			const BigNum &power,
398 			const BigNum &mod,
399 			const Context &context = Context()) {
400 		return BigNum().setModExp(base, power, mod, context);
401 	}
Compare(const BigNum & a,const BigNum & b)402 	[[nodiscard]] static int Compare(const BigNum &a, const BigNum &b) {
403 		return a.failed() ? -1 : b.failed() ? 1 : BN_cmp(a.raw(), b.raw());
404 	}
405 	static void Div(
406 			BigNum *dv,
407 			BigNum *rem,
408 			const BigNum &a,
409 			const BigNum &b,
410 			const Context &context = Context()) {
411 		if (!dv && !rem) {
412 			return;
413 		} else if (a.failed()
414 			|| b.failed()
415 			|| !BN_div(
416 				dv ? dv->raw() : nullptr,
417 				rem ? rem->raw() : nullptr,
418 				a.raw(),
419 				b.raw(),
420 				context.raw())) {
421 			if (dv) {
422 				dv->_failed = true;
423 			}
424 			if (rem) {
425 				rem->_failed = true;
426 			}
427 		} else {
428 			if (dv) {
429 				dv->_failed = false;
430 			}
431 			if (rem) {
432 				rem->_failed = false;
433 			}
434 		}
435 	}
Failed()436 	[[nodiscard]] static BigNum Failed() {
437 		auto result = BigNum();
438 		result._failed = true;
439 		return result;
440 	}
441 
442 private:
clear()443 	void clear() {
444 		BN_clear_free(std::exchange(_data, nullptr));
445 	}
446 
447 	mutable BIGNUM *_data = nullptr;
448 	mutable bool _failed = false;
449 
450 };
451 
452 namespace details {
453 
454 template <typename Context, typename Method, typename Arg>
ShaUpdate(Context context,Method method,Arg && arg)455 inline void ShaUpdate(Context context, Method method, Arg &&arg) {
456 	const auto span = bytes::make_span(arg);
457 	method(context, span.data(), span.size());
458 }
459 
460 template <typename Context, typename Method, typename Arg, typename ...Args>
ShaUpdate(Context context,Method method,Arg && arg,Args &&...args)461 inline void ShaUpdate(Context context, Method method, Arg &&arg, Args &&...args) {
462 	const auto span = bytes::make_span(arg);
463 	method(context, span.data(), span.size());
464 	ShaUpdate(context, method, args...);
465 }
466 
467 template <size_type Size, typename Method>
Sha(bytes::span dst,Method method,bytes::const_span data)468 inline void Sha(
469 		bytes::span dst,
470 		Method method,
471 		bytes::const_span data) {
472 	Expects(dst.size() >= Size);
473 
474 	method(
475 		reinterpret_cast<const unsigned char*>(data.data()),
476 		data.size(),
477 		reinterpret_cast<unsigned char*>(dst.data()));
478 }
479 
480 template <size_type Size, typename Method>
Sha(Method method,bytes::const_span data)481 [[nodiscard]] inline bytes::vector Sha(
482 		Method method,
483 		bytes::const_span data) {
484 	auto bytes = bytes::vector(Size);
485 	Sha<Size>(bytes, method, data);
486 	return bytes;
487 }
488 
489 template <
490 	size_type Size,
491 	typename Context,
492 	typename Init,
493 	typename Update,
494 	typename Finalize,
495 	typename ...Args,
496 	typename = std::enable_if_t<(sizeof...(Args) > 1)>>
Sha(Context context,Init init,Update update,Finalize finalize,Args &&...args)497 [[nodiscard]] bytes::vector Sha(
498 		Context context,
499 		Init init,
500 		Update update,
501 		Finalize finalize,
502 		Args &&...args) {
503 	auto bytes = bytes::vector(Size);
504 
505 	init(&context);
506 	ShaUpdate(&context, update, args...);
507 	finalize(reinterpret_cast<unsigned char*>(bytes.data()), &context);
508 
509 	return bytes;
510 }
511 
512 template <
513 	size_type Size,
514 	typename Evp>
Pbkdf2(bytes::const_span password,bytes::const_span salt,int iterations,Evp evp)515 [[nodiscard]] bytes::vector Pbkdf2(
516 		bytes::const_span password,
517 		bytes::const_span salt,
518 		int iterations,
519 		Evp evp) {
520 	auto result = bytes::vector(Size);
521 	PKCS5_PBKDF2_HMAC(
522 		reinterpret_cast<const char*>(password.data()),
523 		password.size(),
524 		reinterpret_cast<const unsigned char*>(salt.data()),
525 		salt.size(),
526 		iterations,
527 		evp,
528 		result.size(),
529 		reinterpret_cast<unsigned char*>(result.data()));
530 	return result;
531 }
532 
533 } // namespace details
534 
535 constexpr auto kSha1Size = size_type(SHA_DIGEST_LENGTH);
536 constexpr auto kSha256Size = size_type(SHA256_DIGEST_LENGTH);
537 constexpr auto kSha512Size = size_type(SHA512_DIGEST_LENGTH);
538 
Sha1(bytes::const_span data)539 [[nodiscard]] inline bytes::vector Sha1(bytes::const_span data) {
540 	return details::Sha<kSha1Size>(SHA1, data);
541 }
542 
Sha1To(bytes::span dst,bytes::const_span data)543 inline void Sha1To(bytes::span dst, bytes::const_span data) {
544 	details::Sha<kSha1Size>(dst, SHA1, data);
545 }
546 
547 template <
548 	typename ...Args,
549 	typename = std::enable_if_t<(sizeof...(Args) > 1)>>
Sha1(Args &&...args)550 [[nodiscard]] inline bytes::vector Sha1(Args &&...args) {
551 	return details::Sha<kSha1Size>(
552 		SHA_CTX(),
553 		SHA1_Init,
554 		SHA1_Update,
555 		SHA1_Final,
556 		args...);
557 }
558 
Sha256(bytes::const_span data)559 [[nodiscard]] inline bytes::vector Sha256(bytes::const_span data) {
560 	return details::Sha<kSha256Size>(SHA256, data);
561 }
562 
Sha256To(bytes::span dst,bytes::const_span data)563 inline void Sha256To(bytes::span dst, bytes::const_span data) {
564 	details::Sha<kSha256Size>(dst, SHA256, data);
565 }
566 
567 template <
568 	typename ...Args,
569 	typename = std::enable_if_t<(sizeof...(Args) > 1)>>
Sha256(Args &&...args)570 [[nodiscard]] inline bytes::vector Sha256(Args &&...args) {
571 	return details::Sha<kSha256Size>(
572 		SHA256_CTX(),
573 		SHA256_Init,
574 		SHA256_Update,
575 		SHA256_Final,
576 		args...);
577 }
578 
Sha512(bytes::const_span data)579 [[nodiscard]] inline bytes::vector Sha512(bytes::const_span data) {
580 	return details::Sha<kSha512Size>(SHA512, data);
581 }
582 
Sha512To(bytes::span dst,bytes::const_span data)583 inline void Sha512To(bytes::span dst, bytes::const_span data) {
584 	details::Sha<kSha512Size>(dst, SHA512, data);
585 }
586 
587 template <
588 	typename ...Args,
589 	typename = std::enable_if_t<(sizeof...(Args) > 1)>>
Sha512(Args &&...args)590 [[nodiscard]] inline bytes::vector Sha512(Args &&...args) {
591 	return details::Sha<kSha512Size>(
592 		SHA512_CTX(),
593 		SHA512_Init,
594 		SHA512_Update,
595 		SHA512_Final,
596 		args...);
597 }
598 
Pbkdf2Sha512(bytes::const_span password,bytes::const_span salt,int iterations)599 inline bytes::vector Pbkdf2Sha512(
600 		bytes::const_span password,
601 		bytes::const_span salt,
602 		int iterations) {
603 	return details::Pbkdf2<kSha512Size>(
604 		password,
605 		salt,
606 		iterations,
607 		EVP_sha512());
608 }
609 
HmacSha256(bytes::const_span key,bytes::const_span data)610 inline bytes::vector HmacSha256(
611 		bytes::const_span key,
612 		bytes::const_span data) {
613 	auto result = bytes::vector(kSha256Size);
614 	auto length = (unsigned int)kSha256Size;
615 
616 	HMAC(
617 		EVP_sha256(),
618 		key.data(),
619 		key.size(),
620 		reinterpret_cast<const unsigned char*>(data.data()),
621 		data.size(),
622 		reinterpret_cast<unsigned char*>(result.data()),
623 		&length);
624 
625 	return result;
626 }
627 
628 } // namespace openssl
629