1 #ifndef SIMPLE_WEB_UTILITY_HPP
2 #define SIMPLE_WEB_UTILITY_HPP
3 
4 #include "status_code.hpp"
5 #include <atomic>
6 #include <chrono>
7 #include <cstdlib>
8 #include <ctime>
9 #include <iostream>
10 #include <memory>
11 #include <mutex>
12 #include <string>
13 #include <unordered_map>
14 
15 #if __cplusplus > 201402L || _MSVC_LANG > 201402L
16 #include <string_view>
17 namespace SimpleWeb {
18   using string_view = std::string_view;
19 }
20 #elif !defined(USE_STANDALONE_ASIO)
21 #include <boost/utility/string_ref.hpp>
22 namespace SimpleWeb {
23   using string_view = boost::string_ref;
24 }
25 #else
26 namespace SimpleWeb {
27   using string_view = const std::string &;
28 }
29 #endif
30 
31 namespace SimpleWeb {
case_insensitive_equal(const std::string & str1,const std::string & str2)32   inline bool case_insensitive_equal(const std::string &str1, const std::string &str2) noexcept {
33     return str1.size() == str2.size() &&
34            std::equal(str1.begin(), str1.end(), str2.begin(), [](char a, char b) {
35              return tolower(a) == tolower(b);
36            });
37   }
38   class CaseInsensitiveEqual {
39   public:
operator ()(const std::string & str1,const std::string & str2) const40     bool operator()(const std::string &str1, const std::string &str2) const noexcept {
41       return case_insensitive_equal(str1, str2);
42     }
43   };
44   // Based on https://stackoverflow.com/questions/2590677/how-do-i-combine-hash-values-in-c0x/2595226#2595226
45   class CaseInsensitiveHash {
46   public:
operator ()(const std::string & str) const47     std::size_t operator()(const std::string &str) const noexcept {
48       std::size_t h = 0;
49       std::hash<int> hash;
50       for(auto c : str)
51         h ^= hash(tolower(c)) + 0x9e3779b9 + (h << 6) + (h >> 2);
52       return h;
53     }
54   };
55 
56   using CaseInsensitiveMultimap = std::unordered_multimap<std::string, std::string, CaseInsensitiveHash, CaseInsensitiveEqual>;
57 
58   /// Percent encoding and decoding
59   class Percent {
60   public:
61     /// Returns percent-encoded string
encode(const std::string & value)62     static std::string encode(const std::string &value) noexcept {
63       static auto hex_chars = "0123456789ABCDEF";
64 
65       std::string result;
66       result.reserve(value.size()); // Minimum size of result
67 
68       for(auto &chr : value) {
69         if(!((chr >= '0' && chr <= '9') || (chr >= 'A' && chr <= 'Z') || (chr >= 'a' && chr <= 'z') || chr == '-' || chr == '.' || chr == '_' || chr == '~'))
70           result += std::string("%") + hex_chars[static_cast<unsigned char>(chr) >> 4] + hex_chars[static_cast<unsigned char>(chr) & 15];
71         else
72           result += chr;
73       }
74 
75       return result;
76     }
77 
78     /// Returns percent-decoded string
decode(const std::string & value)79     static std::string decode(const std::string &value) noexcept {
80       std::string result;
81       result.reserve(value.size() / 3 + (value.size() % 3)); // Minimum size of result
82 
83       for(std::size_t i = 0; i < value.size(); ++i) {
84         auto &chr = value[i];
85         if(chr == '%' && i + 2 < value.size()) {
86           auto hex = value.substr(i + 1, 2);
87           auto decoded_chr = static_cast<char>(std::strtol(hex.c_str(), nullptr, 16));
88           result += decoded_chr;
89           i += 2;
90         }
91         else if(chr == '+')
92           result += ' ';
93         else
94           result += chr;
95       }
96 
97       return result;
98     }
99   };
100 
101   /// Query string creation and parsing
102   class QueryString {
103   public:
104     /// Returns query string created from given field names and values
create(const CaseInsensitiveMultimap & fields)105     static std::string create(const CaseInsensitiveMultimap &fields) noexcept {
106       std::string result;
107 
108       bool first = true;
109       for(auto &field : fields) {
110         result += (!first ? "&" : "") + field.first + '=' + Percent::encode(field.second);
111         first = false;
112       }
113 
114       return result;
115     }
116 
117     /// Returns query keys with percent-decoded values.
parse(const std::string & query_string)118     static CaseInsensitiveMultimap parse(const std::string &query_string) noexcept {
119       CaseInsensitiveMultimap result;
120 
121       if(query_string.empty())
122         return result;
123 
124       std::size_t name_pos = 0;
125       auto name_end_pos = std::string::npos;
126       auto value_pos = std::string::npos;
127       for(std::size_t c = 0; c < query_string.size(); ++c) {
128         if(query_string[c] == '&') {
129           auto name = query_string.substr(name_pos, (name_end_pos == std::string::npos ? c : name_end_pos) - name_pos);
130           if(!name.empty()) {
131             auto value = value_pos == std::string::npos ? std::string() : query_string.substr(value_pos, c - value_pos);
132             result.emplace(std::move(name), Percent::decode(value));
133           }
134           name_pos = c + 1;
135           name_end_pos = std::string::npos;
136           value_pos = std::string::npos;
137         }
138         else if(query_string[c] == '=') {
139           name_end_pos = c;
140           value_pos = c + 1;
141         }
142       }
143       if(name_pos < query_string.size()) {
144         auto name = query_string.substr(name_pos, name_end_pos - name_pos);
145         if(!name.empty()) {
146           auto value = value_pos >= query_string.size() ? std::string() : query_string.substr(value_pos);
147           result.emplace(std::move(name), Percent::decode(value));
148         }
149       }
150 
151       return result;
152     }
153   };
154 
155   class HttpHeader {
156   public:
157     /// Parse header fields from stream
parse(std::istream & stream)158     static CaseInsensitiveMultimap parse(std::istream &stream) noexcept {
159       CaseInsensitiveMultimap result;
160       std::string line;
161       std::size_t param_end;
162       while(getline(stream, line) && (param_end = line.find(':')) != std::string::npos) {
163         std::size_t value_start = param_end + 1;
164         while(value_start + 1 < line.size() && line[value_start] == ' ')
165           ++value_start;
166         if(value_start < line.size())
167           result.emplace(line.substr(0, param_end), line.substr(value_start, line.size() - value_start - (line.back() == '\r' ? 1 : 0)));
168       }
169       return result;
170     }
171 
172     class FieldValue {
173     public:
174       class SemicolonSeparatedAttributes {
175       public:
176         /// Parse Set-Cookie or Content-Disposition from given header field value.
177         /// Attribute values are percent-decoded.
parse(const std::string & value)178         static CaseInsensitiveMultimap parse(const std::string &value) {
179           CaseInsensitiveMultimap result;
180 
181           std::size_t name_start_pos = std::string::npos;
182           std::size_t name_end_pos = std::string::npos;
183           std::size_t value_start_pos = std::string::npos;
184           for(std::size_t c = 0; c < value.size(); ++c) {
185             if(name_start_pos == std::string::npos) {
186               if(value[c] != ' ' && value[c] != ';')
187                 name_start_pos = c;
188             }
189             else {
190               if(name_end_pos == std::string::npos) {
191                 if(value[c] == ';') {
192                   result.emplace(value.substr(name_start_pos, c - name_start_pos), std::string());
193                   name_start_pos = std::string::npos;
194                 }
195                 else if(value[c] == '=')
196                   name_end_pos = c;
197               }
198               else {
199                 if(value_start_pos == std::string::npos) {
200                   if(value[c] == '"' && c + 1 < value.size())
201                     value_start_pos = c + 1;
202                   else
203                     value_start_pos = c;
204                 }
205                 else if(value[c] == '"' || value[c] == ';') {
206                   result.emplace(value.substr(name_start_pos, name_end_pos - name_start_pos), Percent::decode(value.substr(value_start_pos, c - value_start_pos)));
207                   name_start_pos = std::string::npos;
208                   name_end_pos = std::string::npos;
209                   value_start_pos = std::string::npos;
210                 }
211               }
212             }
213           }
214           if(name_start_pos != std::string::npos) {
215             if(name_end_pos == std::string::npos)
216               result.emplace(value.substr(name_start_pos), std::string());
217             else if(value_start_pos != std::string::npos) {
218               if(value.back() == '"')
219                 result.emplace(value.substr(name_start_pos, name_end_pos - name_start_pos), Percent::decode(value.substr(value_start_pos, value.size() - 1)));
220               else
221                 result.emplace(value.substr(name_start_pos, name_end_pos - name_start_pos), Percent::decode(value.substr(value_start_pos)));
222             }
223           }
224 
225           return result;
226         }
227       };
228     };
229   };
230 
231   class RequestMessage {
232   public:
233     /** Parse request line and header fields from a request stream.
234      *
235      * @param[in]  stream       Stream to parse.
236      * @param[out] method       HTTP method.
237      * @param[out] path         Path from request URI.
238      * @param[out] query_string Query string from request URI.
239      * @param[out] version      HTTP version.
240      * @param[out] header       Header fields.
241      *
242      * @return True if stream is parsed successfully, false if not.
243      */
parse(std::istream & stream,std::string & method,std::string & path,std::string & query_string,std::string & version,CaseInsensitiveMultimap & header)244     static bool parse(std::istream &stream, std::string &method, std::string &path, std::string &query_string, std::string &version, CaseInsensitiveMultimap &header) noexcept {
245       std::string line;
246       std::size_t method_end;
247       if(getline(stream, line) && (method_end = line.find(' ')) != std::string::npos) {
248         method = line.substr(0, method_end);
249 
250         std::size_t query_start = std::string::npos;
251         std::size_t path_and_query_string_end = std::string::npos;
252         for(std::size_t i = method_end + 1; i < line.size(); ++i) {
253           if(line[i] == '?' && (i + 1) < line.size())
254             query_start = i + 1;
255           else if(line[i] == ' ') {
256             path_and_query_string_end = i;
257             break;
258           }
259         }
260         if(path_and_query_string_end != std::string::npos) {
261           if(query_start != std::string::npos) {
262             path = line.substr(method_end + 1, query_start - method_end - 2);
263             query_string = line.substr(query_start, path_and_query_string_end - query_start);
264           }
265           else
266             path = line.substr(method_end + 1, path_and_query_string_end - method_end - 1);
267 
268           std::size_t protocol_end;
269           if((protocol_end = line.find('/', path_and_query_string_end + 1)) != std::string::npos) {
270             if(line.compare(path_and_query_string_end + 1, protocol_end - path_and_query_string_end - 1, "HTTP") != 0)
271               return false;
272             version = line.substr(protocol_end + 1, line.size() - protocol_end - 2);
273           }
274           else
275             return false;
276 
277           header = HttpHeader::parse(stream);
278         }
279         else
280           return false;
281       }
282       else
283         return false;
284       return true;
285     }
286   };
287 
288   class ResponseMessage {
289   public:
290     /** Parse status line and header fields from a response stream.
291      *
292      * @param[in]  stream      Stream to parse.
293      * @param[out] version     HTTP version.
294      * @param[out] status_code HTTP status code.
295      * @param[out] header      Header fields.
296      *
297      * @return True if stream is parsed successfully, false if not.
298      */
parse(std::istream & stream,std::string & version,std::string & status_code,CaseInsensitiveMultimap & header)299     static bool parse(std::istream &stream, std::string &version, std::string &status_code, CaseInsensitiveMultimap &header) noexcept {
300       std::string line;
301       std::size_t version_end;
302       if(getline(stream, line) && (version_end = line.find(' ')) != std::string::npos) {
303         if(5 < line.size())
304           version = line.substr(5, version_end - 5);
305         else
306           return false;
307         if((version_end + 1) < line.size())
308           status_code = line.substr(version_end + 1, line.size() - (version_end + 1) - (line.back() == '\r' ? 1 : 0));
309         else
310           return false;
311 
312         header = HttpHeader::parse(stream);
313       }
314       else
315         return false;
316       return true;
317     }
318   };
319 
320   /// Date class working with formats specified in RFC 7231 Date/Time Formats
321   class Date {
322   public:
323     /// Returns the given std::chrono::system_clock::time_point as a string with the following format: Wed, 31 Jul 2019 11:34:23 GMT.
324     /// Warning: while this function is thread safe with other Date::to_string() calls,
325     /// it is not thread safe with other functions that include calls to std::gmtime.
to_string(const std::chrono::system_clock::time_point time_point)326     static std::string to_string(const std::chrono::system_clock::time_point time_point) noexcept {
327       static std::string result_cache;
328       static std::chrono::system_clock::time_point last_time_point;
329 
330       static std::mutex mutex;
331       std::lock_guard<std::mutex> lock(mutex);
332 
333       if(std::chrono::duration_cast<std::chrono::seconds>(time_point - last_time_point).count() == 0 && !result_cache.empty())
334         return result_cache;
335 
336       last_time_point = time_point;
337 
338       std::string result;
339       result.reserve(29);
340 
341       auto time = std::chrono::system_clock::to_time_t(time_point);
342       auto gmtime = std::gmtime(&time);
343 
344       switch(gmtime->tm_wday) {
345       case 0: result += "Sun, "; break;
346       case 1: result += "Mon, "; break;
347       case 2: result += "Tue, "; break;
348       case 3: result += "Wed, "; break;
349       case 4: result += "Thu, "; break;
350       case 5: result += "Fri, "; break;
351       case 6: result += "Sat, "; break;
352       }
353 
354       result += gmtime->tm_mday < 10 ? '0' : static_cast<char>(gmtime->tm_mday / 10 + 48);
355       result += static_cast<char>(gmtime->tm_mday % 10 + 48);
356 
357       switch(gmtime->tm_mon) {
358       case 0: result += " Jan "; break;
359       case 1: result += " Feb "; break;
360       case 2: result += " Mar "; break;
361       case 3: result += " Apr "; break;
362       case 4: result += " May "; break;
363       case 5: result += " Jun "; break;
364       case 6: result += " Jul "; break;
365       case 7: result += " Aug "; break;
366       case 8: result += " Sep "; break;
367       case 9: result += " Oct "; break;
368       case 10: result += " Nov "; break;
369       case 11: result += " Dec "; break;
370       }
371 
372       auto year = gmtime->tm_year + 1900;
373       result += static_cast<char>(year / 1000 + 48);
374       result += static_cast<char>((year / 100) % 10 + 48);
375       result += static_cast<char>((year / 10) % 10 + 48);
376       result += static_cast<char>(year % 10 + 48);
377       result += ' ';
378 
379       result += gmtime->tm_hour < 10 ? '0' : static_cast<char>(gmtime->tm_hour / 10 + 48);
380       result += static_cast<char>(gmtime->tm_hour % 10 + 48);
381       result += ':';
382 
383       result += gmtime->tm_min < 10 ? '0' : static_cast<char>(gmtime->tm_min / 10 + 48);
384       result += static_cast<char>(gmtime->tm_min % 10 + 48);
385       result += ':';
386 
387       result += gmtime->tm_sec < 10 ? '0' : static_cast<char>(gmtime->tm_sec / 10 + 48);
388       result += static_cast<char>(gmtime->tm_sec % 10 + 48);
389 
390       result += " GMT";
391 
392       result_cache = result;
393       return result;
394     }
395   };
396 } // namespace SimpleWeb
397 
398 #ifdef __SSE2__
399 #include <emmintrin.h>
400 namespace SimpleWeb {
spin_loop_pause()401   inline void spin_loop_pause() noexcept { _mm_pause(); }
402 } // namespace SimpleWeb
403 // TODO: need verification that the following checks are correct:
404 #elif defined(_MSC_VER) && _MSC_VER >= 1800 && (defined(_M_X64) || defined(_M_IX86))
405 #include <intrin.h>
406 namespace SimpleWeb {
spin_loop_pause()407   inline void spin_loop_pause() noexcept { _mm_pause(); }
408 } // namespace SimpleWeb
409 #else
410 namespace SimpleWeb {
spin_loop_pause()411   inline void spin_loop_pause() noexcept {}
412 } // namespace SimpleWeb
413 #endif
414 
415 namespace SimpleWeb {
416   /// Makes it possible to for instance cancel Asio handlers without stopping asio::io_service.
417   class ScopeRunner {
418     /// Scope count that is set to -1 if scopes are to be canceled.
419     std::atomic<long> count;
420 
421   public:
422     class SharedLock {
423       friend class ScopeRunner;
424       std::atomic<long> &count;
SharedLock(std::atomic<long> & count)425       SharedLock(std::atomic<long> &count) noexcept : count(count) {}
426       SharedLock &operator=(const SharedLock &) = delete;
427       SharedLock(const SharedLock &) = delete;
428 
429     public:
~SharedLock()430       ~SharedLock() noexcept {
431         count.fetch_sub(1);
432       }
433     };
434 
ScopeRunner()435     ScopeRunner() noexcept : count(0) {}
436 
437     /// Returns nullptr if scope should be exited, or a shared lock otherwise.
438     /// The shared lock ensures that a potential destructor call is delayed until all locks are released.
continue_lock()439     std::unique_ptr<SharedLock> continue_lock() noexcept {
440       long expected = count;
441       while(expected >= 0 && !count.compare_exchange_weak(expected, expected + 1))
442         spin_loop_pause();
443 
444       if(expected < 0)
445         return nullptr;
446       else
447         return std::unique_ptr<SharedLock>(new SharedLock(count));
448     }
449 
450     /// Blocks until all shared locks are released, then prevents future shared locks.
stop()451     void stop() noexcept {
452       long expected = 0;
453       while(!count.compare_exchange_weak(expected, -1)) {
454         if(expected < 0)
455           return;
456         expected = 0;
457         spin_loop_pause();
458       }
459     }
460   };
461 } // namespace SimpleWeb
462 
463 #endif // SIMPLE_WEB_UTILITY_HPP
464