1 #ifndef SIMPLE_WEB_UTILITY_HPP 2 #define SIMPLE_WEB_UTILITY_HPP 3 4 #include "status_code.hpp" 5 #include <atomic> 6 #include <chrono> 7 #include <cstdlib> 8 #include <ctime> 9 #include <iostream> 10 #include <memory> 11 #include <mutex> 12 #include <string> 13 #include <unordered_map> 14 15 #if __cplusplus > 201402L || _MSVC_LANG > 201402L 16 #include <string_view> 17 namespace SimpleWeb { 18 using string_view = std::string_view; 19 } 20 #elif !defined(USE_STANDALONE_ASIO) 21 #include <boost/utility/string_ref.hpp> 22 namespace SimpleWeb { 23 using string_view = boost::string_ref; 24 } 25 #else 26 namespace SimpleWeb { 27 using string_view = const std::string &; 28 } 29 #endif 30 31 namespace SimpleWeb { case_insensitive_equal(const std::string & str1,const std::string & str2)32 inline bool case_insensitive_equal(const std::string &str1, const std::string &str2) noexcept { 33 return str1.size() == str2.size() && 34 std::equal(str1.begin(), str1.end(), str2.begin(), [](char a, char b) { 35 return tolower(a) == tolower(b); 36 }); 37 } 38 class CaseInsensitiveEqual { 39 public: operator ()(const std::string & str1,const std::string & str2) const40 bool operator()(const std::string &str1, const std::string &str2) const noexcept { 41 return case_insensitive_equal(str1, str2); 42 } 43 }; 44 // Based on https://stackoverflow.com/questions/2590677/how-do-i-combine-hash-values-in-c0x/2595226#2595226 45 class CaseInsensitiveHash { 46 public: operator ()(const std::string & str) const47 std::size_t operator()(const std::string &str) const noexcept { 48 std::size_t h = 0; 49 std::hash<int> hash; 50 for(auto c : str) 51 h ^= hash(tolower(c)) + 0x9e3779b9 + (h << 6) + (h >> 2); 52 return h; 53 } 54 }; 55 56 using CaseInsensitiveMultimap = std::unordered_multimap<std::string, std::string, CaseInsensitiveHash, CaseInsensitiveEqual>; 57 58 /// Percent encoding and decoding 59 class Percent { 60 public: 61 /// Returns percent-encoded string encode(const std::string & value)62 static std::string encode(const std::string &value) noexcept { 63 static auto hex_chars = "0123456789ABCDEF"; 64 65 std::string result; 66 result.reserve(value.size()); // Minimum size of result 67 68 for(auto &chr : value) { 69 if(!((chr >= '0' && chr <= '9') || (chr >= 'A' && chr <= 'Z') || (chr >= 'a' && chr <= 'z') || chr == '-' || chr == '.' || chr == '_' || chr == '~')) 70 result += std::string("%") + hex_chars[static_cast<unsigned char>(chr) >> 4] + hex_chars[static_cast<unsigned char>(chr) & 15]; 71 else 72 result += chr; 73 } 74 75 return result; 76 } 77 78 /// Returns percent-decoded string decode(const std::string & value)79 static std::string decode(const std::string &value) noexcept { 80 std::string result; 81 result.reserve(value.size() / 3 + (value.size() % 3)); // Minimum size of result 82 83 for(std::size_t i = 0; i < value.size(); ++i) { 84 auto &chr = value[i]; 85 if(chr == '%' && i + 2 < value.size()) { 86 auto hex = value.substr(i + 1, 2); 87 auto decoded_chr = static_cast<char>(std::strtol(hex.c_str(), nullptr, 16)); 88 result += decoded_chr; 89 i += 2; 90 } 91 else if(chr == '+') 92 result += ' '; 93 else 94 result += chr; 95 } 96 97 return result; 98 } 99 }; 100 101 /// Query string creation and parsing 102 class QueryString { 103 public: 104 /// Returns query string created from given field names and values create(const CaseInsensitiveMultimap & fields)105 static std::string create(const CaseInsensitiveMultimap &fields) noexcept { 106 std::string result; 107 108 bool first = true; 109 for(auto &field : fields) { 110 result += (!first ? "&" : "") + field.first + '=' + Percent::encode(field.second); 111 first = false; 112 } 113 114 return result; 115 } 116 117 /// Returns query keys with percent-decoded values. parse(const std::string & query_string)118 static CaseInsensitiveMultimap parse(const std::string &query_string) noexcept { 119 CaseInsensitiveMultimap result; 120 121 if(query_string.empty()) 122 return result; 123 124 std::size_t name_pos = 0; 125 auto name_end_pos = std::string::npos; 126 auto value_pos = std::string::npos; 127 for(std::size_t c = 0; c < query_string.size(); ++c) { 128 if(query_string[c] == '&') { 129 auto name = query_string.substr(name_pos, (name_end_pos == std::string::npos ? c : name_end_pos) - name_pos); 130 if(!name.empty()) { 131 auto value = value_pos == std::string::npos ? std::string() : query_string.substr(value_pos, c - value_pos); 132 result.emplace(std::move(name), Percent::decode(value)); 133 } 134 name_pos = c + 1; 135 name_end_pos = std::string::npos; 136 value_pos = std::string::npos; 137 } 138 else if(query_string[c] == '=') { 139 name_end_pos = c; 140 value_pos = c + 1; 141 } 142 } 143 if(name_pos < query_string.size()) { 144 auto name = query_string.substr(name_pos, name_end_pos - name_pos); 145 if(!name.empty()) { 146 auto value = value_pos >= query_string.size() ? std::string() : query_string.substr(value_pos); 147 result.emplace(std::move(name), Percent::decode(value)); 148 } 149 } 150 151 return result; 152 } 153 }; 154 155 class HttpHeader { 156 public: 157 /// Parse header fields from stream parse(std::istream & stream)158 static CaseInsensitiveMultimap parse(std::istream &stream) noexcept { 159 CaseInsensitiveMultimap result; 160 std::string line; 161 std::size_t param_end; 162 while(getline(stream, line) && (param_end = line.find(':')) != std::string::npos) { 163 std::size_t value_start = param_end + 1; 164 while(value_start + 1 < line.size() && line[value_start] == ' ') 165 ++value_start; 166 if(value_start < line.size()) 167 result.emplace(line.substr(0, param_end), line.substr(value_start, line.size() - value_start - (line.back() == '\r' ? 1 : 0))); 168 } 169 return result; 170 } 171 172 class FieldValue { 173 public: 174 class SemicolonSeparatedAttributes { 175 public: 176 /// Parse Set-Cookie or Content-Disposition from given header field value. 177 /// Attribute values are percent-decoded. parse(const std::string & value)178 static CaseInsensitiveMultimap parse(const std::string &value) { 179 CaseInsensitiveMultimap result; 180 181 std::size_t name_start_pos = std::string::npos; 182 std::size_t name_end_pos = std::string::npos; 183 std::size_t value_start_pos = std::string::npos; 184 for(std::size_t c = 0; c < value.size(); ++c) { 185 if(name_start_pos == std::string::npos) { 186 if(value[c] != ' ' && value[c] != ';') 187 name_start_pos = c; 188 } 189 else { 190 if(name_end_pos == std::string::npos) { 191 if(value[c] == ';') { 192 result.emplace(value.substr(name_start_pos, c - name_start_pos), std::string()); 193 name_start_pos = std::string::npos; 194 } 195 else if(value[c] == '=') 196 name_end_pos = c; 197 } 198 else { 199 if(value_start_pos == std::string::npos) { 200 if(value[c] == '"' && c + 1 < value.size()) 201 value_start_pos = c + 1; 202 else 203 value_start_pos = c; 204 } 205 else if(value[c] == '"' || value[c] == ';') { 206 result.emplace(value.substr(name_start_pos, name_end_pos - name_start_pos), Percent::decode(value.substr(value_start_pos, c - value_start_pos))); 207 name_start_pos = std::string::npos; 208 name_end_pos = std::string::npos; 209 value_start_pos = std::string::npos; 210 } 211 } 212 } 213 } 214 if(name_start_pos != std::string::npos) { 215 if(name_end_pos == std::string::npos) 216 result.emplace(value.substr(name_start_pos), std::string()); 217 else if(value_start_pos != std::string::npos) { 218 if(value.back() == '"') 219 result.emplace(value.substr(name_start_pos, name_end_pos - name_start_pos), Percent::decode(value.substr(value_start_pos, value.size() - 1))); 220 else 221 result.emplace(value.substr(name_start_pos, name_end_pos - name_start_pos), Percent::decode(value.substr(value_start_pos))); 222 } 223 } 224 225 return result; 226 } 227 }; 228 }; 229 }; 230 231 class RequestMessage { 232 public: 233 /** Parse request line and header fields from a request stream. 234 * 235 * @param[in] stream Stream to parse. 236 * @param[out] method HTTP method. 237 * @param[out] path Path from request URI. 238 * @param[out] query_string Query string from request URI. 239 * @param[out] version HTTP version. 240 * @param[out] header Header fields. 241 * 242 * @return True if stream is parsed successfully, false if not. 243 */ parse(std::istream & stream,std::string & method,std::string & path,std::string & query_string,std::string & version,CaseInsensitiveMultimap & header)244 static bool parse(std::istream &stream, std::string &method, std::string &path, std::string &query_string, std::string &version, CaseInsensitiveMultimap &header) noexcept { 245 std::string line; 246 std::size_t method_end; 247 if(getline(stream, line) && (method_end = line.find(' ')) != std::string::npos) { 248 method = line.substr(0, method_end); 249 250 std::size_t query_start = std::string::npos; 251 std::size_t path_and_query_string_end = std::string::npos; 252 for(std::size_t i = method_end + 1; i < line.size(); ++i) { 253 if(line[i] == '?' && (i + 1) < line.size()) 254 query_start = i + 1; 255 else if(line[i] == ' ') { 256 path_and_query_string_end = i; 257 break; 258 } 259 } 260 if(path_and_query_string_end != std::string::npos) { 261 if(query_start != std::string::npos) { 262 path = line.substr(method_end + 1, query_start - method_end - 2); 263 query_string = line.substr(query_start, path_and_query_string_end - query_start); 264 } 265 else 266 path = line.substr(method_end + 1, path_and_query_string_end - method_end - 1); 267 268 std::size_t protocol_end; 269 if((protocol_end = line.find('/', path_and_query_string_end + 1)) != std::string::npos) { 270 if(line.compare(path_and_query_string_end + 1, protocol_end - path_and_query_string_end - 1, "HTTP") != 0) 271 return false; 272 version = line.substr(protocol_end + 1, line.size() - protocol_end - 2); 273 } 274 else 275 return false; 276 277 header = HttpHeader::parse(stream); 278 } 279 else 280 return false; 281 } 282 else 283 return false; 284 return true; 285 } 286 }; 287 288 class ResponseMessage { 289 public: 290 /** Parse status line and header fields from a response stream. 291 * 292 * @param[in] stream Stream to parse. 293 * @param[out] version HTTP version. 294 * @param[out] status_code HTTP status code. 295 * @param[out] header Header fields. 296 * 297 * @return True if stream is parsed successfully, false if not. 298 */ parse(std::istream & stream,std::string & version,std::string & status_code,CaseInsensitiveMultimap & header)299 static bool parse(std::istream &stream, std::string &version, std::string &status_code, CaseInsensitiveMultimap &header) noexcept { 300 std::string line; 301 std::size_t version_end; 302 if(getline(stream, line) && (version_end = line.find(' ')) != std::string::npos) { 303 if(5 < line.size()) 304 version = line.substr(5, version_end - 5); 305 else 306 return false; 307 if((version_end + 1) < line.size()) 308 status_code = line.substr(version_end + 1, line.size() - (version_end + 1) - (line.back() == '\r' ? 1 : 0)); 309 else 310 return false; 311 312 header = HttpHeader::parse(stream); 313 } 314 else 315 return false; 316 return true; 317 } 318 }; 319 320 /// Date class working with formats specified in RFC 7231 Date/Time Formats 321 class Date { 322 public: 323 /// Returns the given std::chrono::system_clock::time_point as a string with the following format: Wed, 31 Jul 2019 11:34:23 GMT. 324 /// Warning: while this function is thread safe with other Date::to_string() calls, 325 /// it is not thread safe with other functions that include calls to std::gmtime. to_string(const std::chrono::system_clock::time_point time_point)326 static std::string to_string(const std::chrono::system_clock::time_point time_point) noexcept { 327 static std::string result_cache; 328 static std::chrono::system_clock::time_point last_time_point; 329 330 static std::mutex mutex; 331 std::lock_guard<std::mutex> lock(mutex); 332 333 if(std::chrono::duration_cast<std::chrono::seconds>(time_point - last_time_point).count() == 0 && !result_cache.empty()) 334 return result_cache; 335 336 last_time_point = time_point; 337 338 std::string result; 339 result.reserve(29); 340 341 auto time = std::chrono::system_clock::to_time_t(time_point); 342 auto gmtime = std::gmtime(&time); 343 344 switch(gmtime->tm_wday) { 345 case 0: result += "Sun, "; break; 346 case 1: result += "Mon, "; break; 347 case 2: result += "Tue, "; break; 348 case 3: result += "Wed, "; break; 349 case 4: result += "Thu, "; break; 350 case 5: result += "Fri, "; break; 351 case 6: result += "Sat, "; break; 352 } 353 354 result += gmtime->tm_mday < 10 ? '0' : static_cast<char>(gmtime->tm_mday / 10 + 48); 355 result += static_cast<char>(gmtime->tm_mday % 10 + 48); 356 357 switch(gmtime->tm_mon) { 358 case 0: result += " Jan "; break; 359 case 1: result += " Feb "; break; 360 case 2: result += " Mar "; break; 361 case 3: result += " Apr "; break; 362 case 4: result += " May "; break; 363 case 5: result += " Jun "; break; 364 case 6: result += " Jul "; break; 365 case 7: result += " Aug "; break; 366 case 8: result += " Sep "; break; 367 case 9: result += " Oct "; break; 368 case 10: result += " Nov "; break; 369 case 11: result += " Dec "; break; 370 } 371 372 auto year = gmtime->tm_year + 1900; 373 result += static_cast<char>(year / 1000 + 48); 374 result += static_cast<char>((year / 100) % 10 + 48); 375 result += static_cast<char>((year / 10) % 10 + 48); 376 result += static_cast<char>(year % 10 + 48); 377 result += ' '; 378 379 result += gmtime->tm_hour < 10 ? '0' : static_cast<char>(gmtime->tm_hour / 10 + 48); 380 result += static_cast<char>(gmtime->tm_hour % 10 + 48); 381 result += ':'; 382 383 result += gmtime->tm_min < 10 ? '0' : static_cast<char>(gmtime->tm_min / 10 + 48); 384 result += static_cast<char>(gmtime->tm_min % 10 + 48); 385 result += ':'; 386 387 result += gmtime->tm_sec < 10 ? '0' : static_cast<char>(gmtime->tm_sec / 10 + 48); 388 result += static_cast<char>(gmtime->tm_sec % 10 + 48); 389 390 result += " GMT"; 391 392 result_cache = result; 393 return result; 394 } 395 }; 396 } // namespace SimpleWeb 397 398 #ifdef __SSE2__ 399 #include <emmintrin.h> 400 namespace SimpleWeb { spin_loop_pause()401 inline void spin_loop_pause() noexcept { _mm_pause(); } 402 } // namespace SimpleWeb 403 // TODO: need verification that the following checks are correct: 404 #elif defined(_MSC_VER) && _MSC_VER >= 1800 && (defined(_M_X64) || defined(_M_IX86)) 405 #include <intrin.h> 406 namespace SimpleWeb { spin_loop_pause()407 inline void spin_loop_pause() noexcept { _mm_pause(); } 408 } // namespace SimpleWeb 409 #else 410 namespace SimpleWeb { spin_loop_pause()411 inline void spin_loop_pause() noexcept {} 412 } // namespace SimpleWeb 413 #endif 414 415 namespace SimpleWeb { 416 /// Makes it possible to for instance cancel Asio handlers without stopping asio::io_service. 417 class ScopeRunner { 418 /// Scope count that is set to -1 if scopes are to be canceled. 419 std::atomic<long> count; 420 421 public: 422 class SharedLock { 423 friend class ScopeRunner; 424 std::atomic<long> &count; SharedLock(std::atomic<long> & count)425 SharedLock(std::atomic<long> &count) noexcept : count(count) {} 426 SharedLock &operator=(const SharedLock &) = delete; 427 SharedLock(const SharedLock &) = delete; 428 429 public: ~SharedLock()430 ~SharedLock() noexcept { 431 count.fetch_sub(1); 432 } 433 }; 434 ScopeRunner()435 ScopeRunner() noexcept : count(0) {} 436 437 /// Returns nullptr if scope should be exited, or a shared lock otherwise. 438 /// The shared lock ensures that a potential destructor call is delayed until all locks are released. continue_lock()439 std::unique_ptr<SharedLock> continue_lock() noexcept { 440 long expected = count; 441 while(expected >= 0 && !count.compare_exchange_weak(expected, expected + 1)) 442 spin_loop_pause(); 443 444 if(expected < 0) 445 return nullptr; 446 else 447 return std::unique_ptr<SharedLock>(new SharedLock(count)); 448 } 449 450 /// Blocks until all shared locks are released, then prevents future shared locks. stop()451 void stop() noexcept { 452 long expected = 0; 453 while(!count.compare_exchange_weak(expected, -1)) { 454 if(expected < 0) 455 return; 456 expected = 0; 457 spin_loop_pause(); 458 } 459 } 460 }; 461 } // namespace SimpleWeb 462 463 #endif // SIMPLE_WEB_UTILITY_HPP 464