1 #include <algorithm>
2 #include <cmath>
3 #include <iomanip>
4 #include <iostream>
5 #include <sstream>
6 #include <regex>
7 #include <string>
8 #include <unordered_map>
9 #include <vector>
10 
11 #include <openssl/crypto.h>
12 #include <openssl/evp.h>
13 #include <openssl/hmac.h>
14 
15 namespace {
base32Decode(const std::string & encoded)16 std::vector<uint8_t> base32Decode(const std::string& encoded) {
17     std::vector<uint8_t> ret;
18 
19     unsigned int curByte = 0;
20     int bits = 0;
21 
22     for (auto ch : encoded) {
23         if (ch >= 'A' && ch <= 'Z') {
24             ch -= 'A';
25         } else if (ch >= '2' && ch <= '7') {
26             ch -= '2';
27             ch += 26;
28         }
29 
30         curByte = (curByte << 5) | ch;
31         bits += 5;
32 
33         if (bits >= 8) {
34             ret.push_back((curByte >> (bits - 8)) & 255);
35             bits -= 8;
36         }
37     }
38 
39     return std::move(ret);
40 }
41 
42 class HMACWrapper {
43 private:
__anon2bf57c1e0202(HMAC_CTX* ptr) 44     std::function<void(HMAC_CTX*)> hmacFreeFn = [](HMAC_CTX* ptr) {
45 #if OPENSSL_VERSION_NUMBER < 0x10100000L || defined(LIBRESSL_VERSION_NUMBER)
46         abort();
47 #else
48         HMAC_CTX_free(ptr);
49 #endif
50     };
51 
52 public:
HMACWrapper(const EVP_MD * digest,const std::vector<uint8_t> & key)53     HMACWrapper(const EVP_MD* digest, const std::vector<uint8_t>& key) {
54 #if OPENSSL_VERSION_NUMBER < 0x10100000L || defined(LIBRESSL_VERSION_NUMBER)
55         ctx = std::unique_ptr<HMAC_CTX>(new HMAC_CTX);
56         HMAC_CTX_init(ctx.get());
57 #else
58         ctx = std::unique_ptr<HMAC_CTX, decltype(hmacFreeFn)>(HMAC_CTX_new(), hmacFreeFn);
59 #endif
60         HMAC_Init_ex(ctx.get(), key.data(), key.size(), digest, nullptr);
61     }
62 
update(const std::vector<uint8_t> & s)63     void update(const std::vector<uint8_t>& s) {
64         HMAC_Update(ctx.get(),
65                     reinterpret_cast<const unsigned char*>(s.data()),
66                     static_cast<unsigned int>(s.size()));
67     }
68 
finalize()69     std::vector<uint8_t> finalize() {
70         std::vector<uint8_t> finalHash(HMAC_size(ctx.get()));
71         unsigned int size;
72         HMAC_Final(ctx.get(), finalHash.data(), &size);
73         if (size != finalHash.size()) {
74             throw std::runtime_error("Overflow while finalizing HMAC");
75         }
76         return std::move(finalHash);
77     }
78 
79 private:
80     std::unique_ptr<HMAC_CTX, decltype(hmacFreeFn)> ctx;
81 };
82 
calculateTOTPInternal(const EVP_MD * algo,uint64_t counter,int digits,const std::vector<uint8_t> & key)83 std::string calculateTOTPInternal(const EVP_MD* algo,
84                                   uint64_t counter,
85                                   int digits,
86                                   const std::vector<uint8_t>& key) {
87     std::vector<uint8_t> counterArr(8);
88     for (int i = 7; i >= 0; i--) {
89         counterArr[i] = counter & 0xff;
90         counter >>= 8;
91     }
92 
93     HMACWrapper hmac(algo, key);
94     hmac.update(counterArr);
95     auto finalHmac = hmac.finalize();
96 
97     const auto offset = finalHmac[19] & 0xf;
98     uint32_t truncated = (finalHmac[offset] & 0x7f) << 24 | (finalHmac[offset + 1] & 0xff) << 16 |
99         (finalHmac[offset + 2] & 0xff) << 8 | (finalHmac[offset + 3] & 0xff);
100 
101     std::stringstream ss;
102     const int digitsPow = pow(10, digits);
103     ss << std::setw(digits) << std::setfill('0') << truncated % digitsPow;
104     return ss.str();
105 }
106 
107 const auto kURIPattern = "otpauth://(totp|hotp)/(?:[^\\?]+)\\?((?:(?:[^=]+)=(?:[^\\&]+)\\&?)+)";
108 const auto kURIRegex = std::regex(kURIPattern, std::regex::ECMAScript | std::regex::optimize);
109 const auto kQueryComponentPattern = "([^=]+)=([^\\&]+)&?";
110 const auto kQueryComponentRegex =
111     std::regex(kQueryComponentPattern, std::regex::ECMAScript | std::regex::optimize);
112 
113 }  // namespace
114 
isTOTPURI(const std::string & uri)115 bool isTOTPURI(const std::string& uri) {
116     std::smatch match;
117     return std::regex_match(uri, match, kURIRegex);
118 }
119 
calculateTOTP(const std::string & uri)120 std::string calculateTOTP(const std::string& uri) {
121     std::smatch match;
122     if (!std::regex_match(uri, match, kURIRegex)) {
123         throw std::runtime_error("Error parsing OTP URI");
124     }
125 
126     const auto type = match[1].str();
127     if (type != "totp") {
128         std::stringstream ss;
129         ss << "Unsupported OTP format: " << type;
130         throw std::runtime_error(ss.str());
131     }
132 
133     const auto query = match[2].str();
134     auto begin = std::sregex_iterator(query.begin(), query.end(), kQueryComponentRegex);
135     const auto end = std::sregex_iterator();
136     std::unordered_map<std::string, std::string> params;
137     for (auto it = begin; it != end; ++it) {
138         auto curParam = *it;
139         params.emplace(curParam[1], curParam[2]);
140     }
141 
142     if (params.find("secret") == params.end()) {
143         throw std::runtime_error("OTP URI is missing a key");
144     }
145 
146     const EVP_MD* algo = EVP_sha1();
147     if (params.find("algorithm") != params.end()) {
148         const auto& algoStr = params["algorithm"];
149         if (algoStr == "SHA256") {
150             algo = EVP_sha256();
151         } else if (algoStr == "SHA512") {
152             algo = EVP_sha512();
153         } else if (algoStr != "SHA1") {
154             std::stringstream ss;
155             ss << "Unsurpported algorithm in OTP URI: " << algoStr;
156             throw std::runtime_error(ss.str());
157         }
158     }
159 
160     auto now = std::time(nullptr);
161     int period = 30;
162     if (params.find("period") != params.end()) {
163         period = std::stoi(params["period"]);
164     }
165 
166     const auto key = base32Decode(params["secret"]);
167     uint64_t counter = now / period;
168 
169     int digits = 6;
170     if (params.find("digits") != params.end()) {
171         digits = std::stoi(params["digits"]);
172     }
173 
174     return calculateTOTPInternal(algo, counter, digits, key);
175 }
176