1 #include <algorithm>
2 #include <cmath>
3 #include <iomanip>
4 #include <iostream>
5 #include <sstream>
6 #include <regex>
7 #include <string>
8 #include <unordered_map>
9 #include <vector>
10
11 #include <openssl/crypto.h>
12 #include <openssl/evp.h>
13 #include <openssl/hmac.h>
14
15 namespace {
base32Decode(const std::string & encoded)16 std::vector<uint8_t> base32Decode(const std::string& encoded) {
17 std::vector<uint8_t> ret;
18
19 unsigned int curByte = 0;
20 int bits = 0;
21
22 for (auto ch : encoded) {
23 if (ch >= 'A' && ch <= 'Z') {
24 ch -= 'A';
25 } else if (ch >= '2' && ch <= '7') {
26 ch -= '2';
27 ch += 26;
28 }
29
30 curByte = (curByte << 5) | ch;
31 bits += 5;
32
33 if (bits >= 8) {
34 ret.push_back((curByte >> (bits - 8)) & 255);
35 bits -= 8;
36 }
37 }
38
39 return std::move(ret);
40 }
41
42 class HMACWrapper {
43 private:
__anon2bf57c1e0202(HMAC_CTX* ptr) 44 std::function<void(HMAC_CTX*)> hmacFreeFn = [](HMAC_CTX* ptr) {
45 #if OPENSSL_VERSION_NUMBER < 0x10100000L || defined(LIBRESSL_VERSION_NUMBER)
46 abort();
47 #else
48 HMAC_CTX_free(ptr);
49 #endif
50 };
51
52 public:
HMACWrapper(const EVP_MD * digest,const std::vector<uint8_t> & key)53 HMACWrapper(const EVP_MD* digest, const std::vector<uint8_t>& key) {
54 #if OPENSSL_VERSION_NUMBER < 0x10100000L || defined(LIBRESSL_VERSION_NUMBER)
55 ctx = std::unique_ptr<HMAC_CTX>(new HMAC_CTX);
56 HMAC_CTX_init(ctx.get());
57 #else
58 ctx = std::unique_ptr<HMAC_CTX, decltype(hmacFreeFn)>(HMAC_CTX_new(), hmacFreeFn);
59 #endif
60 HMAC_Init_ex(ctx.get(), key.data(), key.size(), digest, nullptr);
61 }
62
update(const std::vector<uint8_t> & s)63 void update(const std::vector<uint8_t>& s) {
64 HMAC_Update(ctx.get(),
65 reinterpret_cast<const unsigned char*>(s.data()),
66 static_cast<unsigned int>(s.size()));
67 }
68
finalize()69 std::vector<uint8_t> finalize() {
70 std::vector<uint8_t> finalHash(HMAC_size(ctx.get()));
71 unsigned int size;
72 HMAC_Final(ctx.get(), finalHash.data(), &size);
73 if (size != finalHash.size()) {
74 throw std::runtime_error("Overflow while finalizing HMAC");
75 }
76 return std::move(finalHash);
77 }
78
79 private:
80 std::unique_ptr<HMAC_CTX, decltype(hmacFreeFn)> ctx;
81 };
82
calculateTOTPInternal(const EVP_MD * algo,uint64_t counter,int digits,const std::vector<uint8_t> & key)83 std::string calculateTOTPInternal(const EVP_MD* algo,
84 uint64_t counter,
85 int digits,
86 const std::vector<uint8_t>& key) {
87 std::vector<uint8_t> counterArr(8);
88 for (int i = 7; i >= 0; i--) {
89 counterArr[i] = counter & 0xff;
90 counter >>= 8;
91 }
92
93 HMACWrapper hmac(algo, key);
94 hmac.update(counterArr);
95 auto finalHmac = hmac.finalize();
96
97 const auto offset = finalHmac[19] & 0xf;
98 uint32_t truncated = (finalHmac[offset] & 0x7f) << 24 | (finalHmac[offset + 1] & 0xff) << 16 |
99 (finalHmac[offset + 2] & 0xff) << 8 | (finalHmac[offset + 3] & 0xff);
100
101 std::stringstream ss;
102 const int digitsPow = pow(10, digits);
103 ss << std::setw(digits) << std::setfill('0') << truncated % digitsPow;
104 return ss.str();
105 }
106
107 const auto kURIPattern = "otpauth://(totp|hotp)/(?:[^\\?]+)\\?((?:(?:[^=]+)=(?:[^\\&]+)\\&?)+)";
108 const auto kURIRegex = std::regex(kURIPattern, std::regex::ECMAScript | std::regex::optimize);
109 const auto kQueryComponentPattern = "([^=]+)=([^\\&]+)&?";
110 const auto kQueryComponentRegex =
111 std::regex(kQueryComponentPattern, std::regex::ECMAScript | std::regex::optimize);
112
113 } // namespace
114
isTOTPURI(const std::string & uri)115 bool isTOTPURI(const std::string& uri) {
116 std::smatch match;
117 return std::regex_match(uri, match, kURIRegex);
118 }
119
calculateTOTP(const std::string & uri)120 std::string calculateTOTP(const std::string& uri) {
121 std::smatch match;
122 if (!std::regex_match(uri, match, kURIRegex)) {
123 throw std::runtime_error("Error parsing OTP URI");
124 }
125
126 const auto type = match[1].str();
127 if (type != "totp") {
128 std::stringstream ss;
129 ss << "Unsupported OTP format: " << type;
130 throw std::runtime_error(ss.str());
131 }
132
133 const auto query = match[2].str();
134 auto begin = std::sregex_iterator(query.begin(), query.end(), kQueryComponentRegex);
135 const auto end = std::sregex_iterator();
136 std::unordered_map<std::string, std::string> params;
137 for (auto it = begin; it != end; ++it) {
138 auto curParam = *it;
139 params.emplace(curParam[1], curParam[2]);
140 }
141
142 if (params.find("secret") == params.end()) {
143 throw std::runtime_error("OTP URI is missing a key");
144 }
145
146 const EVP_MD* algo = EVP_sha1();
147 if (params.find("algorithm") != params.end()) {
148 const auto& algoStr = params["algorithm"];
149 if (algoStr == "SHA256") {
150 algo = EVP_sha256();
151 } else if (algoStr == "SHA512") {
152 algo = EVP_sha512();
153 } else if (algoStr != "SHA1") {
154 std::stringstream ss;
155 ss << "Unsurpported algorithm in OTP URI: " << algoStr;
156 throw std::runtime_error(ss.str());
157 }
158 }
159
160 auto now = std::time(nullptr);
161 int period = 30;
162 if (params.find("period") != params.end()) {
163 period = std::stoi(params["period"]);
164 }
165
166 const auto key = base32Decode(params["secret"]);
167 uint64_t counter = now / period;
168
169 int digits = 6;
170 if (params.find("digits") != params.end()) {
171 digits = std::stoi(params["digits"]);
172 }
173
174 return calculateTOTPInternal(algo, counter, digits, key);
175 }
176