1 //
2 //  httplib.h
3 //
4 //  Copyright (c) 2020 Yuji Hirose. All rights reserved.
5 //  MIT License
6 //
7 
8 #ifndef CPPHTTPLIB_HTTPLIB_H
9 #define CPPHTTPLIB_HTTPLIB_H
10 
11 /*
12  * Configuration
13  */
14 
15 #ifndef CPPHTTPLIB_KEEPALIVE_TIMEOUT_SECOND
16 #define CPPHTTPLIB_KEEPALIVE_TIMEOUT_SECOND 5
17 #endif
18 
19 #ifndef CPPHTTPLIB_KEEPALIVE_TIMEOUT_USECOND
20 #define CPPHTTPLIB_KEEPALIVE_TIMEOUT_USECOND 0
21 #endif
22 
23 #ifndef CPPHTTPLIB_KEEPALIVE_MAX_COUNT
24 #define CPPHTTPLIB_KEEPALIVE_MAX_COUNT 5
25 #endif
26 
27 #ifndef CPPHTTPLIB_READ_TIMEOUT_SECOND
28 #define CPPHTTPLIB_READ_TIMEOUT_SECOND 5
29 #endif
30 
31 #ifndef CPPHTTPLIB_READ_TIMEOUT_USECOND
32 #define CPPHTTPLIB_READ_TIMEOUT_USECOND 0
33 #endif
34 
35 #ifndef CPPHTTPLIB_REQUEST_URI_MAX_LENGTH
36 #define CPPHTTPLIB_REQUEST_URI_MAX_LENGTH 8192
37 #endif
38 
39 #ifndef CPPHTTPLIB_REDIRECT_MAX_COUNT
40 #define CPPHTTPLIB_REDIRECT_MAX_COUNT 20
41 #endif
42 
43 #ifndef CPPHTTPLIB_PAYLOAD_MAX_LENGTH
44 #define CPPHTTPLIB_PAYLOAD_MAX_LENGTH ((std::numeric_limits<size_t>::max)())
45 #endif
46 
47 #ifndef CPPHTTPLIB_RECV_BUFSIZ
48 #define CPPHTTPLIB_RECV_BUFSIZ size_t(4096u)
49 #endif
50 
51 #ifndef CPPHTTPLIB_THREAD_POOL_COUNT
52 #define CPPHTTPLIB_THREAD_POOL_COUNT                                           \
53   ((std::max)(1u, std::thread::hardware_concurrency() - 1))
54 #endif
55 
56 /*
57  * Headers
58  */
59 
60 #ifdef _WIN32
61 #ifndef _CRT_SECURE_NO_WARNINGS
62 #define _CRT_SECURE_NO_WARNINGS
63 #endif //_CRT_SECURE_NO_WARNINGS
64 
65 #ifndef _CRT_NONSTDC_NO_DEPRECATE
66 #define _CRT_NONSTDC_NO_DEPRECATE
67 #endif //_CRT_NONSTDC_NO_DEPRECATE
68 
69 #if defined(_MSC_VER)
70 #ifdef _WIN64
71 using ssize_t = __int64;
72 #else
73 using ssize_t = int;
74 #endif
75 
76 #if _MSC_VER < 1900
77 #define snprintf _snprintf_s
78 #endif
79 #endif // _MSC_VER
80 
81 #ifndef S_ISREG
82 #define S_ISREG(m) (((m)&S_IFREG) == S_IFREG)
83 #endif // S_ISREG
84 
85 #ifndef S_ISDIR
86 #define S_ISDIR(m) (((m)&S_IFDIR) == S_IFDIR)
87 #endif // S_ISDIR
88 
89 #ifndef NOMINMAX
90 #define NOMINMAX
91 #endif // NOMINMAX
92 
93 #include <io.h>
94 #include <winsock2.h>
95 #include <ws2tcpip.h>
96 
97 #ifndef WSA_FLAG_NO_HANDLE_INHERIT
98 #define WSA_FLAG_NO_HANDLE_INHERIT 0x80
99 #endif
100 
101 #ifdef _MSC_VER
102 #pragma comment(lib, "ws2_32.lib")
103 #endif
104 
105 #ifndef strcasecmp
106 #define strcasecmp _stricmp
107 #endif // strcasecmp
108 
109 using socket_t = SOCKET;
110 #ifdef CPPHTTPLIB_USE_POLL
111 #define poll(fds, nfds, timeout) WSAPoll(fds, nfds, timeout)
112 #endif
113 
114 #else // not _WIN32
115 
116 #include <sys/types.h>
117 #include <arpa/inet.h>
118 #include <cstring>
119 #include <ifaddrs.h>
120 #include <netdb.h>
121 #include <netinet/in.h>
122 #ifdef CPPHTTPLIB_USE_POLL
123 #include <poll.h>
124 #endif
125 #include <csignal>
126 #include <pthread.h>
127 #include <sys/select.h>
128 #include <sys/socket.h>
129 #include <unistd.h>
130 
131 using socket_t = int;
132 #define INVALID_SOCKET (-1)
133 #endif //_WIN32
134 
135 #include <array>
136 #include <atomic>
137 #include <cassert>
138 #include <climits>
139 #include <condition_variable>
140 #include <errno.h>
141 #include <fcntl.h>
142 #include <fstream>
143 #include <functional>
144 #include <list>
145 #include <map>
146 #include <memory>
147 #include <mutex>
148 #include <random>
149 #include <regex>
150 #include <string>
151 #include <sys/stat.h>
152 #include <thread>
153 
154 #ifdef CPPHTTPLIB_OPENSSL_SUPPORT
155 #include <openssl/err.h>
156 #include <openssl/md5.h>
157 #include <openssl/ssl.h>
158 #include <openssl/x509v3.h>
159 
160 #include <iomanip>
161 #include <iostream>
162 #include <sstream>
163 
164 // #if OPENSSL_VERSION_NUMBER < 0x1010100fL
165 // #error Sorry, OpenSSL versions prior to 1.1.1 are not supported
166 // #endif
167 
168 #if OPENSSL_VERSION_NUMBER < 0x10100000L
169 #include <openssl/crypto.h>
ASN1_STRING_get0_data(const ASN1_STRING * asn1)170 inline const unsigned char *ASN1_STRING_get0_data(const ASN1_STRING *asn1) {
171   return M_ASN1_STRING_data(asn1);
172 }
173 #endif
174 #endif
175 
176 #ifdef CPPHTTPLIB_ZLIB_SUPPORT
177 #include <zlib.h>
178 #endif
179 /*
180  * Declaration
181  */
182 namespace httplib {
183 
184 namespace detail {
185 
186 struct ci {
operatorci187   bool operator()(const std::string &s1, const std::string &s2) const {
188     return std::lexicographical_compare(
189         s1.begin(), s1.end(), s2.begin(), s2.end(),
190         [](char c1, char c2) { return ::tolower(c1) < ::tolower(c2); });
191   }
192 };
193 
194 } // namespace detail
195 
196 using Headers = std::multimap<std::string, std::string, detail::ci>;
197 
198 using Params = std::multimap<std::string, std::string>;
199 using Match = std::smatch;
200 
201 using Progress = std::function<bool(uint64_t current, uint64_t total)>;
202 
203 struct Response;
204 using ResponseHandler = std::function<bool(const Response &response)>;
205 
206 struct MultipartFormData {
207   std::string name;
208   std::string content;
209   std::string filename;
210   std::string content_type;
211 };
212 using MultipartFormDataItems = std::vector<MultipartFormData>;
213 using MultipartFormDataMap = std::multimap<std::string, MultipartFormData>;
214 
215 class DataSink {
216 public:
217   DataSink() = default;
218   DataSink(const DataSink &) = delete;
219   DataSink &operator=(const DataSink &) = delete;
220   DataSink(DataSink &&) = delete;
221   DataSink &operator=(DataSink &&) = delete;
222 
223   std::function<void(const char *data, size_t data_len)> write;
224   std::function<void()> done;
225   std::function<bool()> is_writable;
226 };
227 
228 using ContentProvider =
229     std::function<void(size_t offset, size_t length, DataSink &sink)>;
230 
231 using ContentReceiver =
232     std::function<bool(const char *data, size_t data_length)>;
233 
234 using MultipartContentHeader =
235     std::function<bool(const MultipartFormData &file)>;
236 
237 class ContentReader {
238 public:
239   using Reader = std::function<bool(ContentReceiver receiver)>;
240   using MultipartReader = std::function<bool(MultipartContentHeader header,
241                                              ContentReceiver receiver)>;
242 
ContentReader(Reader reader,MultipartReader muitlpart_reader)243   ContentReader(Reader reader, MultipartReader muitlpart_reader)
244       : reader_(reader), muitlpart_reader_(muitlpart_reader) {}
245 
operator()246   bool operator()(MultipartContentHeader header,
247                   ContentReceiver receiver) const {
248     return muitlpart_reader_(header, receiver);
249   }
250 
operator()251   bool operator()(ContentReceiver receiver) const { return reader_(receiver); }
252 
253   Reader reader_;
254   MultipartReader muitlpart_reader_;
255 };
256 
257 using Range = std::pair<ssize_t, ssize_t>;
258 using Ranges = std::vector<Range>;
259 
260 struct Request {
261   std::string method;
262   std::string path;
263   Headers headers;
264   std::string body;
265 
266   std::string remote_addr;
267   int remote_port = -1;
268 
269   // for server
270   std::string version;
271   std::string target;
272   Params params;
273   MultipartFormDataMap files;
274   Ranges ranges;
275   Match matches;
276 
277   // for client
278   size_t redirect_count = CPPHTTPLIB_REDIRECT_MAX_COUNT;
279   ResponseHandler response_handler;
280   ContentReceiver content_receiver;
281   Progress progress;
282 
283 #ifdef CPPHTTPLIB_OPENSSL_SUPPORT
284   const SSL *ssl;
285 #endif
286 
287   bool has_header(const char *key) const;
288   std::string get_header_value(const char *key, size_t id = 0) const;
289   size_t get_header_value_count(const char *key) const;
290   void set_header(const char *key, const char *val);
291   void set_header(const char *key, const std::string &val);
292 
293   bool has_param(const char *key) const;
294   std::string get_param_value(const char *key, size_t id = 0) const;
295   size_t get_param_value_count(const char *key) const;
296 
297   bool is_multipart_form_data() const;
298 
299   bool has_file(const char *key) const;
300   MultipartFormData get_file_value(const char *key) const;
301 
302   // private members...
303   size_t content_length;
304   ContentProvider content_provider;
305 };
306 
307 struct Response {
308   std::string version;
309   int status = -1;
310   Headers headers;
311   std::string body;
312 
313   bool has_header(const char *key) const;
314   std::string get_header_value(const char *key, size_t id = 0) const;
315   size_t get_header_value_count(const char *key) const;
316   void set_header(const char *key, const char *val);
317   void set_header(const char *key, const std::string &val);
318 
319   void set_redirect(const char *url, int status = 302);
320   void set_content(const char *s, size_t n, const char *content_type);
321   void set_content(std::string s, const char *content_type);
322 
323   void set_content_provider(
324       size_t length,
325       std::function<void(size_t offset, size_t length, DataSink &sink)>
326           provider,
327       std::function<void()> resource_releaser = [] {});
328 
329   void set_chunked_content_provider(
330       std::function<void(size_t offset, DataSink &sink)> provider,
331       std::function<void()> resource_releaser = [] {});
332 
333   Response() = default;
334   Response(const Response &) = default;
335   Response &operator=(const Response &) = default;
336   Response(Response &&) = default;
337   Response &operator=(Response &&) = default;
~ResponseResponse338   ~Response() {
339     if (content_provider_resource_releaser) {
340       content_provider_resource_releaser();
341     }
342   }
343 
344   // private members...
345   size_t content_length = 0;
346   ContentProvider content_provider;
347   std::function<void()> content_provider_resource_releaser;
348 };
349 
350 class Stream {
351 public:
352   virtual ~Stream() = default;
353 
354   virtual bool is_readable() const = 0;
355   virtual bool is_writable() const = 0;
356 
357   virtual ssize_t read(char *ptr, size_t size) = 0;
358   virtual ssize_t write(const char *ptr, size_t size) = 0;
359   virtual void get_remote_ip_and_port(std::string &ip, int &port) const = 0;
360 
361   template <typename... Args>
362   ssize_t write_format(const char *fmt, const Args &... args);
363   ssize_t write(const char *ptr);
364   ssize_t write(const std::string &s);
365 };
366 
367 class TaskQueue {
368 public:
369   TaskQueue() = default;
370   virtual ~TaskQueue() = default;
371 
372   virtual void enqueue(std::function<void()> fn) = 0;
373   virtual void shutdown() = 0;
374 
on_idle()375   virtual void on_idle(){};
376 };
377 
378 class ThreadPool : public TaskQueue {
379 public:
ThreadPool(size_t n)380   explicit ThreadPool(size_t n) : shutdown_(false) {
381     while (n) {
382       threads_.emplace_back(worker(*this));
383       n--;
384     }
385   }
386 
387   ThreadPool(const ThreadPool &) = delete;
388   ~ThreadPool() override = default;
389 
enqueue(std::function<void ()> fn)390   void enqueue(std::function<void()> fn) override {
391     std::unique_lock<std::mutex> lock(mutex_);
392     jobs_.push_back(fn);
393     cond_.notify_one();
394   }
395 
shutdown()396   void shutdown() override {
397     // Stop all worker threads...
398     {
399       std::unique_lock<std::mutex> lock(mutex_);
400       shutdown_ = true;
401     }
402 
403     cond_.notify_all();
404 
405     // Join...
406     for (auto &t : threads_) {
407       t.join();
408     }
409   }
410 
411 private:
412   struct worker {
workerworker413     explicit worker(ThreadPool &pool) : pool_(pool) {}
414 
operatorworker415     void operator()() {
416       for (;;) {
417         std::function<void()> fn;
418         {
419           std::unique_lock<std::mutex> lock(pool_.mutex_);
420 
421           pool_.cond_.wait(
422               lock, [&] { return !pool_.jobs_.empty() || pool_.shutdown_; });
423 
424           if (pool_.shutdown_ && pool_.jobs_.empty()) { break; }
425 
426           fn = pool_.jobs_.front();
427           pool_.jobs_.pop_front();
428         }
429 
430         assert(true == static_cast<bool>(fn));
431         fn();
432       }
433     }
434 
435     ThreadPool &pool_;
436   };
437   friend struct worker;
438 
439   std::vector<std::thread> threads_;
440   std::list<std::function<void()>> jobs_;
441 
442   bool shutdown_;
443 
444   std::condition_variable cond_;
445   std::mutex mutex_;
446 };
447 
448 using Logger = std::function<void(const Request &, const Response &)>;
449 
450 class Server {
451 public:
452   using Handler = std::function<void(const Request &, Response &)>;
453   using HandlerWithContentReader = std::function<void(
454       const Request &, Response &, const ContentReader &content_reader)>;
455   using Expect100ContinueHandler =
456       std::function<int(const Request &, Response &)>;
457 
458   Server();
459 
460   virtual ~Server();
461 
462   virtual bool is_valid() const;
463 
464   Server &Get(const char *pattern, Handler handler);
465   Server &Post(const char *pattern, Handler handler);
466   Server &Post(const char *pattern, HandlerWithContentReader handler);
467   Server &Put(const char *pattern, Handler handler);
468   Server &Put(const char *pattern, HandlerWithContentReader handler);
469   Server &Patch(const char *pattern, Handler handler);
470   Server &Patch(const char *pattern, HandlerWithContentReader handler);
471   Server &Delete(const char *pattern, Handler handler);
472   Server &Delete(const char *pattern, HandlerWithContentReader handler);
473   Server &Options(const char *pattern, Handler handler);
474 
475   [[deprecated]] bool set_base_dir(const char *dir,
476                                    const char *mount_point = nullptr);
477   bool set_mount_point(const char *mount_point, const char *dir);
478   bool remove_mount_point(const char *mount_point);
479   void set_file_extension_and_mimetype_mapping(const char *ext,
480                                                const char *mime);
481   void set_file_request_handler(Handler handler);
482 
483   void set_error_handler(Handler handler);
484   void set_logger(Logger logger);
485 
486   void set_expect_100_continue_handler(Expect100ContinueHandler handler);
487 
488   void set_keep_alive_max_count(size_t count);
489   void set_read_timeout(time_t sec, time_t usec);
490   void set_payload_max_length(size_t length);
491 
492   bool bind_to_port(const char *host, int port, int socket_flags = 0);
493   int bind_to_any_port(const char *host, int socket_flags = 0);
494   bool listen_after_bind();
495 
496   bool listen(const char *host, int port, int socket_flags = 0);
497 
498   bool is_running() const;
499   void stop();
500 
501   std::function<TaskQueue *(void)> new_task_queue;
502 
503 protected:
504   bool process_request(Stream &strm, bool last_connection,
505                        bool &connection_close,
506                        const std::function<void(Request &)> &setup_request);
507 
508   size_t keep_alive_max_count_;
509   time_t read_timeout_sec_;
510   time_t read_timeout_usec_;
511   size_t payload_max_length_;
512 
513 private:
514   using Handlers = std::vector<std::pair<std::regex, Handler>>;
515   using HandlersForContentReader =
516       std::vector<std::pair<std::regex, HandlerWithContentReader>>;
517 
518   socket_t create_server_socket(const char *host, int port,
519                                 int socket_flags) const;
520   int bind_internal(const char *host, int port, int socket_flags);
521   bool listen_internal();
522 
523   bool routing(Request &req, Response &res, Stream &strm);
524   bool handle_file_request(Request &req, Response &res, bool head = false);
525   bool dispatch_request(Request &req, Response &res, Handlers &handlers);
526   bool dispatch_request_for_content_reader(Request &req, Response &res,
527                                            ContentReader content_reader,
528                                            HandlersForContentReader &handlers);
529 
530   bool parse_request_line(const char *s, Request &req);
531   bool write_response(Stream &strm, bool last_connection, const Request &req,
532                       Response &res);
533   bool write_content_with_provider(Stream &strm, const Request &req,
534                                    Response &res, const std::string &boundary,
535                                    const std::string &content_type);
536   bool read_content(Stream &strm, Request &req, Response &res);
537   bool
538   read_content_with_content_receiver(Stream &strm, Request &req, Response &res,
539                                      ContentReceiver receiver,
540                                      MultipartContentHeader multipart_header,
541                                      ContentReceiver multipart_receiver);
542   bool read_content_core(Stream &strm, Request &req, Response &res,
543                          ContentReceiver receiver,
544                          MultipartContentHeader mulitpart_header,
545                          ContentReceiver multipart_receiver);
546 
547   virtual bool process_and_close_socket(socket_t sock);
548 
549   std::atomic<bool> is_running_;
550   std::atomic<socket_t> svr_sock_;
551   std::vector<std::pair<std::string, std::string>> base_dirs_;
552   std::map<std::string, std::string> file_extension_and_mimetype_map_;
553   Handler file_request_handler_;
554   Handlers get_handlers_;
555   Handlers post_handlers_;
556   HandlersForContentReader post_handlers_for_content_reader_;
557   Handlers put_handlers_;
558   HandlersForContentReader put_handlers_for_content_reader_;
559   Handlers patch_handlers_;
560   HandlersForContentReader patch_handlers_for_content_reader_;
561   Handlers delete_handlers_;
562   HandlersForContentReader delete_handlers_for_content_reader_;
563   Handlers options_handlers_;
564   Handler error_handler_;
565   Logger logger_;
566   Expect100ContinueHandler expect_100_continue_handler_;
567 };
568 
569 class Client {
570 public:
571   explicit Client(const std::string &host, int port = 80,
572                   const std::string &client_cert_path = std::string(),
573                   const std::string &client_key_path = std::string());
574 
575   virtual ~Client();
576 
577   virtual bool is_valid() const;
578 
579   std::shared_ptr<Response> Get(const char *path);
580 
581   std::shared_ptr<Response> Get(const char *path, const Headers &headers);
582 
583   std::shared_ptr<Response> Get(const char *path, Progress progress);
584 
585   std::shared_ptr<Response> Get(const char *path, const Headers &headers,
586                                 Progress progress);
587 
588   std::shared_ptr<Response> Get(const char *path,
589                                 ContentReceiver content_receiver);
590 
591   std::shared_ptr<Response> Get(const char *path, const Headers &headers,
592                                 ContentReceiver content_receiver);
593 
594   std::shared_ptr<Response>
595   Get(const char *path, ContentReceiver content_receiver, Progress progress);
596 
597   std::shared_ptr<Response> Get(const char *path, const Headers &headers,
598                                 ContentReceiver content_receiver,
599                                 Progress progress);
600 
601   std::shared_ptr<Response> Get(const char *path, const Headers &headers,
602                                 ResponseHandler response_handler,
603                                 ContentReceiver content_receiver);
604 
605   std::shared_ptr<Response> Get(const char *path, const Headers &headers,
606                                 ResponseHandler response_handler,
607                                 ContentReceiver content_receiver,
608                                 Progress progress);
609 
610   std::shared_ptr<Response> Head(const char *path);
611 
612   std::shared_ptr<Response> Head(const char *path, const Headers &headers);
613 
614   std::shared_ptr<Response> Post(const char *path);
615 
616   std::shared_ptr<Response> Post(const char *path, const std::string &body,
617                                  const char *content_type);
618 
619   std::shared_ptr<Response> Post(const char *path, const Headers &headers,
620                                  const std::string &body,
621                                  const char *content_type);
622 
623   std::shared_ptr<Response> Post(const char *path, size_t content_length,
624                                  ContentProvider content_provider,
625                                  const char *content_type);
626 
627   std::shared_ptr<Response> Post(const char *path, const Headers &headers,
628                                  size_t content_length,
629                                  ContentProvider content_provider,
630                                  const char *content_type);
631 
632   std::shared_ptr<Response> Post(const char *path, const Params &params);
633 
634   std::shared_ptr<Response> Post(const char *path, const Headers &headers,
635                                  const Params &params);
636 
637   std::shared_ptr<Response> Post(const char *path,
638                                  const MultipartFormDataItems &items);
639 
640   std::shared_ptr<Response> Post(const char *path, const Headers &headers,
641                                  const MultipartFormDataItems &items);
642 
643   std::shared_ptr<Response> Put(const char *path);
644 
645   std::shared_ptr<Response> Put(const char *path, const std::string &body,
646                                 const char *content_type);
647 
648   std::shared_ptr<Response> Put(const char *path, const Headers &headers,
649                                 const std::string &body,
650                                 const char *content_type);
651 
652   std::shared_ptr<Response> Put(const char *path, size_t content_length,
653                                 ContentProvider content_provider,
654                                 const char *content_type);
655 
656   std::shared_ptr<Response> Put(const char *path, const Headers &headers,
657                                 size_t content_length,
658                                 ContentProvider content_provider,
659                                 const char *content_type);
660 
661   std::shared_ptr<Response> Put(const char *path, const Params &params);
662 
663   std::shared_ptr<Response> Put(const char *path, const Headers &headers,
664                                 const Params &params);
665 
666   std::shared_ptr<Response> Patch(const char *path, const std::string &body,
667                                   const char *content_type);
668 
669   std::shared_ptr<Response> Patch(const char *path, const Headers &headers,
670                                   const std::string &body,
671                                   const char *content_type);
672 
673   std::shared_ptr<Response> Patch(const char *path, size_t content_length,
674                                   ContentProvider content_provider,
675                                   const char *content_type);
676 
677   std::shared_ptr<Response> Patch(const char *path, const Headers &headers,
678                                   size_t content_length,
679                                   ContentProvider content_provider,
680                                   const char *content_type);
681 
682   std::shared_ptr<Response> Delete(const char *path);
683 
684   std::shared_ptr<Response> Delete(const char *path, const std::string &body,
685                                    const char *content_type);
686 
687   std::shared_ptr<Response> Delete(const char *path, const Headers &headers);
688 
689   std::shared_ptr<Response> Delete(const char *path, const Headers &headers,
690                                    const std::string &body,
691                                    const char *content_type);
692 
693   std::shared_ptr<Response> Options(const char *path);
694 
695   std::shared_ptr<Response> Options(const char *path, const Headers &headers);
696 
697   bool send(const Request &req, Response &res);
698 
699   bool send(const std::vector<Request> &requests,
700             std::vector<Response> &responses);
701 
702   void stop();
703 
704   void set_timeout_sec(time_t timeout_sec);
705 
706   void set_read_timeout(time_t sec, time_t usec);
707 
708   void set_keep_alive_max_count(size_t count);
709 
710   void set_basic_auth(const char *username, const char *password);
711 
712 #ifdef CPPHTTPLIB_OPENSSL_SUPPORT
713   void set_digest_auth(const char *username, const char *password);
714 #endif
715 
716   void set_follow_location(bool on);
717 
718   void set_compress(bool on);
719 
720   void set_interface(const char *intf);
721 
722   void set_proxy(const char *host, int port);
723 
724   void set_proxy_basic_auth(const char *username, const char *password);
725 
726 #ifdef CPPHTTPLIB_OPENSSL_SUPPORT
727   void set_proxy_digest_auth(const char *username, const char *password);
728 #endif
729 
730   void set_logger(Logger logger);
731 
732 protected:
733   bool process_request(Stream &strm, const Request &req, Response &res,
734                        bool last_connection, bool &connection_close);
735 
736   std::atomic<socket_t> sock_;
737 
738   const std::string host_;
739   const int port_;
740   const std::string host_and_port_;
741 
742   // Settings
743   std::string client_cert_path_;
744   std::string client_key_path_;
745 
746   time_t timeout_sec_ = 300;
747   time_t read_timeout_sec_ = CPPHTTPLIB_READ_TIMEOUT_SECOND;
748   time_t read_timeout_usec_ = CPPHTTPLIB_READ_TIMEOUT_USECOND;
749 
750   size_t keep_alive_max_count_ = CPPHTTPLIB_KEEPALIVE_MAX_COUNT;
751 
752   std::string basic_auth_username_;
753   std::string basic_auth_password_;
754 #ifdef CPPHTTPLIB_OPENSSL_SUPPORT
755   std::string digest_auth_username_;
756   std::string digest_auth_password_;
757 #endif
758 
759   bool follow_location_ = false;
760 
761   bool compress_ = false;
762 
763   std::string interface_;
764 
765   std::string proxy_host_;
766   int proxy_port_;
767 
768   std::string proxy_basic_auth_username_;
769   std::string proxy_basic_auth_password_;
770 #ifdef CPPHTTPLIB_OPENSSL_SUPPORT
771   std::string proxy_digest_auth_username_;
772   std::string proxy_digest_auth_password_;
773 #endif
774 
775   Logger logger_;
776 
copy_settings(const Client & rhs)777   void copy_settings(const Client &rhs) {
778     client_cert_path_ = rhs.client_cert_path_;
779     client_key_path_ = rhs.client_key_path_;
780     timeout_sec_ = rhs.timeout_sec_;
781     read_timeout_sec_ = rhs.read_timeout_sec_;
782     read_timeout_usec_ = rhs.read_timeout_usec_;
783     keep_alive_max_count_ = rhs.keep_alive_max_count_;
784     basic_auth_username_ = rhs.basic_auth_username_;
785     basic_auth_password_ = rhs.basic_auth_password_;
786 #ifdef CPPHTTPLIB_OPENSSL_SUPPORT
787     digest_auth_username_ = rhs.digest_auth_username_;
788     digest_auth_password_ = rhs.digest_auth_password_;
789 #endif
790     follow_location_ = rhs.follow_location_;
791     compress_ = rhs.compress_;
792     interface_ = rhs.interface_;
793     proxy_host_ = rhs.proxy_host_;
794     proxy_port_ = rhs.proxy_port_;
795     proxy_basic_auth_username_ = rhs.proxy_basic_auth_username_;
796     proxy_basic_auth_password_ = rhs.proxy_basic_auth_password_;
797 #ifdef CPPHTTPLIB_OPENSSL_SUPPORT
798     proxy_digest_auth_username_ = rhs.proxy_digest_auth_username_;
799     proxy_digest_auth_password_ = rhs.proxy_digest_auth_password_;
800 #endif
801     logger_ = rhs.logger_;
802   }
803 
804 private:
805   socket_t create_client_socket() const;
806   bool read_response_line(Stream &strm, Response &res);
807   bool write_request(Stream &strm, const Request &req, bool last_connection);
808   bool redirect(const Request &req, Response &res);
809   bool handle_request(Stream &strm, const Request &req, Response &res,
810                       bool last_connection, bool &connection_close);
811 #ifdef CPPHTTPLIB_OPENSSL_SUPPORT
812   bool connect(socket_t sock, Response &res, bool &error);
813 #endif
814 
815   std::shared_ptr<Response> send_with_content_provider(
816       const char *method, const char *path, const Headers &headers,
817       const std::string &body, size_t content_length,
818       ContentProvider content_provider, const char *content_type);
819 
820   virtual bool process_and_close_socket(
821       socket_t sock, size_t request_count,
822       std::function<bool(Stream &strm, bool last_connection,
823                          bool &connection_close)>
824           callback);
825 
826   virtual bool is_ssl() const;
827 };
828 
Get(std::vector<Request> & requests,const char * path,const Headers & headers)829 inline void Get(std::vector<Request> &requests, const char *path,
830                 const Headers &headers) {
831   Request req;
832   req.method = "GET";
833   req.path = path;
834   req.headers = headers;
835   requests.emplace_back(std::move(req));
836 }
837 
Get(std::vector<Request> & requests,const char * path)838 inline void Get(std::vector<Request> &requests, const char *path) {
839   Get(requests, path, Headers());
840 }
841 
Post(std::vector<Request> & requests,const char * path,const Headers & headers,const std::string & body,const char * content_type)842 inline void Post(std::vector<Request> &requests, const char *path,
843                  const Headers &headers, const std::string &body,
844                  const char *content_type) {
845   Request req;
846   req.method = "POST";
847   req.path = path;
848   req.headers = headers;
849   if (content_type) { req.headers.emplace("Content-Type", content_type); }
850   req.body = body;
851   requests.emplace_back(std::move(req));
852 }
853 
Post(std::vector<Request> & requests,const char * path,const std::string & body,const char * content_type)854 inline void Post(std::vector<Request> &requests, const char *path,
855                  const std::string &body, const char *content_type) {
856   Post(requests, path, Headers(), body, content_type);
857 }
858 
859 #ifdef CPPHTTPLIB_OPENSSL_SUPPORT
860 class SSLServer : public Server {
861 public:
862   SSLServer(const char *cert_path, const char *private_key_path,
863             const char *client_ca_cert_file_path = nullptr,
864             const char *client_ca_cert_dir_path = nullptr);
865 
866   SSLServer(X509 *cert, EVP_PKEY *private_key,
867             X509_STORE *client_ca_cert_store = nullptr);
868 
869   ~SSLServer() override;
870 
871   bool is_valid() const override;
872 
873 private:
874   bool process_and_close_socket(socket_t sock) override;
875 
876   SSL_CTX *ctx_;
877   std::mutex ctx_mutex_;
878 };
879 
880 class SSLClient : public Client {
881 public:
882   explicit SSLClient(const std::string &host, int port = 443,
883                      const std::string &client_cert_path = std::string(),
884                      const std::string &client_key_path = std::string());
885 
886   SSLClient(const std::string &host, int port, X509 *client_cert,
887             EVP_PKEY *client_key);
888 
889   ~SSLClient() override;
890 
891   bool is_valid() const override;
892 
893   void set_ca_cert_path(const char *ca_ceert_file_path,
894                         const char *ca_cert_dir_path = nullptr);
895 
896   void set_ca_cert_store(X509_STORE *ca_cert_store);
897 
898   void enable_server_certificate_verification(bool enabled);
899 
900   long get_openssl_verify_result() const;
901 
902   SSL_CTX *ssl_context() const;
903 
904 private:
905   bool process_and_close_socket(
906       socket_t sock, size_t request_count,
907       std::function<bool(Stream &strm, bool last_connection,
908                          bool &connection_close)>
909           callback) override;
910   bool is_ssl() const override;
911 
912   bool verify_host(X509 *server_cert) const;
913   bool verify_host_with_subject_alt_name(X509 *server_cert) const;
914   bool verify_host_with_common_name(X509 *server_cert) const;
915   bool check_host_name(const char *pattern, size_t pattern_len) const;
916 
917   SSL_CTX *ctx_;
918   std::mutex ctx_mutex_;
919   std::vector<std::string> host_components_;
920 
921   std::string ca_cert_file_path_;
922   std::string ca_cert_dir_path_;
923   X509_STORE *ca_cert_store_ = nullptr;
924   bool server_certificate_verification_ = false;
925   long verify_result_ = 0;
926 };
927 #endif
928 
929 // ----------------------------------------------------------------------------
930 
931 /*
932  * Implementation
933  */
934 
935 namespace detail {
936 
is_hex(char c,int & v)937 inline bool is_hex(char c, int &v) {
938   if (0x20 <= c && isdigit(c)) {
939     v = c - '0';
940     return true;
941   } else if ('A' <= c && c <= 'F') {
942     v = c - 'A' + 10;
943     return true;
944   } else if ('a' <= c && c <= 'f') {
945     v = c - 'a' + 10;
946     return true;
947   }
948   return false;
949 }
950 
from_hex_to_i(const std::string & s,size_t i,size_t cnt,int & val)951 inline bool from_hex_to_i(const std::string &s, size_t i, size_t cnt,
952                           int &val) {
953   if (i >= s.size()) { return false; }
954 
955   val = 0;
956   for (; cnt; i++, cnt--) {
957     if (!s[i]) { return false; }
958     int v = 0;
959     if (is_hex(s[i], v)) {
960       val = val * 16 + v;
961     } else {
962       return false;
963     }
964   }
965   return true;
966 }
967 
from_i_to_hex(size_t n)968 inline std::string from_i_to_hex(size_t n) {
969   const char *charset = "0123456789abcdef";
970   std::string ret;
971   do {
972     ret = charset[n & 15] + ret;
973     n >>= 4;
974   } while (n > 0);
975   return ret;
976 }
977 
to_utf8(int code,char * buff)978 inline size_t to_utf8(int code, char *buff) {
979   if (code < 0x0080) {
980     buff[0] = (code & 0x7F);
981     return 1;
982   } else if (code < 0x0800) {
983     buff[0] = static_cast<char>(0xC0 | ((code >> 6) & 0x1F));
984     buff[1] = static_cast<char>(0x80 | (code & 0x3F));
985     return 2;
986   } else if (code < 0xD800) {
987     buff[0] = static_cast<char>(0xE0 | ((code >> 12) & 0xF));
988     buff[1] = static_cast<char>(0x80 | ((code >> 6) & 0x3F));
989     buff[2] = static_cast<char>(0x80 | (code & 0x3F));
990     return 3;
991   } else if (code < 0xE000) { // D800 - DFFF is invalid...
992     return 0;
993   } else if (code < 0x10000) {
994     buff[0] = static_cast<char>(0xE0 | ((code >> 12) & 0xF));
995     buff[1] = static_cast<char>(0x80 | ((code >> 6) & 0x3F));
996     buff[2] = static_cast<char>(0x80 | (code & 0x3F));
997     return 3;
998   } else if (code < 0x110000) {
999     buff[0] = static_cast<char>(0xF0 | ((code >> 18) & 0x7));
1000     buff[1] = static_cast<char>(0x80 | ((code >> 12) & 0x3F));
1001     buff[2] = static_cast<char>(0x80 | ((code >> 6) & 0x3F));
1002     buff[3] = static_cast<char>(0x80 | (code & 0x3F));
1003     return 4;
1004   }
1005 
1006   // NOTREACHED
1007   return 0;
1008 }
1009 
1010 // NOTE: This code came up with the following stackoverflow post:
1011 // https://stackoverflow.com/questions/180947/base64-decode-snippet-in-c
base64_encode(const std::string & in)1012 inline std::string base64_encode(const std::string &in) {
1013   static const auto lookup =
1014       "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
1015 
1016   std::string out;
1017   out.reserve(in.size());
1018 
1019   int val = 0;
1020   int valb = -6;
1021 
1022   for (auto c : in) {
1023     val = (val << 8) + static_cast<uint8_t>(c);
1024     valb += 8;
1025     while (valb >= 0) {
1026       out.push_back(lookup[(val >> valb) & 0x3F]);
1027       valb -= 6;
1028     }
1029   }
1030 
1031   if (valb > -6) { out.push_back(lookup[((val << 8) >> (valb + 8)) & 0x3F]); }
1032 
1033   while (out.size() % 4) {
1034     out.push_back('=');
1035   }
1036 
1037   return out;
1038 }
1039 
is_file(const std::string & path)1040 inline bool is_file(const std::string &path) {
1041   struct stat st;
1042   return stat(path.c_str(), &st) >= 0 && S_ISREG(st.st_mode);
1043 }
1044 
is_dir(const std::string & path)1045 inline bool is_dir(const std::string &path) {
1046   struct stat st;
1047   return stat(path.c_str(), &st) >= 0 && S_ISDIR(st.st_mode);
1048 }
1049 
is_valid_path(const std::string & path)1050 inline bool is_valid_path(const std::string &path) {
1051   size_t level = 0;
1052   size_t i = 0;
1053 
1054   // Skip slash
1055   while (i < path.size() && path[i] == '/') {
1056     i++;
1057   }
1058 
1059   while (i < path.size()) {
1060     // Read component
1061     auto beg = i;
1062     while (i < path.size() && path[i] != '/') {
1063       i++;
1064     }
1065 
1066     auto len = i - beg;
1067     assert(len > 0);
1068 
1069     if (!path.compare(beg, len, ".")) {
1070       ;
1071     } else if (!path.compare(beg, len, "..")) {
1072       if (level == 0) { return false; }
1073       level--;
1074     } else {
1075       level++;
1076     }
1077 
1078     // Skip slash
1079     while (i < path.size() && path[i] == '/') {
1080       i++;
1081     }
1082   }
1083 
1084   return true;
1085 }
1086 
read_file(const std::string & path,std::string & out)1087 inline void read_file(const std::string &path, std::string &out) {
1088   std::ifstream fs(path, std::ios_base::binary);
1089   fs.seekg(0, std::ios_base::end);
1090   auto size = fs.tellg();
1091   fs.seekg(0);
1092   out.resize(static_cast<size_t>(size));
1093   fs.read(&out[0], size);
1094 }
1095 
file_extension(const std::string & path)1096 inline std::string file_extension(const std::string &path) {
1097   std::smatch m;
1098   static auto re = std::regex("\\.([a-zA-Z0-9]+)$");
1099   if (std::regex_search(path, m, re)) { return m[1].str(); }
1100   return std::string();
1101 }
1102 
split(const char * b,const char * e,char d,Fn fn)1103 template <class Fn> void split(const char *b, const char *e, char d, Fn fn) {
1104   int i = 0;
1105   int beg = 0;
1106 
1107   while (e ? (b + i != e) : (b[i] != '\0')) {
1108     if (b[i] == d) {
1109       fn(&b[beg], &b[i]);
1110       beg = i + 1;
1111     }
1112     i++;
1113   }
1114 
1115   if (i) { fn(&b[beg], &b[i]); }
1116 }
1117 
1118 // NOTE: until the read size reaches `fixed_buffer_size`, use `fixed_buffer`
1119 // to store data. The call can set memory on stack for performance.
1120 class stream_line_reader {
1121 public:
stream_line_reader(Stream & strm,char * fixed_buffer,size_t fixed_buffer_size)1122   stream_line_reader(Stream &strm, char *fixed_buffer, size_t fixed_buffer_size)
1123       : strm_(strm), fixed_buffer_(fixed_buffer),
1124         fixed_buffer_size_(fixed_buffer_size) {}
1125 
ptr()1126   const char *ptr() const {
1127     if (glowable_buffer_.empty()) {
1128       return fixed_buffer_;
1129     } else {
1130       return glowable_buffer_.data();
1131     }
1132   }
1133 
size()1134   size_t size() const {
1135     if (glowable_buffer_.empty()) {
1136       return fixed_buffer_used_size_;
1137     } else {
1138       return glowable_buffer_.size();
1139     }
1140   }
1141 
end_with_crlf()1142   bool end_with_crlf() const {
1143     auto end = ptr() + size();
1144     return size() >= 2 && end[-2] == '\r' && end[-1] == '\n';
1145   }
1146 
getline()1147   bool getline() {
1148     fixed_buffer_used_size_ = 0;
1149     glowable_buffer_.clear();
1150 
1151     for (size_t i = 0;; i++) {
1152       char byte;
1153       auto n = strm_.read(&byte, 1);
1154 
1155       if (n < 0) {
1156         return false;
1157       } else if (n == 0) {
1158         if (i == 0) {
1159           return false;
1160         } else {
1161           break;
1162         }
1163       }
1164 
1165       append(byte);
1166 
1167       if (byte == '\n') { break; }
1168     }
1169 
1170     return true;
1171   }
1172 
1173 private:
append(char c)1174   void append(char c) {
1175     if (fixed_buffer_used_size_ < fixed_buffer_size_ - 1) {
1176       fixed_buffer_[fixed_buffer_used_size_++] = c;
1177       fixed_buffer_[fixed_buffer_used_size_] = '\0';
1178     } else {
1179       if (glowable_buffer_.empty()) {
1180         assert(fixed_buffer_[fixed_buffer_used_size_] == '\0');
1181         glowable_buffer_.assign(fixed_buffer_, fixed_buffer_used_size_);
1182       }
1183       glowable_buffer_ += c;
1184     }
1185   }
1186 
1187   Stream &strm_;
1188   char *fixed_buffer_;
1189   const size_t fixed_buffer_size_;
1190   size_t fixed_buffer_used_size_ = 0;
1191   std::string glowable_buffer_;
1192 };
1193 
close_socket(socket_t sock)1194 inline int close_socket(socket_t sock) {
1195 #ifdef _WIN32
1196   return closesocket(sock);
1197 #else
1198   return close(sock);
1199 #endif
1200 }
1201 
select_read(socket_t sock,time_t sec,time_t usec)1202 inline int select_read(socket_t sock, time_t sec, time_t usec) {
1203 #ifdef CPPHTTPLIB_USE_POLL
1204   struct pollfd pfd_read;
1205   pfd_read.fd = sock;
1206   pfd_read.events = POLLIN;
1207 
1208   auto timeout = static_cast<int>(sec * 1000 + usec / 1000);
1209 
1210   return poll(&pfd_read, 1, timeout);
1211 #else
1212   fd_set fds;
1213   FD_ZERO(&fds);
1214   FD_SET(sock, &fds);
1215 
1216   timeval tv;
1217   tv.tv_sec = static_cast<long>(sec);
1218   tv.tv_usec = static_cast<decltype(tv.tv_usec)>(usec);
1219 
1220   return select(static_cast<int>(sock + 1), &fds, nullptr, nullptr, &tv);
1221 #endif
1222 }
1223 
select_write(socket_t sock,time_t sec,time_t usec)1224 inline int select_write(socket_t sock, time_t sec, time_t usec) {
1225 #ifdef CPPHTTPLIB_USE_POLL
1226   struct pollfd pfd_read;
1227   pfd_read.fd = sock;
1228   pfd_read.events = POLLOUT;
1229 
1230   auto timeout = static_cast<int>(sec * 1000 + usec / 1000);
1231 
1232   return poll(&pfd_read, 1, timeout);
1233 #else
1234   fd_set fds;
1235   FD_ZERO(&fds);
1236   FD_SET(sock, &fds);
1237 
1238   timeval tv;
1239   tv.tv_sec = static_cast<long>(sec);
1240   tv.tv_usec = static_cast<decltype(tv.tv_usec)>(usec);
1241 
1242   return select(static_cast<int>(sock + 1), nullptr, &fds, nullptr, &tv);
1243 #endif
1244 }
1245 
wait_until_socket_is_ready(socket_t sock,time_t sec,time_t usec)1246 inline bool wait_until_socket_is_ready(socket_t sock, time_t sec, time_t usec) {
1247 #ifdef CPPHTTPLIB_USE_POLL
1248   struct pollfd pfd_read;
1249   pfd_read.fd = sock;
1250   pfd_read.events = POLLIN | POLLOUT;
1251 
1252   auto timeout = static_cast<int>(sec * 1000 + usec / 1000);
1253 
1254   if (poll(&pfd_read, 1, timeout) > 0 &&
1255       pfd_read.revents & (POLLIN | POLLOUT)) {
1256     int error = 0;
1257     socklen_t len = sizeof(error);
1258     return getsockopt(sock, SOL_SOCKET, SO_ERROR,
1259                       reinterpret_cast<char *>(&error), &len) >= 0 &&
1260            !error;
1261   }
1262   return false;
1263 #else
1264   fd_set fdsr;
1265   FD_ZERO(&fdsr);
1266   FD_SET(sock, &fdsr);
1267 
1268   auto fdsw = fdsr;
1269   auto fdse = fdsr;
1270 
1271   timeval tv;
1272   tv.tv_sec = static_cast<long>(sec);
1273   tv.tv_usec = static_cast<decltype(tv.tv_usec)>(usec);
1274 
1275   if (select(static_cast<int>(sock + 1), &fdsr, &fdsw, &fdse, &tv) > 0 &&
1276       (FD_ISSET(sock, &fdsr) || FD_ISSET(sock, &fdsw))) {
1277     int error = 0;
1278     socklen_t len = sizeof(error);
1279     return getsockopt(sock, SOL_SOCKET, SO_ERROR,
1280                       reinterpret_cast<char *>(&error), &len) >= 0 &&
1281            !error;
1282   }
1283   return false;
1284 #endif
1285 }
1286 
1287 class SocketStream : public Stream {
1288 public:
1289   SocketStream(socket_t sock, time_t read_timeout_sec,
1290                time_t read_timeout_usec);
1291   ~SocketStream() override;
1292 
1293   bool is_readable() const override;
1294   bool is_writable() const override;
1295   ssize_t read(char *ptr, size_t size) override;
1296   ssize_t write(const char *ptr, size_t size) override;
1297   void get_remote_ip_and_port(std::string &ip, int &port) const override;
1298 
1299 private:
1300   socket_t sock_;
1301   time_t read_timeout_sec_;
1302   time_t read_timeout_usec_;
1303 };
1304 
1305 #ifdef CPPHTTPLIB_OPENSSL_SUPPORT
1306 class SSLSocketStream : public Stream {
1307 public:
1308   SSLSocketStream(socket_t sock, SSL *ssl, time_t read_timeout_sec,
1309                   time_t read_timeout_usec);
1310   ~SSLSocketStream() override;
1311 
1312   bool is_readable() const override;
1313   bool is_writable() const override;
1314   ssize_t read(char *ptr, size_t size) override;
1315   ssize_t write(const char *ptr, size_t size) override;
1316   void get_remote_ip_and_port(std::string &ip, int &port) const override;
1317 
1318 private:
1319   socket_t sock_;
1320   SSL *ssl_;
1321   time_t read_timeout_sec_;
1322   time_t read_timeout_usec_;
1323 };
1324 #endif
1325 
1326 class BufferStream : public Stream {
1327 public:
1328   BufferStream() = default;
1329   ~BufferStream() override = default;
1330 
1331   bool is_readable() const override;
1332   bool is_writable() const override;
1333   ssize_t read(char *ptr, size_t size) override;
1334   ssize_t write(const char *ptr, size_t size) override;
1335   void get_remote_ip_and_port(std::string &ip, int &port) const override;
1336 
1337   const std::string &get_buffer() const;
1338 
1339 private:
1340   std::string buffer;
1341   size_t position = 0;
1342 };
1343 
1344 template <typename T>
process_socket(bool is_client_request,socket_t sock,size_t keep_alive_max_count,time_t read_timeout_sec,time_t read_timeout_usec,T callback)1345 inline bool process_socket(bool is_client_request, socket_t sock,
1346                            size_t keep_alive_max_count, time_t read_timeout_sec,
1347                            time_t read_timeout_usec, T callback) {
1348   assert(keep_alive_max_count > 0);
1349 
1350   auto ret = false;
1351 
1352   if (keep_alive_max_count > 1) {
1353     auto count = keep_alive_max_count;
1354     while (count > 0 &&
1355            (is_client_request ||
1356             select_read(sock, CPPHTTPLIB_KEEPALIVE_TIMEOUT_SECOND,
1357                         CPPHTTPLIB_KEEPALIVE_TIMEOUT_USECOND) > 0)) {
1358       SocketStream strm(sock, read_timeout_sec, read_timeout_usec);
1359       auto last_connection = count == 1;
1360       auto connection_close = false;
1361 
1362       ret = callback(strm, last_connection, connection_close);
1363       if (!ret || connection_close) { break; }
1364 
1365       count--;
1366     }
1367   } else { // keep_alive_max_count  is 0 or 1
1368     SocketStream strm(sock, read_timeout_sec, read_timeout_usec);
1369     auto dummy_connection_close = false;
1370     ret = callback(strm, true, dummy_connection_close);
1371   }
1372 
1373   return ret;
1374 }
1375 
1376 template <typename T>
process_and_close_socket(bool is_client_request,socket_t sock,size_t keep_alive_max_count,time_t read_timeout_sec,time_t read_timeout_usec,T callback)1377 inline bool process_and_close_socket(bool is_client_request, socket_t sock,
1378                                      size_t keep_alive_max_count,
1379                                      time_t read_timeout_sec,
1380                                      time_t read_timeout_usec, T callback) {
1381   auto ret = process_socket(is_client_request, sock, keep_alive_max_count,
1382                             read_timeout_sec, read_timeout_usec, callback);
1383   close_socket(sock);
1384   return ret;
1385 }
1386 
shutdown_socket(socket_t sock)1387 inline int shutdown_socket(socket_t sock) {
1388 #ifdef _WIN32
1389   return shutdown(sock, SD_BOTH);
1390 #else
1391   return shutdown(sock, SHUT_RDWR);
1392 #endif
1393 }
1394 
1395 template <typename Fn>
1396 socket_t create_socket(const char *host, int port, Fn fn,
1397                        int socket_flags = 0) {
1398 #ifdef _WIN32
1399 #define SO_SYNCHRONOUS_NONALERT 0x20
1400 #define SO_OPENTYPE 0x7008
1401 
1402   int opt = SO_SYNCHRONOUS_NONALERT;
1403   setsockopt(INVALID_SOCKET, SOL_SOCKET, SO_OPENTYPE, (char *)&opt,
1404              sizeof(opt));
1405 #endif
1406 
1407   // Get address info
1408   struct addrinfo hints;
1409   struct addrinfo *result;
1410 
1411   memset(&hints, 0, sizeof(struct addrinfo));
1412   hints.ai_family = AF_UNSPEC;
1413   hints.ai_socktype = SOCK_STREAM;
1414   hints.ai_flags = socket_flags;
1415   hints.ai_protocol = 0;
1416 
1417   auto service = std::to_string(port);
1418 
1419   if (getaddrinfo(host, service.c_str(), &hints, &result)) {
1420     return INVALID_SOCKET;
1421   }
1422 
1423   for (auto rp = result; rp; rp = rp->ai_next) {
1424     // Create a socket
1425 #ifdef _WIN32
1426     auto sock = WSASocketW(rp->ai_family, rp->ai_socktype, rp->ai_protocol,
1427                            nullptr, 0, WSA_FLAG_NO_HANDLE_INHERIT);
1428     /**
1429      * Since the WSA_FLAG_NO_HANDLE_INHERIT is only supported on Windows 7 SP1
1430      * and above the socket creation fails on older Windows Systems.
1431      *
1432      * Let's try to create a socket the old way in this case.
1433      *
1434      * Reference:
1435      * https://docs.microsoft.com/en-us/windows/win32/api/winsock2/nf-winsock2-wsasocketa
1436      *
1437      * WSA_FLAG_NO_HANDLE_INHERIT:
1438      * This flag is supported on Windows 7 with SP1, Windows Server 2008 R2 with
1439      * SP1, and later
1440      *
1441      */
1442     if (sock == INVALID_SOCKET) {
1443       sock = socket(rp->ai_family, rp->ai_socktype, rp->ai_protocol);
1444     }
1445 #else
1446     auto sock = socket(rp->ai_family, rp->ai_socktype, rp->ai_protocol);
1447 #endif
1448     if (sock == INVALID_SOCKET) { continue; }
1449 
1450 #ifndef _WIN32
1451     if (fcntl(sock, F_SETFD, FD_CLOEXEC) == -1) { continue; }
1452 #endif
1453 
1454     // Make 'reuse address' option available
1455     int yes = 1;
1456     setsockopt(sock, SOL_SOCKET, SO_REUSEADDR, reinterpret_cast<char *>(&yes),
1457                sizeof(yes));
1458 #ifdef SO_REUSEPORT
1459     setsockopt(sock, SOL_SOCKET, SO_REUSEPORT, reinterpret_cast<char *>(&yes),
1460                sizeof(yes));
1461 #endif
1462 
1463     // bind or connect
1464     if (fn(sock, *rp)) {
1465       freeaddrinfo(result);
1466       return sock;
1467     }
1468 
1469     close_socket(sock);
1470   }
1471 
1472   freeaddrinfo(result);
1473   return INVALID_SOCKET;
1474 }
1475 
set_nonblocking(socket_t sock,bool nonblocking)1476 inline void set_nonblocking(socket_t sock, bool nonblocking) {
1477 #ifdef _WIN32
1478   auto flags = nonblocking ? 1UL : 0UL;
1479   ioctlsocket(sock, FIONBIO, &flags);
1480 #else
1481   auto flags = fcntl(sock, F_GETFL, 0);
1482   fcntl(sock, F_SETFL,
1483         nonblocking ? (flags | O_NONBLOCK) : (flags & (~O_NONBLOCK)));
1484 #endif
1485 }
1486 
is_connection_error()1487 inline bool is_connection_error() {
1488 #ifdef _WIN32
1489   return WSAGetLastError() != WSAEWOULDBLOCK;
1490 #else
1491   return errno != EINPROGRESS;
1492 #endif
1493 }
1494 
bind_ip_address(socket_t sock,const char * host)1495 inline bool bind_ip_address(socket_t sock, const char *host) {
1496   struct addrinfo hints;
1497   struct addrinfo *result;
1498 
1499   memset(&hints, 0, sizeof(struct addrinfo));
1500   hints.ai_family = AF_UNSPEC;
1501   hints.ai_socktype = SOCK_STREAM;
1502   hints.ai_protocol = 0;
1503 
1504   if (getaddrinfo(host, "0", &hints, &result)) { return false; }
1505 
1506   auto ret = false;
1507   for (auto rp = result; rp; rp = rp->ai_next) {
1508     const auto &ai = *rp;
1509     if (!::bind(sock, ai.ai_addr, static_cast<socklen_t>(ai.ai_addrlen))) {
1510       ret = true;
1511       break;
1512     }
1513   }
1514 
1515   freeaddrinfo(result);
1516   return ret;
1517 }
1518 
if2ip(const std::string & ifn)1519 inline std::string if2ip(const std::string &ifn) {
1520 #ifndef _WIN32
1521   struct ifaddrs *ifap;
1522   getifaddrs(&ifap);
1523   for (auto ifa = ifap; ifa; ifa = ifa->ifa_next) {
1524     if (ifa->ifa_addr && ifn == ifa->ifa_name) {
1525       if (ifa->ifa_addr->sa_family == AF_INET) {
1526         auto sa = reinterpret_cast<struct sockaddr_in *>(ifa->ifa_addr);
1527         char buf[INET_ADDRSTRLEN];
1528         if (inet_ntop(AF_INET, &sa->sin_addr, buf, INET_ADDRSTRLEN)) {
1529           freeifaddrs(ifap);
1530           return std::string(buf, INET_ADDRSTRLEN);
1531         }
1532       }
1533     }
1534   }
1535   freeifaddrs(ifap);
1536 #endif
1537   return std::string();
1538 }
1539 
create_client_socket(const char * host,int port,time_t timeout_sec,const std::string & intf)1540 inline socket_t create_client_socket(const char *host, int port,
1541                                      time_t timeout_sec,
1542                                      const std::string &intf) {
1543   return create_socket(
1544       host, port, [&](socket_t sock, struct addrinfo &ai) -> bool {
1545         if (!intf.empty()) {
1546           auto ip = if2ip(intf);
1547           if (ip.empty()) { ip = intf; }
1548           if (!bind_ip_address(sock, ip.c_str())) { return false; }
1549         }
1550 
1551         set_nonblocking(sock, true);
1552 
1553         auto ret =
1554             ::connect(sock, ai.ai_addr, static_cast<socklen_t>(ai.ai_addrlen));
1555         if (ret < 0) {
1556           if (is_connection_error() ||
1557               !wait_until_socket_is_ready(sock, timeout_sec, 0)) {
1558             close_socket(sock);
1559             return false;
1560           }
1561         }
1562 
1563         set_nonblocking(sock, false);
1564         return true;
1565       });
1566 }
1567 
get_remote_ip_and_port(const struct sockaddr_storage & addr,socklen_t addr_len,std::string & ip,int & port)1568 inline void get_remote_ip_and_port(const struct sockaddr_storage &addr,
1569                                    socklen_t addr_len, std::string &ip,
1570                                    int &port) {
1571   if (addr.ss_family == AF_INET) {
1572     port = ntohs(reinterpret_cast<const struct sockaddr_in *>(&addr)->sin_port);
1573   } else if (addr.ss_family == AF_INET6) {
1574     port =
1575         ntohs(reinterpret_cast<const struct sockaddr_in6 *>(&addr)->sin6_port);
1576   }
1577 
1578   std::array<char, NI_MAXHOST> ipstr{};
1579   if (!getnameinfo(reinterpret_cast<const struct sockaddr *>(&addr), addr_len,
1580                    ipstr.data(), static_cast<socklen_t>(ipstr.size()), nullptr,
1581                    0, NI_NUMERICHOST)) {
1582     ip = ipstr.data();
1583   }
1584 }
1585 
get_remote_ip_and_port(socket_t sock,std::string & ip,int & port)1586 inline void get_remote_ip_and_port(socket_t sock, std::string &ip, int &port) {
1587   struct sockaddr_storage addr;
1588   socklen_t addr_len = sizeof(addr);
1589 
1590   if (!getpeername(sock, reinterpret_cast<struct sockaddr *>(&addr),
1591                    &addr_len)) {
1592     get_remote_ip_and_port(addr, addr_len, ip, port);
1593   }
1594 }
1595 
1596 inline const char *
find_content_type(const std::string & path,const std::map<std::string,std::string> & user_data)1597 find_content_type(const std::string &path,
1598                   const std::map<std::string, std::string> &user_data) {
1599   auto ext = file_extension(path);
1600 
1601   auto it = user_data.find(ext);
1602   if (it != user_data.end()) { return it->second.c_str(); }
1603 
1604   if (ext == "txt") {
1605     return "text/plain";
1606   } else if (ext == "html" || ext == "htm") {
1607     return "text/html";
1608   } else if (ext == "css") {
1609     return "text/css";
1610   } else if (ext == "jpeg" || ext == "jpg") {
1611     return "image/jpg";
1612   } else if (ext == "png") {
1613     return "image/png";
1614   } else if (ext == "gif") {
1615     return "image/gif";
1616   } else if (ext == "svg") {
1617     return "image/svg+xml";
1618   } else if (ext == "ico") {
1619     return "image/x-icon";
1620   } else if (ext == "json") {
1621     return "application/json";
1622   } else if (ext == "pdf") {
1623     return "application/pdf";
1624   } else if (ext == "js") {
1625     return "application/javascript";
1626   } else if (ext == "wasm") {
1627     return "application/wasm";
1628   } else if (ext == "xml") {
1629     return "application/xml";
1630   } else if (ext == "xhtml") {
1631     return "application/xhtml+xml";
1632   }
1633   return nullptr;
1634 }
1635 
status_message(int status)1636 inline const char *status_message(int status) {
1637   switch (status) {
1638   case 100: return "Continue";
1639   case 101: return "Switching Protocol";
1640   case 102: return "Processing";
1641   case 103: return "Early Hints";
1642   case 200: return "OK";
1643   case 201: return "Created";
1644   case 202: return "Accepted";
1645   case 203: return "Non-Authoritative Information";
1646   case 204: return "No Content";
1647   case 205: return "Reset Content";
1648   case 206: return "Partial Content";
1649   case 207: return "Multi-Status";
1650   case 208: return "Already Reported";
1651   case 226: return "IM Used";
1652   case 300: return "Multiple Choice";
1653   case 301: return "Moved Permanently";
1654   case 302: return "Found";
1655   case 303: return "See Other";
1656   case 304: return "Not Modified";
1657   case 305: return "Use Proxy";
1658   case 306: return "unused";
1659   case 307: return "Temporary Redirect";
1660   case 308: return "Permanent Redirect";
1661   case 400: return "Bad Request";
1662   case 401: return "Unauthorized";
1663   case 402: return "Payment Required";
1664   case 403: return "Forbidden";
1665   case 404: return "Not Found";
1666   case 405: return "Method Not Allowed";
1667   case 406: return "Not Acceptable";
1668   case 407: return "Proxy Authentication Required";
1669   case 408: return "Request Timeout";
1670   case 409: return "Conflict";
1671   case 410: return "Gone";
1672   case 411: return "Length Required";
1673   case 412: return "Precondition Failed";
1674   case 413: return "Payload Too Large";
1675   case 414: return "URI Too Long";
1676   case 415: return "Unsupported Media Type";
1677   case 416: return "Range Not Satisfiable";
1678   case 417: return "Expectation Failed";
1679   case 418: return "I'm a teapot";
1680   case 421: return "Misdirected Request";
1681   case 422: return "Unprocessable Entity";
1682   case 423: return "Locked";
1683   case 424: return "Failed Dependency";
1684   case 425: return "Too Early";
1685   case 426: return "Upgrade Required";
1686   case 428: return "Precondition Required";
1687   case 429: return "Too Many Requests";
1688   case 431: return "Request Header Fields Too Large";
1689   case 451: return "Unavailable For Legal Reasons";
1690   case 501: return "Not Implemented";
1691   case 502: return "Bad Gateway";
1692   case 503: return "Service Unavailable";
1693   case 504: return "Gateway Timeout";
1694   case 505: return "HTTP Version Not Supported";
1695   case 506: return "Variant Also Negotiates";
1696   case 507: return "Insufficient Storage";
1697   case 508: return "Loop Detected";
1698   case 510: return "Not Extended";
1699   case 511: return "Network Authentication Required";
1700 
1701   default:
1702   case 500: return "Internal Server Error";
1703   }
1704 }
1705 
1706 #ifdef CPPHTTPLIB_ZLIB_SUPPORT
can_compress(const std::string & content_type)1707 inline bool can_compress(const std::string &content_type) {
1708   return !content_type.find("text/") || content_type == "image/svg+xml" ||
1709          content_type == "application/javascript" ||
1710          content_type == "application/json" ||
1711          content_type == "application/xml" ||
1712          content_type == "application/xhtml+xml";
1713 }
1714 
compress(std::string & content)1715 inline bool compress(std::string &content) {
1716   z_stream strm;
1717   strm.zalloc = Z_NULL;
1718   strm.zfree = Z_NULL;
1719   strm.opaque = Z_NULL;
1720 
1721   auto ret = deflateInit2(&strm, Z_DEFAULT_COMPRESSION, Z_DEFLATED, 31, 8,
1722                           Z_DEFAULT_STRATEGY);
1723   if (ret != Z_OK) { return false; }
1724 
1725   strm.avail_in = static_cast<decltype(strm.avail_in)>(content.size());
1726   strm.next_in =
1727       const_cast<Bytef *>(reinterpret_cast<const Bytef *>(content.data()));
1728 
1729   std::string compressed;
1730 
1731   std::array<char, 16384> buff{};
1732   do {
1733     strm.avail_out = buff.size();
1734     strm.next_out = reinterpret_cast<Bytef *>(buff.data());
1735     ret = deflate(&strm, Z_FINISH);
1736     assert(ret != Z_STREAM_ERROR);
1737     compressed.append(buff.data(), buff.size() - strm.avail_out);
1738   } while (strm.avail_out == 0);
1739 
1740   assert(ret == Z_STREAM_END);
1741   assert(strm.avail_in == 0);
1742 
1743   content.swap(compressed);
1744 
1745   deflateEnd(&strm);
1746   return true;
1747 }
1748 
1749 class decompressor {
1750 public:
decompressor()1751   decompressor() {
1752     std::memset(&strm, 0, sizeof(strm));
1753     strm.zalloc = Z_NULL;
1754     strm.zfree = Z_NULL;
1755     strm.opaque = Z_NULL;
1756 
1757     // 15 is the value of wbits, which should be at the maximum possible value
1758     // to ensure that any gzip stream can be decoded. The offset of 32 specifies
1759     // that the stream type should be automatically detected either gzip or
1760     // deflate.
1761     is_valid_ = inflateInit2(&strm, 32 + 15) == Z_OK;
1762   }
1763 
~decompressor()1764   ~decompressor() { inflateEnd(&strm); }
1765 
is_valid()1766   bool is_valid() const { return is_valid_; }
1767 
1768   template <typename T>
decompress(const char * data,size_t data_length,T callback)1769   bool decompress(const char *data, size_t data_length, T callback) {
1770     int ret = Z_OK;
1771 
1772     strm.avail_in = static_cast<decltype(strm.avail_in)>(data_length);
1773     strm.next_in = const_cast<Bytef *>(reinterpret_cast<const Bytef *>(data));
1774 
1775     std::array<char, 16384> buff{};
1776     do {
1777       strm.avail_out = buff.size();
1778       strm.next_out = reinterpret_cast<Bytef *>(buff.data());
1779 
1780       ret = inflate(&strm, Z_NO_FLUSH);
1781       assert(ret != Z_STREAM_ERROR);
1782       switch (ret) {
1783       case Z_NEED_DICT:
1784       case Z_DATA_ERROR:
1785       case Z_MEM_ERROR: inflateEnd(&strm); return false;
1786       }
1787 
1788       if (!callback(buff.data(), buff.size() - strm.avail_out)) {
1789         return false;
1790       }
1791     } while (strm.avail_out == 0);
1792 
1793     return ret == Z_OK || ret == Z_STREAM_END;
1794   }
1795 
1796 private:
1797   bool is_valid_;
1798   z_stream strm;
1799 };
1800 #endif
1801 
has_header(const Headers & headers,const char * key)1802 inline bool has_header(const Headers &headers, const char *key) {
1803   return headers.find(key) != headers.end();
1804 }
1805 
1806 inline const char *get_header_value(const Headers &headers, const char *key,
1807                                     size_t id = 0, const char *def = nullptr) {
1808   auto it = headers.find(key);
1809   std::advance(it, static_cast<int>(id));
1810   if (it != headers.end()) { return it->second.c_str(); }
1811   return def;
1812 }
1813 
1814 inline uint64_t get_header_value_uint64(const Headers &headers, const char *key,
1815                                         uint64_t def = 0) {
1816   auto it = headers.find(key);
1817   if (it != headers.end()) {
1818     return std::strtoull(it->second.data(), nullptr, 10);
1819   }
1820   return def;
1821 }
1822 
read_headers(Stream & strm,Headers & headers)1823 inline bool read_headers(Stream &strm, Headers &headers) {
1824   const auto bufsiz = 2048;
1825   char buf[bufsiz];
1826   stream_line_reader line_reader(strm, buf, bufsiz);
1827 
1828   for (;;) {
1829     if (!line_reader.getline()) { return false; }
1830 
1831     // Check if the line ends with CRLF.
1832     if (line_reader.end_with_crlf()) {
1833       // Blank line indicates end of headers.
1834       if (line_reader.size() == 2) { break; }
1835     } else {
1836       continue; // Skip invalid line.
1837     }
1838 
1839     // Skip trailing spaces and tabs.
1840     auto end = line_reader.ptr() + line_reader.size() - 2;
1841     while (line_reader.ptr() < end && (end[-1] == ' ' || end[-1] == '\t')) {
1842       end--;
1843     }
1844 
1845     // Horizontal tab and ' ' are considered whitespace and are ignored when on
1846     // the left or right side of the header value:
1847     //  - https://stackoverflow.com/questions/50179659/
1848     //  - https://www.w3.org/Protocols/rfc2616/rfc2616-sec4.html
1849     static const std::regex re(R"(([^:]+):[\t ]*(.+))");
1850 
1851     std::cmatch m;
1852     if (std::regex_match(line_reader.ptr(), end, m, re)) {
1853       auto key = std::string(m[1]);
1854       auto val = std::string(m[2]);
1855       headers.emplace(key, val);
1856     }
1857   }
1858 
1859   return true;
1860 }
1861 
read_content_with_length(Stream & strm,uint64_t len,Progress progress,ContentReceiver out)1862 inline bool read_content_with_length(Stream &strm, uint64_t len,
1863                                      Progress progress, ContentReceiver out) {
1864   char buf[CPPHTTPLIB_RECV_BUFSIZ];
1865 
1866   uint64_t r = 0;
1867   while (r < len) {
1868     auto read_len = static_cast<size_t>(len - r);
1869     auto n = strm.read(buf, (std::min)(read_len, CPPHTTPLIB_RECV_BUFSIZ));
1870     if (n <= 0) { return false; }
1871 
1872     if (!out(buf, static_cast<size_t>(n))) { return false; }
1873 
1874     r += static_cast<uint64_t>(n);
1875 
1876     if (progress) {
1877       if (!progress(r, len)) { return false; }
1878     }
1879   }
1880 
1881   return true;
1882 }
1883 
skip_content_with_length(Stream & strm,uint64_t len)1884 inline void skip_content_with_length(Stream &strm, uint64_t len) {
1885   char buf[CPPHTTPLIB_RECV_BUFSIZ];
1886   uint64_t r = 0;
1887   while (r < len) {
1888     auto read_len = static_cast<size_t>(len - r);
1889     auto n = strm.read(buf, (std::min)(read_len, CPPHTTPLIB_RECV_BUFSIZ));
1890     if (n <= 0) { return; }
1891     r += static_cast<uint64_t>(n);
1892   }
1893 }
1894 
read_content_without_length(Stream & strm,ContentReceiver out)1895 inline bool read_content_without_length(Stream &strm, ContentReceiver out) {
1896   char buf[CPPHTTPLIB_RECV_BUFSIZ];
1897   for (;;) {
1898     auto n = strm.read(buf, CPPHTTPLIB_RECV_BUFSIZ);
1899     if (n < 0) {
1900       return false;
1901     } else if (n == 0) {
1902       return true;
1903     }
1904     if (!out(buf, static_cast<size_t>(n))) { return false; }
1905   }
1906 
1907   return true;
1908 }
1909 
read_content_chunked(Stream & strm,ContentReceiver out)1910 inline bool read_content_chunked(Stream &strm, ContentReceiver out) {
1911   const auto bufsiz = 16;
1912   char buf[bufsiz];
1913 
1914   stream_line_reader line_reader(strm, buf, bufsiz);
1915 
1916   if (!line_reader.getline()) { return false; }
1917 
1918   unsigned long chunk_len;
1919   while (true) {
1920     char *end_ptr;
1921 
1922     chunk_len = std::strtoul(line_reader.ptr(), &end_ptr, 16);
1923 
1924     if (end_ptr == line_reader.ptr()) { return false; }
1925     if (chunk_len == ULONG_MAX) { return false; }
1926 
1927     if (chunk_len == 0) { break; }
1928 
1929     if (!read_content_with_length(strm, chunk_len, nullptr, out)) {
1930       return false;
1931     }
1932 
1933     if (!line_reader.getline()) { return false; }
1934 
1935     if (strcmp(line_reader.ptr(), "\r\n")) { break; }
1936 
1937     if (!line_reader.getline()) { return false; }
1938   }
1939 
1940   if (chunk_len == 0) {
1941     // Reader terminator after chunks
1942     if (!line_reader.getline() || strcmp(line_reader.ptr(), "\r\n"))
1943       return false;
1944   }
1945 
1946   return true;
1947 }
1948 
is_chunked_transfer_encoding(const Headers & headers)1949 inline bool is_chunked_transfer_encoding(const Headers &headers) {
1950   return !strcasecmp(get_header_value(headers, "Transfer-Encoding", 0, ""),
1951                      "chunked");
1952 }
1953 
1954 template <typename T>
read_content(Stream & strm,T & x,size_t payload_max_length,int & status,Progress progress,ContentReceiver receiver)1955 bool read_content(Stream &strm, T &x, size_t payload_max_length, int &status,
1956                   Progress progress, ContentReceiver receiver) {
1957 
1958   ContentReceiver out = [&](const char *buf, size_t n) {
1959     return receiver(buf, n);
1960   };
1961 
1962 #ifdef CPPHTTPLIB_ZLIB_SUPPORT
1963   decompressor decompressor;
1964 
1965   std::string content_encoding = x.get_header_value("Content-Encoding");
1966   if (content_encoding.find("gzip") != std::string::npos ||
1967       content_encoding.find("deflate") != std::string::npos) {
1968     if (!decompressor.is_valid()) {
1969       status = 500;
1970       return false;
1971     }
1972 
1973     out = [&](const char *buf, size_t n) {
1974       return decompressor.decompress(
1975           buf, n, [&](const char *buf, size_t n) { return receiver(buf, n); });
1976     };
1977   }
1978 #else
1979   if (x.get_header_value("Content-Encoding") == "gzip") {
1980     status = 415;
1981     return false;
1982   }
1983 #endif
1984 
1985   auto ret = true;
1986   auto exceed_payload_max_length = false;
1987 
1988   if (is_chunked_transfer_encoding(x.headers)) {
1989     ret = read_content_chunked(strm, out);
1990   } else if (!has_header(x.headers, "Content-Length")) {
1991     ret = read_content_without_length(strm, out);
1992   } else {
1993     auto len = get_header_value_uint64(x.headers, "Content-Length", 0);
1994     if (len > payload_max_length) {
1995       exceed_payload_max_length = true;
1996       skip_content_with_length(strm, len);
1997       ret = false;
1998     } else if (len > 0) {
1999       ret = read_content_with_length(strm, len, progress, out);
2000     }
2001   }
2002 
2003   if (!ret) { status = exceed_payload_max_length ? 413 : 400; }
2004 
2005   return ret;
2006 }
2007 
2008 template <typename T>
write_headers(Stream & strm,const T & info,const Headers & headers)2009 inline ssize_t write_headers(Stream &strm, const T &info,
2010                              const Headers &headers) {
2011   ssize_t write_len = 0;
2012   for (const auto &x : info.headers) {
2013     if (x.first == "EXCEPTION_WHAT") { continue; }
2014     auto len =
2015         strm.write_format("%s: %s\r\n", x.first.c_str(), x.second.c_str());
2016     if (len < 0) { return len; }
2017     write_len += len;
2018   }
2019   for (const auto &x : headers) {
2020     auto len =
2021         strm.write_format("%s: %s\r\n", x.first.c_str(), x.second.c_str());
2022     if (len < 0) { return len; }
2023     write_len += len;
2024   }
2025   auto len = strm.write("\r\n");
2026   if (len < 0) { return len; }
2027   write_len += len;
2028   return write_len;
2029 }
2030 
write_content(Stream & strm,ContentProvider content_provider,size_t offset,size_t length)2031 inline ssize_t write_content(Stream &strm, ContentProvider content_provider,
2032                              size_t offset, size_t length) {
2033   size_t begin_offset = offset;
2034   size_t end_offset = offset + length;
2035   while (offset < end_offset) {
2036     ssize_t written_length = 0;
2037 
2038     DataSink data_sink;
2039     data_sink.write = [&](const char *d, size_t l) {
2040       offset += l;
2041       written_length = strm.write(d, l);
2042     };
2043     data_sink.done = [&](void) { written_length = -1; };
2044     data_sink.is_writable = [&](void) { return strm.is_writable(); };
2045 
2046     content_provider(offset, end_offset - offset, data_sink);
2047     if (written_length < 0) { return written_length; }
2048   }
2049   return static_cast<ssize_t>(offset - begin_offset);
2050 }
2051 
2052 template <typename T>
write_content_chunked(Stream & strm,ContentProvider content_provider,T is_shutting_down)2053 inline ssize_t write_content_chunked(Stream &strm,
2054                                      ContentProvider content_provider,
2055                                      T is_shutting_down) {
2056   size_t offset = 0;
2057   auto data_available = true;
2058   ssize_t total_written_length = 0;
2059   while (data_available && !is_shutting_down()) {
2060     ssize_t written_length = 0;
2061 
2062     DataSink data_sink;
2063     data_sink.write = [&](const char *d, size_t l) {
2064       data_available = l > 0;
2065       offset += l;
2066 
2067       // Emit chunked response header and footer for each chunk
2068       auto chunk = from_i_to_hex(l) + "\r\n" + std::string(d, l) + "\r\n";
2069       written_length = strm.write(chunk);
2070     };
2071     data_sink.done = [&](void) {
2072       data_available = false;
2073       written_length = strm.write("0\r\n\r\n");
2074     };
2075     data_sink.is_writable = [&](void) { return strm.is_writable(); };
2076 
2077     content_provider(offset, 0, data_sink);
2078 
2079     if (written_length < 0) { return written_length; }
2080     total_written_length += written_length;
2081   }
2082   return total_written_length;
2083 }
2084 
2085 template <typename T>
redirect(T & cli,const Request & req,Response & res,const std::string & path)2086 inline bool redirect(T &cli, const Request &req, Response &res,
2087                      const std::string &path) {
2088   Request new_req = req;
2089   new_req.path = path;
2090   new_req.redirect_count -= 1;
2091 
2092   if (res.status == 303 && (req.method != "GET" && req.method != "HEAD")) {
2093     new_req.method = "GET";
2094     new_req.body.clear();
2095     new_req.headers.clear();
2096   }
2097 
2098   Response new_res;
2099 
2100   auto ret = cli.send(new_req, new_res);
2101   if (ret) { res = new_res; }
2102   return ret;
2103 }
2104 
encode_url(const std::string & s)2105 inline std::string encode_url(const std::string &s) {
2106   std::string result;
2107 
2108   for (size_t i = 0; s[i]; i++) {
2109     switch (s[i]) {
2110     case ' ': result += "%20"; break;
2111     case '+': result += "%2B"; break;
2112     case '\r': result += "%0D"; break;
2113     case '\n': result += "%0A"; break;
2114     case '\'': result += "%27"; break;
2115     case ',': result += "%2C"; break;
2116     // case ':': result += "%3A"; break; // ok? probably...
2117     case ';': result += "%3B"; break;
2118     default:
2119       auto c = static_cast<uint8_t>(s[i]);
2120       if (c >= 0x80) {
2121         result += '%';
2122         char hex[4];
2123         auto len = snprintf(hex, sizeof(hex) - 1, "%02X", c);
2124         assert(len == 2);
2125         result.append(hex, static_cast<size_t>(len));
2126       } else {
2127         result += s[i];
2128       }
2129       break;
2130     }
2131   }
2132 
2133   return result;
2134 }
2135 
decode_url(const std::string & s,bool convert_plus_to_space)2136 inline std::string decode_url(const std::string &s,
2137                               bool convert_plus_to_space) {
2138   std::string result;
2139 
2140   for (size_t i = 0; i < s.size(); i++) {
2141     if (s[i] == '%' && i + 1 < s.size()) {
2142       if (s[i + 1] == 'u') {
2143         int val = 0;
2144         if (from_hex_to_i(s, i + 2, 4, val)) {
2145           // 4 digits Unicode codes
2146           char buff[4];
2147           size_t len = to_utf8(val, buff);
2148           if (len > 0) { result.append(buff, len); }
2149           i += 5; // 'u0000'
2150         } else {
2151           result += s[i];
2152         }
2153       } else {
2154         int val = 0;
2155         if (from_hex_to_i(s, i + 1, 2, val)) {
2156           // 2 digits hex codes
2157           result += static_cast<char>(val);
2158           i += 2; // '00'
2159         } else {
2160           result += s[i];
2161         }
2162       }
2163     } else if (convert_plus_to_space && s[i] == '+') {
2164       result += ' ';
2165     } else {
2166       result += s[i];
2167     }
2168   }
2169 
2170   return result;
2171 }
2172 
params_to_query_str(const Params & params)2173 inline std::string params_to_query_str(const Params &params) {
2174   std::string query;
2175 
2176   for (auto it = params.begin(); it != params.end(); ++it) {
2177     if (it != params.begin()) { query += "&"; }
2178     query += it->first;
2179     query += "=";
2180     query += detail::encode_url(it->second);
2181   }
2182 
2183   return query;
2184 }
2185 
parse_query_text(const std::string & s,Params & params)2186 inline void parse_query_text(const std::string &s, Params &params) {
2187   split(&s[0], &s[s.size()], '&', [&](const char *b, const char *e) {
2188     std::string key;
2189     std::string val;
2190     split(b, e, '=', [&](const char *b2, const char *e2) {
2191       if (key.empty()) {
2192         key.assign(b2, e2);
2193       } else {
2194         val.assign(b2, e2);
2195       }
2196     });
2197     params.emplace(decode_url(key, true), decode_url(val, true));
2198   });
2199 }
2200 
parse_multipart_boundary(const std::string & content_type,std::string & boundary)2201 inline bool parse_multipart_boundary(const std::string &content_type,
2202                                      std::string &boundary) {
2203   auto pos = content_type.find("boundary=");
2204   if (pos == std::string::npos) { return false; }
2205 
2206   boundary = content_type.substr(pos + 9);
2207   return true;
2208 }
2209 
parse_range_header(const std::string & s,Ranges & ranges)2210 inline bool parse_range_header(const std::string &s, Ranges &ranges) {
2211   static auto re_first_range = std::regex(R"(bytes=(\d*-\d*(?:,\s*\d*-\d*)*))");
2212   std::smatch m;
2213   if (std::regex_match(s, m, re_first_range)) {
2214     auto pos = static_cast<size_t>(m.position(1));
2215     auto len = static_cast<size_t>(m.length(1));
2216     bool all_valid_ranges = true;
2217     split(&s[pos], &s[pos + len], ',', [&](const char *b, const char *e) {
2218       if (!all_valid_ranges) return;
2219       static auto re_another_range = std::regex(R"(\s*(\d*)-(\d*))");
2220       std::cmatch cm;
2221       if (std::regex_match(b, e, cm, re_another_range)) {
2222         ssize_t first = -1;
2223         if (!cm.str(1).empty()) {
2224           first = static_cast<ssize_t>(std::stoll(cm.str(1)));
2225         }
2226 
2227         ssize_t last = -1;
2228         if (!cm.str(2).empty()) {
2229           last = static_cast<ssize_t>(std::stoll(cm.str(2)));
2230         }
2231 
2232         if (first != -1 && last != -1 && first > last) {
2233           all_valid_ranges = false;
2234           return;
2235         }
2236         ranges.emplace_back(std::make_pair(first, last));
2237       }
2238     });
2239     return all_valid_ranges;
2240   }
2241   return false;
2242 }
2243 
2244 class MultipartFormDataParser {
2245 public:
2246   MultipartFormDataParser() = default;
2247 
set_boundary(std::string boundary)2248   void set_boundary(std::string boundary) { boundary_ = std::move(boundary); }
2249 
is_valid()2250   bool is_valid() const { return is_valid_; }
2251 
2252   template <typename T, typename U>
parse(const char * buf,size_t n,T content_callback,U header_callback)2253   bool parse(const char *buf, size_t n, T content_callback, U header_callback) {
2254     static const std::regex re_content_type(R"(^Content-Type:\s*(.*?)\s*$)",
2255                                             std::regex_constants::icase);
2256 
2257     static const std::regex re_content_disposition(
2258         "^Content-Disposition:\\s*form-data;\\s*name=\"(.*?)\"(?:;\\s*filename="
2259         "\"(.*?)\")?\\s*$",
2260         std::regex_constants::icase);
2261     static const std::string dash_ = "--";
2262     static const std::string crlf_ = "\r\n";
2263 
2264     buf_.append(buf, n); // TODO: performance improvement
2265 
2266     while (!buf_.empty()) {
2267       switch (state_) {
2268       case 0: { // Initial boundary
2269         auto pattern = dash_ + boundary_ + crlf_;
2270         if (pattern.size() > buf_.size()) { return true; }
2271         auto pos = buf_.find(pattern);
2272         if (pos != 0) {
2273           is_done_ = true;
2274           return false;
2275         }
2276         buf_.erase(0, pattern.size());
2277         off_ += pattern.size();
2278         state_ = 1;
2279         break;
2280       }
2281       case 1: { // New entry
2282         clear_file_info();
2283         state_ = 2;
2284         break;
2285       }
2286       case 2: { // Headers
2287         auto pos = buf_.find(crlf_);
2288         while (pos != std::string::npos) {
2289           // Empty line
2290           if (pos == 0) {
2291             if (!header_callback(file_)) {
2292               is_valid_ = false;
2293               is_done_ = false;
2294               return false;
2295             }
2296             buf_.erase(0, crlf_.size());
2297             off_ += crlf_.size();
2298             state_ = 3;
2299             break;
2300           }
2301 
2302           auto header = buf_.substr(0, pos);
2303           {
2304             std::smatch m;
2305             if (std::regex_match(header, m, re_content_type)) {
2306               file_.content_type = m[1];
2307             } else if (std::regex_match(header, m, re_content_disposition)) {
2308               file_.name = m[1];
2309               file_.filename = m[2];
2310             }
2311           }
2312 
2313           buf_.erase(0, pos + crlf_.size());
2314           off_ += pos + crlf_.size();
2315           pos = buf_.find(crlf_);
2316         }
2317         break;
2318       }
2319       case 3: { // Body
2320         {
2321           auto pattern = crlf_ + dash_;
2322           if (pattern.size() > buf_.size()) { return true; }
2323 
2324           auto pos = buf_.find(pattern);
2325           if (pos == std::string::npos) { pos = buf_.size(); }
2326           if (!content_callback(buf_.data(), pos)) {
2327             is_valid_ = false;
2328             is_done_ = false;
2329             return false;
2330           }
2331 
2332           off_ += pos;
2333           buf_.erase(0, pos);
2334         }
2335 
2336         {
2337           auto pattern = crlf_ + dash_ + boundary_;
2338           if (pattern.size() > buf_.size()) { return true; }
2339 
2340           auto pos = buf_.find(pattern);
2341           if (pos != std::string::npos) {
2342             if (!content_callback(buf_.data(), pos)) {
2343               is_valid_ = false;
2344               is_done_ = false;
2345               return false;
2346             }
2347 
2348             off_ += pos + pattern.size();
2349             buf_.erase(0, pos + pattern.size());
2350             state_ = 4;
2351           } else {
2352             if (!content_callback(buf_.data(), pattern.size())) {
2353               is_valid_ = false;
2354               is_done_ = false;
2355               return false;
2356             }
2357 
2358             off_ += pattern.size();
2359             buf_.erase(0, pattern.size());
2360           }
2361         }
2362         break;
2363       }
2364       case 4: { // Boundary
2365         if (crlf_.size() > buf_.size()) { return true; }
2366         if (buf_.find(crlf_) == 0) {
2367           buf_.erase(0, crlf_.size());
2368           off_ += crlf_.size();
2369           state_ = 1;
2370         } else {
2371           auto pattern = dash_ + crlf_;
2372           if (pattern.size() > buf_.size()) { return true; }
2373           if (buf_.find(pattern) == 0) {
2374             buf_.erase(0, pattern.size());
2375             off_ += pattern.size();
2376             is_valid_ = true;
2377             state_ = 5;
2378           } else {
2379             is_done_ = true;
2380             return true;
2381           }
2382         }
2383         break;
2384       }
2385       case 5: { // Done
2386         is_valid_ = false;
2387         return false;
2388       }
2389       }
2390     }
2391 
2392     return true;
2393   }
2394 
2395 private:
clear_file_info()2396   void clear_file_info() {
2397     file_.name.clear();
2398     file_.filename.clear();
2399     file_.content_type.clear();
2400   }
2401 
2402   std::string boundary_;
2403 
2404   std::string buf_;
2405   size_t state_ = 0;
2406   size_t is_valid_ = false;
2407   size_t is_done_ = false;
2408   size_t off_ = 0;
2409   MultipartFormData file_;
2410 };
2411 
to_lower(const char * beg,const char * end)2412 inline std::string to_lower(const char *beg, const char *end) {
2413   std::string out;
2414   auto it = beg;
2415   while (it != end) {
2416     out += static_cast<char>(::tolower(*it));
2417     it++;
2418   }
2419   return out;
2420 }
2421 
make_multipart_data_boundary()2422 inline std::string make_multipart_data_boundary() {
2423   static const char data[] =
2424       "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz";
2425 
2426   std::random_device seed_gen;
2427   std::mt19937 engine(seed_gen());
2428 
2429   std::string result = "--cpp-httplib-multipart-data-";
2430 
2431   for (auto i = 0; i < 16; i++) {
2432     result += data[engine() % (sizeof(data) - 1)];
2433   }
2434 
2435   return result;
2436 }
2437 
2438 inline std::pair<size_t, size_t>
get_range_offset_and_length(const Request & req,size_t content_length,size_t index)2439 get_range_offset_and_length(const Request &req, size_t content_length,
2440                             size_t index) {
2441   auto r = req.ranges[index];
2442 
2443   if (r.first == -1 && r.second == -1) {
2444     return std::make_pair(0, content_length);
2445   }
2446 
2447   auto slen = static_cast<ssize_t>(content_length);
2448 
2449   if (r.first == -1) {
2450     r.first = slen - r.second;
2451     r.second = slen - 1;
2452   }
2453 
2454   if (r.second == -1) { r.second = slen - 1; }
2455 
2456   return std::make_pair(r.first, r.second - r.first + 1);
2457 }
2458 
make_content_range_header_field(size_t offset,size_t length,size_t content_length)2459 inline std::string make_content_range_header_field(size_t offset, size_t length,
2460                                                    size_t content_length) {
2461   std::string field = "bytes ";
2462   field += std::to_string(offset);
2463   field += "-";
2464   field += std::to_string(offset + length - 1);
2465   field += "/";
2466   field += std::to_string(content_length);
2467   return field;
2468 }
2469 
2470 template <typename SToken, typename CToken, typename Content>
process_multipart_ranges_data(const Request & req,Response & res,const std::string & boundary,const std::string & content_type,SToken stoken,CToken ctoken,Content content)2471 bool process_multipart_ranges_data(const Request &req, Response &res,
2472                                    const std::string &boundary,
2473                                    const std::string &content_type,
2474                                    SToken stoken, CToken ctoken,
2475                                    Content content) {
2476   for (size_t i = 0; i < req.ranges.size(); i++) {
2477     ctoken("--");
2478     stoken(boundary);
2479     ctoken("\r\n");
2480     if (!content_type.empty()) {
2481       ctoken("Content-Type: ");
2482       stoken(content_type);
2483       ctoken("\r\n");
2484     }
2485 
2486     auto offsets = get_range_offset_and_length(req, res.body.size(), i);
2487     auto offset = offsets.first;
2488     auto length = offsets.second;
2489 
2490     ctoken("Content-Range: ");
2491     stoken(make_content_range_header_field(offset, length, res.body.size()));
2492     ctoken("\r\n");
2493     ctoken("\r\n");
2494     if (!content(offset, length)) { return false; }
2495     ctoken("\r\n");
2496   }
2497 
2498   ctoken("--");
2499   stoken(boundary);
2500   ctoken("--\r\n");
2501 
2502   return true;
2503 }
2504 
make_multipart_ranges_data(const Request & req,Response & res,const std::string & boundary,const std::string & content_type)2505 inline std::string make_multipart_ranges_data(const Request &req, Response &res,
2506                                               const std::string &boundary,
2507                                               const std::string &content_type) {
2508   std::string data;
2509 
2510   process_multipart_ranges_data(
2511       req, res, boundary, content_type,
2512       [&](const std::string &token) { data += token; },
2513       [&](const char *token) { data += token; },
2514       [&](size_t offset, size_t length) {
2515         data += res.body.substr(offset, length);
2516         return true;
2517       });
2518 
2519   return data;
2520 }
2521 
2522 inline size_t
get_multipart_ranges_data_length(const Request & req,Response & res,const std::string & boundary,const std::string & content_type)2523 get_multipart_ranges_data_length(const Request &req, Response &res,
2524                                  const std::string &boundary,
2525                                  const std::string &content_type) {
2526   size_t data_length = 0;
2527 
2528   process_multipart_ranges_data(
2529       req, res, boundary, content_type,
2530       [&](const std::string &token) { data_length += token.size(); },
2531       [&](const char *token) { data_length += strlen(token); },
2532       [&](size_t /*offset*/, size_t length) {
2533         data_length += length;
2534         return true;
2535       });
2536 
2537   return data_length;
2538 }
2539 
write_multipart_ranges_data(Stream & strm,const Request & req,Response & res,const std::string & boundary,const std::string & content_type)2540 inline bool write_multipart_ranges_data(Stream &strm, const Request &req,
2541                                         Response &res,
2542                                         const std::string &boundary,
2543                                         const std::string &content_type) {
2544   return process_multipart_ranges_data(
2545       req, res, boundary, content_type,
2546       [&](const std::string &token) { strm.write(token); },
2547       [&](const char *token) { strm.write(token); },
2548       [&](size_t offset, size_t length) {
2549         return write_content(strm, res.content_provider, offset, length) >= 0;
2550       });
2551 }
2552 
2553 inline std::pair<size_t, size_t>
get_range_offset_and_length(const Request & req,const Response & res,size_t index)2554 get_range_offset_and_length(const Request &req, const Response &res,
2555                             size_t index) {
2556   auto r = req.ranges[index];
2557 
2558   if (r.second == -1) {
2559     r.second = static_cast<ssize_t>(res.content_length) - 1;
2560   }
2561 
2562   return std::make_pair(r.first, r.second - r.first + 1);
2563 }
2564 
expect_content(const Request & req)2565 inline bool expect_content(const Request &req) {
2566   if (req.method == "POST" || req.method == "PUT" || req.method == "PATCH" ||
2567       req.method == "PRI" || req.method == "DELETE") {
2568     return true;
2569   }
2570   // TODO: check if Content-Length is set
2571   return false;
2572 }
2573 
has_crlf(const char * s)2574 inline bool has_crlf(const char *s) {
2575   auto p = s;
2576   while (*p) {
2577     if (*p == '\r' || *p == '\n') { return true; }
2578     p++;
2579   }
2580   return false;
2581 }
2582 
2583 #ifdef CPPHTTPLIB_OPENSSL_SUPPORT
2584 template <typename CTX, typename Init, typename Update, typename Final>
message_digest(const std::string & s,Init init,Update update,Final final,size_t digest_length)2585 inline std::string message_digest(const std::string &s, Init init,
2586                                   Update update, Final final,
2587                                   size_t digest_length) {
2588   using namespace std;
2589 
2590   std::vector<unsigned char> md(digest_length, 0);
2591   CTX ctx;
2592   init(&ctx);
2593   update(&ctx, s.data(), s.size());
2594   final(md.data(), &ctx);
2595 
2596   stringstream ss;
2597   for (auto c : md) {
2598     ss << setfill('0') << setw(2) << hex << (unsigned int)c;
2599   }
2600   return ss.str();
2601 }
2602 
MD5(const std::string & s)2603 inline std::string MD5(const std::string &s) {
2604   return message_digest<MD5_CTX>(s, MD5_Init, MD5_Update, MD5_Final,
2605                                  MD5_DIGEST_LENGTH);
2606 }
2607 
SHA_256(const std::string & s)2608 inline std::string SHA_256(const std::string &s) {
2609   return message_digest<SHA256_CTX>(s, SHA256_Init, SHA256_Update, SHA256_Final,
2610                                     SHA256_DIGEST_LENGTH);
2611 }
2612 
SHA_512(const std::string & s)2613 inline std::string SHA_512(const std::string &s) {
2614   return message_digest<SHA512_CTX>(s, SHA512_Init, SHA512_Update, SHA512_Final,
2615                                     SHA512_DIGEST_LENGTH);
2616 }
2617 #endif
2618 
2619 #ifdef _WIN32
2620 class WSInit {
2621 public:
WSInit()2622   WSInit() {
2623     WSADATA wsaData;
2624     WSAStartup(0x0002, &wsaData);
2625   }
2626 
~WSInit()2627   ~WSInit() { WSACleanup(); }
2628 };
2629 
2630 static WSInit wsinit_;
2631 #endif
2632 
2633 } // namespace detail
2634 
2635 // Header utilities
make_range_header(Ranges ranges)2636 inline std::pair<std::string, std::string> make_range_header(Ranges ranges) {
2637   std::string field = "bytes=";
2638   auto i = 0;
2639   for (auto r : ranges) {
2640     if (i != 0) { field += ", "; }
2641     if (r.first != -1) { field += std::to_string(r.first); }
2642     field += '-';
2643     if (r.second != -1) { field += std::to_string(r.second); }
2644     i++;
2645   }
2646   return std::make_pair("Range", field);
2647 }
2648 
2649 inline std::pair<std::string, std::string>
2650 make_basic_authentication_header(const std::string &username,
2651                                  const std::string &password,
2652                                  bool is_proxy = false) {
2653   auto field = "Basic " + detail::base64_encode(username + ":" + password);
2654   auto key = is_proxy ? "Proxy-Authorization" : "Authorization";
2655   return std::make_pair(key, field);
2656 }
2657 
2658 #ifdef CPPHTTPLIB_OPENSSL_SUPPORT
2659 inline std::pair<std::string, std::string> make_digest_authentication_header(
2660     const Request &req, const std::map<std::string, std::string> &auth,
2661     size_t cnonce_count, const std::string &cnonce, const std::string &username,
2662     const std::string &password, bool is_proxy = false) {
2663   using namespace std;
2664 
2665   string nc;
2666   {
2667     stringstream ss;
2668     ss << setfill('0') << setw(8) << hex << cnonce_count;
2669     nc = ss.str();
2670   }
2671 
2672   auto qop = auth.at("qop");
2673   if (qop.find("auth-int") != std::string::npos) {
2674     qop = "auth-int";
2675   } else {
2676     qop = "auth";
2677   }
2678 
2679   std::string algo = "MD5";
2680   if (auth.find("algorithm") != auth.end()) { algo = auth.at("algorithm"); }
2681 
2682   string response;
2683   {
2684     auto H = algo == "SHA-256"
2685                  ? detail::SHA_256
2686                  : algo == "SHA-512" ? detail::SHA_512 : detail::MD5;
2687 
2688     auto A1 = username + ":" + auth.at("realm") + ":" + password;
2689 
2690     auto A2 = req.method + ":" + req.path;
2691     if (qop == "auth-int") { A2 += ":" + H(req.body); }
2692 
2693     response = H(H(A1) + ":" + auth.at("nonce") + ":" + nc + ":" + cnonce +
2694                  ":" + qop + ":" + H(A2));
2695   }
2696 
2697   auto field = "Digest username=\"hello\", realm=\"" + auth.at("realm") +
2698                "\", nonce=\"" + auth.at("nonce") + "\", uri=\"" + req.path +
2699                "\", algorithm=" + algo + ", qop=" + qop + ", nc=\"" + nc +
2700                "\", cnonce=\"" + cnonce + "\", response=\"" + response + "\"";
2701 
2702   auto key = is_proxy ? "Proxy-Authorization" : "Authorization";
2703   return std::make_pair(key, field);
2704 }
2705 #endif
2706 
parse_www_authenticate(const httplib::Response & res,std::map<std::string,std::string> & auth,bool is_proxy)2707 inline bool parse_www_authenticate(const httplib::Response &res,
2708                                    std::map<std::string, std::string> &auth,
2709                                    bool is_proxy) {
2710   auto auth_key = is_proxy ? "Proxy-Authenticate" : "WWW-Authenticate";
2711   if (res.has_header(auth_key)) {
2712     static auto re = std::regex(R"~((?:(?:,\s*)?(.+?)=(?:"(.*?)"|([^,]*))))~");
2713     auto s = res.get_header_value(auth_key);
2714     auto pos = s.find(' ');
2715     if (pos != std::string::npos) {
2716       auto type = s.substr(0, pos);
2717       if (type == "Basic") {
2718         return false;
2719       } else if (type == "Digest") {
2720         s = s.substr(pos + 1);
2721         auto beg = std::sregex_iterator(s.begin(), s.end(), re);
2722         for (auto i = beg; i != std::sregex_iterator(); ++i) {
2723           auto m = *i;
2724           auto key = s.substr(static_cast<size_t>(m.position(1)),
2725                               static_cast<size_t>(m.length(1)));
2726           auto val = m.length(2) > 0
2727                          ? s.substr(static_cast<size_t>(m.position(2)),
2728                                     static_cast<size_t>(m.length(2)))
2729                          : s.substr(static_cast<size_t>(m.position(3)),
2730                                     static_cast<size_t>(m.length(3)));
2731           auth[key] = val;
2732         }
2733         return true;
2734       }
2735     }
2736   }
2737   return false;
2738 }
2739 
2740 // https://stackoverflow.com/questions/440133/how-do-i-create-a-random-alpha-numeric-string-in-c/440240#answer-440240
random_string(size_t length)2741 inline std::string random_string(size_t length) {
2742   auto randchar = []() -> char {
2743     const char charset[] = "0123456789"
2744                            "ABCDEFGHIJKLMNOPQRSTUVWXYZ"
2745                            "abcdefghijklmnopqrstuvwxyz";
2746     const size_t max_index = (sizeof(charset) - 1);
2747     return charset[static_cast<size_t>(rand()) % max_index];
2748   };
2749   std::string str(length, 0);
2750   std::generate_n(str.begin(), length, randchar);
2751   return str;
2752 }
2753 
2754 // Request implementation
has_header(const char * key)2755 inline bool Request::has_header(const char *key) const {
2756   return detail::has_header(headers, key);
2757 }
2758 
get_header_value(const char * key,size_t id)2759 inline std::string Request::get_header_value(const char *key, size_t id) const {
2760   return detail::get_header_value(headers, key, id, "");
2761 }
2762 
get_header_value_count(const char * key)2763 inline size_t Request::get_header_value_count(const char *key) const {
2764   auto r = headers.equal_range(key);
2765   return static_cast<size_t>(std::distance(r.first, r.second));
2766 }
2767 
set_header(const char * key,const char * val)2768 inline void Request::set_header(const char *key, const char *val) {
2769   if (!detail::has_crlf(key) && !detail::has_crlf(val)) {
2770     headers.emplace(key, val);
2771   }
2772 }
2773 
set_header(const char * key,const std::string & val)2774 inline void Request::set_header(const char *key, const std::string &val) {
2775   if (!detail::has_crlf(key) && !detail::has_crlf(val.c_str())) {
2776     headers.emplace(key, val);
2777   }
2778 }
2779 
has_param(const char * key)2780 inline bool Request::has_param(const char *key) const {
2781   return params.find(key) != params.end();
2782 }
2783 
get_param_value(const char * key,size_t id)2784 inline std::string Request::get_param_value(const char *key, size_t id) const {
2785   auto it = params.find(key);
2786   std::advance(it, static_cast<ssize_t>(id));
2787   if (it != params.end()) { return it->second; }
2788   return std::string();
2789 }
2790 
get_param_value_count(const char * key)2791 inline size_t Request::get_param_value_count(const char *key) const {
2792   auto r = params.equal_range(key);
2793   return static_cast<size_t>(std::distance(r.first, r.second));
2794 }
2795 
is_multipart_form_data()2796 inline bool Request::is_multipart_form_data() const {
2797   const auto &content_type = get_header_value("Content-Type");
2798   return !content_type.find("multipart/form-data");
2799 }
2800 
has_file(const char * key)2801 inline bool Request::has_file(const char *key) const {
2802   return files.find(key) != files.end();
2803 }
2804 
get_file_value(const char * key)2805 inline MultipartFormData Request::get_file_value(const char *key) const {
2806   auto it = files.find(key);
2807   if (it != files.end()) { return it->second; }
2808   return MultipartFormData();
2809 }
2810 
2811 // Response implementation
has_header(const char * key)2812 inline bool Response::has_header(const char *key) const {
2813   return headers.find(key) != headers.end();
2814 }
2815 
get_header_value(const char * key,size_t id)2816 inline std::string Response::get_header_value(const char *key,
2817                                               size_t id) const {
2818   return detail::get_header_value(headers, key, id, "");
2819 }
2820 
get_header_value_count(const char * key)2821 inline size_t Response::get_header_value_count(const char *key) const {
2822   auto r = headers.equal_range(key);
2823   return static_cast<size_t>(std::distance(r.first, r.second));
2824 }
2825 
set_header(const char * key,const char * val)2826 inline void Response::set_header(const char *key, const char *val) {
2827   if (!detail::has_crlf(key) && !detail::has_crlf(val)) {
2828     headers.emplace(key, val);
2829   }
2830 }
2831 
set_header(const char * key,const std::string & val)2832 inline void Response::set_header(const char *key, const std::string &val) {
2833   if (!detail::has_crlf(key) && !detail::has_crlf(val.c_str())) {
2834     headers.emplace(key, val);
2835   }
2836 }
2837 
set_redirect(const char * url,int status)2838 inline void Response::set_redirect(const char *url, int status) {
2839   if (!detail::has_crlf(url)) {
2840     set_header("Location", url);
2841     if (300 <= status && status < 400) {
2842       this->status = status;
2843     } else {
2844       this->status = 302;
2845     }
2846   }
2847 }
2848 
set_content(const char * s,size_t n,const char * content_type)2849 inline void Response::set_content(const char *s, size_t n,
2850                                   const char *content_type) {
2851   body.assign(s, n);
2852   set_header("Content-Type", content_type);
2853 }
2854 
set_content(std::string s,const char * content_type)2855 inline void Response::set_content(std::string s, const char *content_type) {
2856   body = std::move(s);
2857   set_header("Content-Type", content_type);
2858 }
2859 
set_content_provider(size_t in_length,std::function<void (size_t offset,size_t length,DataSink & sink)> provider,std::function<void ()> resource_releaser)2860 inline void Response::set_content_provider(
2861     size_t in_length,
2862     std::function<void(size_t offset, size_t length, DataSink &sink)> provider,
2863     std::function<void()> resource_releaser) {
2864   assert(in_length > 0);
2865   content_length = in_length;
2866   content_provider = [provider](size_t offset, size_t length, DataSink &sink) {
2867     provider(offset, length, sink);
2868   };
2869   content_provider_resource_releaser = resource_releaser;
2870 }
2871 
set_chunked_content_provider(std::function<void (size_t offset,DataSink & sink)> provider,std::function<void ()> resource_releaser)2872 inline void Response::set_chunked_content_provider(
2873     std::function<void(size_t offset, DataSink &sink)> provider,
2874     std::function<void()> resource_releaser) {
2875   content_length = 0;
2876   content_provider = [provider](size_t offset, size_t, DataSink &sink) {
2877     provider(offset, sink);
2878   };
2879   content_provider_resource_releaser = resource_releaser;
2880 }
2881 
2882 // Rstream implementation
write(const char * ptr)2883 inline ssize_t Stream::write(const char *ptr) {
2884   return write(ptr, strlen(ptr));
2885 }
2886 
write(const std::string & s)2887 inline ssize_t Stream::write(const std::string &s) {
2888   return write(s.data(), s.size());
2889 }
2890 
2891 template <typename... Args>
write_format(const char * fmt,const Args &...args)2892 inline ssize_t Stream::write_format(const char *fmt, const Args &... args) {
2893   std::array<char, 2048> buf;
2894 
2895 #if defined(_MSC_VER) && _MSC_VER < 1900
2896   auto sn = _snprintf_s(buf, bufsiz, buf.size() - 1, fmt, args...);
2897 #else
2898   auto sn = snprintf(buf.data(), buf.size() - 1, fmt, args...);
2899 #endif
2900   if (sn <= 0) { return sn; }
2901 
2902   auto n = static_cast<size_t>(sn);
2903 
2904   if (n >= buf.size() - 1) {
2905     std::vector<char> glowable_buf(buf.size());
2906 
2907     while (n >= glowable_buf.size() - 1) {
2908       glowable_buf.resize(glowable_buf.size() * 2);
2909 #if defined(_MSC_VER) && _MSC_VER < 1900
2910       n = static_cast<size_t>(_snprintf_s(&glowable_buf[0], glowable_buf.size(),
2911                                           glowable_buf.size() - 1, fmt,
2912                                           args...));
2913 #else
2914       n = static_cast<size_t>(
2915           snprintf(&glowable_buf[0], glowable_buf.size() - 1, fmt, args...));
2916 #endif
2917     }
2918     return write(&glowable_buf[0], n);
2919   } else {
2920     return write(buf.data(), n);
2921   }
2922 }
2923 
2924 namespace detail {
2925 
2926 // Socket stream implementation
SocketStream(socket_t sock,time_t read_timeout_sec,time_t read_timeout_usec)2927 inline SocketStream::SocketStream(socket_t sock, time_t read_timeout_sec,
2928                                   time_t read_timeout_usec)
2929     : sock_(sock), read_timeout_sec_(read_timeout_sec),
2930       read_timeout_usec_(read_timeout_usec) {}
2931 
~SocketStream()2932 inline SocketStream::~SocketStream() {}
2933 
is_readable()2934 inline bool SocketStream::is_readable() const {
2935   return select_read(sock_, read_timeout_sec_, read_timeout_usec_) > 0;
2936 }
2937 
is_writable()2938 inline bool SocketStream::is_writable() const {
2939   return select_write(sock_, 0, 0) > 0;
2940 }
2941 
read(char * ptr,size_t size)2942 inline ssize_t SocketStream::read(char *ptr, size_t size) {
2943   if (!is_readable()) { return -1; }
2944 
2945 #ifdef _WIN32
2946   if (size > static_cast<size_t>(std::numeric_limits<int>::max())) {
2947     return -1;
2948   }
2949   return recv(sock_, ptr, static_cast<int>(size), 0);
2950 #else
2951   return recv(sock_, ptr, size, 0);
2952 #endif
2953 }
2954 
write(const char * ptr,size_t size)2955 inline ssize_t SocketStream::write(const char *ptr, size_t size) {
2956   if (!is_writable()) { return -1; }
2957 
2958 #ifdef _WIN32
2959   if (size > static_cast<size_t>(std::numeric_limits<int>::max())) {
2960     return -1;
2961   }
2962   return send(sock_, ptr, static_cast<int>(size), 0);
2963 #else
2964   return send(sock_, ptr, size, 0);
2965 #endif
2966 }
2967 
get_remote_ip_and_port(std::string & ip,int & port)2968 inline void SocketStream::get_remote_ip_and_port(std::string &ip,
2969                                                  int &port) const {
2970   return detail::get_remote_ip_and_port(sock_, ip, port);
2971 }
2972 
2973 // Buffer stream implementation
is_readable()2974 inline bool BufferStream::is_readable() const { return true; }
2975 
is_writable()2976 inline bool BufferStream::is_writable() const { return true; }
2977 
read(char * ptr,size_t size)2978 inline ssize_t BufferStream::read(char *ptr, size_t size) {
2979 #if defined(_MSC_VER) && _MSC_VER < 1900
2980   auto len_read = buffer._Copy_s(ptr, size, size, position);
2981 #else
2982   auto len_read = buffer.copy(ptr, size, position);
2983 #endif
2984   position += static_cast<size_t>(len_read);
2985   return static_cast<ssize_t>(len_read);
2986 }
2987 
write(const char * ptr,size_t size)2988 inline ssize_t BufferStream::write(const char *ptr, size_t size) {
2989   buffer.append(ptr, size);
2990   return static_cast<ssize_t>(size);
2991 }
2992 
get_remote_ip_and_port(std::string &,int &)2993 inline void BufferStream::get_remote_ip_and_port(std::string & /*ip*/,
2994                                                  int & /*port*/) const {}
2995 
get_buffer()2996 inline const std::string &BufferStream::get_buffer() const { return buffer; }
2997 
2998 } // namespace detail
2999 
3000 // HTTP server implementation
Server()3001 inline Server::Server()
3002     : keep_alive_max_count_(CPPHTTPLIB_KEEPALIVE_MAX_COUNT),
3003       read_timeout_sec_(CPPHTTPLIB_READ_TIMEOUT_SECOND),
3004       read_timeout_usec_(CPPHTTPLIB_READ_TIMEOUT_USECOND),
3005       payload_max_length_(CPPHTTPLIB_PAYLOAD_MAX_LENGTH), is_running_(false),
3006       svr_sock_(INVALID_SOCKET) {
3007 #ifndef _WIN32
3008   signal(SIGPIPE, SIG_IGN);
3009 #endif
3010   new_task_queue = [] { return new ThreadPool(CPPHTTPLIB_THREAD_POOL_COUNT); };
3011 }
3012 
~Server()3013 inline Server::~Server() {}
3014 
Get(const char * pattern,Handler handler)3015 inline Server &Server::Get(const char *pattern, Handler handler) {
3016   get_handlers_.push_back(std::make_pair(std::regex(pattern), handler));
3017   return *this;
3018 }
3019 
Post(const char * pattern,Handler handler)3020 inline Server &Server::Post(const char *pattern, Handler handler) {
3021   post_handlers_.push_back(std::make_pair(std::regex(pattern), handler));
3022   return *this;
3023 }
3024 
Post(const char * pattern,HandlerWithContentReader handler)3025 inline Server &Server::Post(const char *pattern,
3026                             HandlerWithContentReader handler) {
3027   post_handlers_for_content_reader_.push_back(
3028       std::make_pair(std::regex(pattern), handler));
3029   return *this;
3030 }
3031 
Put(const char * pattern,Handler handler)3032 inline Server &Server::Put(const char *pattern, Handler handler) {
3033   put_handlers_.push_back(std::make_pair(std::regex(pattern), handler));
3034   return *this;
3035 }
3036 
Put(const char * pattern,HandlerWithContentReader handler)3037 inline Server &Server::Put(const char *pattern,
3038                            HandlerWithContentReader handler) {
3039   put_handlers_for_content_reader_.push_back(
3040       std::make_pair(std::regex(pattern), handler));
3041   return *this;
3042 }
3043 
Patch(const char * pattern,Handler handler)3044 inline Server &Server::Patch(const char *pattern, Handler handler) {
3045   patch_handlers_.push_back(std::make_pair(std::regex(pattern), handler));
3046   return *this;
3047 }
3048 
Patch(const char * pattern,HandlerWithContentReader handler)3049 inline Server &Server::Patch(const char *pattern,
3050                              HandlerWithContentReader handler) {
3051   patch_handlers_for_content_reader_.push_back(
3052       std::make_pair(std::regex(pattern), handler));
3053   return *this;
3054 }
3055 
Delete(const char * pattern,Handler handler)3056 inline Server &Server::Delete(const char *pattern, Handler handler) {
3057   delete_handlers_.push_back(std::make_pair(std::regex(pattern), handler));
3058   return *this;
3059 }
3060 
Delete(const char * pattern,HandlerWithContentReader handler)3061 inline Server &Server::Delete(const char *pattern,
3062                               HandlerWithContentReader handler) {
3063   delete_handlers_for_content_reader_.push_back(
3064       std::make_pair(std::regex(pattern), handler));
3065   return *this;
3066 }
3067 
Options(const char * pattern,Handler handler)3068 inline Server &Server::Options(const char *pattern, Handler handler) {
3069   options_handlers_.push_back(std::make_pair(std::regex(pattern), handler));
3070   return *this;
3071 }
3072 
set_base_dir(const char * dir,const char * mount_point)3073 inline bool Server::set_base_dir(const char *dir, const char *mount_point) {
3074   return set_mount_point(mount_point, dir);
3075 }
3076 
set_mount_point(const char * mount_point,const char * dir)3077 inline bool Server::set_mount_point(const char *mount_point, const char *dir) {
3078   if (detail::is_dir(dir)) {
3079     std::string mnt = mount_point ? mount_point : "/";
3080     if (!mnt.empty() && mnt[0] == '/') {
3081       base_dirs_.emplace_back(mnt, dir);
3082       return true;
3083     }
3084   }
3085   return false;
3086 }
3087 
remove_mount_point(const char * mount_point)3088 inline bool Server::remove_mount_point(const char *mount_point) {
3089   for (auto it = base_dirs_.begin(); it != base_dirs_.end(); ++it) {
3090     if (it->first == mount_point) {
3091       base_dirs_.erase(it);
3092       return true;
3093     }
3094   }
3095   return false;
3096 }
3097 
set_file_extension_and_mimetype_mapping(const char * ext,const char * mime)3098 inline void Server::set_file_extension_and_mimetype_mapping(const char *ext,
3099                                                             const char *mime) {
3100   file_extension_and_mimetype_map_[ext] = mime;
3101 }
3102 
set_file_request_handler(Handler handler)3103 inline void Server::set_file_request_handler(Handler handler) {
3104   file_request_handler_ = std::move(handler);
3105 }
3106 
set_error_handler(Handler handler)3107 inline void Server::set_error_handler(Handler handler) {
3108   error_handler_ = std::move(handler);
3109 }
3110 
set_logger(Logger logger)3111 inline void Server::set_logger(Logger logger) { logger_ = std::move(logger); }
3112 
3113 inline void
set_expect_100_continue_handler(Expect100ContinueHandler handler)3114 Server::set_expect_100_continue_handler(Expect100ContinueHandler handler) {
3115   expect_100_continue_handler_ = std::move(handler);
3116 }
3117 
set_keep_alive_max_count(size_t count)3118 inline void Server::set_keep_alive_max_count(size_t count) {
3119   keep_alive_max_count_ = count;
3120 }
3121 
set_read_timeout(time_t sec,time_t usec)3122 inline void Server::set_read_timeout(time_t sec, time_t usec) {
3123   read_timeout_sec_ = sec;
3124   read_timeout_usec_ = usec;
3125 }
3126 
set_payload_max_length(size_t length)3127 inline void Server::set_payload_max_length(size_t length) {
3128   payload_max_length_ = length;
3129 }
3130 
bind_to_port(const char * host,int port,int socket_flags)3131 inline bool Server::bind_to_port(const char *host, int port, int socket_flags) {
3132   if (bind_internal(host, port, socket_flags) < 0) return false;
3133   return true;
3134 }
bind_to_any_port(const char * host,int socket_flags)3135 inline int Server::bind_to_any_port(const char *host, int socket_flags) {
3136   return bind_internal(host, 0, socket_flags);
3137 }
3138 
listen_after_bind()3139 inline bool Server::listen_after_bind() { return listen_internal(); }
3140 
listen(const char * host,int port,int socket_flags)3141 inline bool Server::listen(const char *host, int port, int socket_flags) {
3142   return bind_to_port(host, port, socket_flags) && listen_internal();
3143 }
3144 
is_running()3145 inline bool Server::is_running() const { return is_running_; }
3146 
stop()3147 inline void Server::stop() {
3148   if (is_running_) {
3149     assert(svr_sock_ != INVALID_SOCKET);
3150     std::atomic<socket_t> sock(svr_sock_.exchange(INVALID_SOCKET));
3151     detail::shutdown_socket(sock);
3152     detail::close_socket(sock);
3153   }
3154 }
3155 
parse_request_line(const char * s,Request & req)3156 inline bool Server::parse_request_line(const char *s, Request &req) {
3157   const static std::regex re(
3158       "(GET|HEAD|POST|PUT|DELETE|CONNECT|OPTIONS|TRACE|PATCH|PRI) "
3159       "(([^?]+)(?:\\?(.*?))?) (HTTP/1\\.[01])\r\n");
3160 
3161   std::cmatch m;
3162   if (std::regex_match(s, m, re)) {
3163     req.version = std::string(m[5]);
3164     req.method = std::string(m[1]);
3165     req.target = std::string(m[2]);
3166     req.path = detail::decode_url(m[3], false);
3167 
3168     // Parse query text
3169     auto len = std::distance(m[4].first, m[4].second);
3170     if (len > 0) { detail::parse_query_text(m[4], req.params); }
3171 
3172     return true;
3173   }
3174 
3175   return false;
3176 }
3177 
write_response(Stream & strm,bool last_connection,const Request & req,Response & res)3178 inline bool Server::write_response(Stream &strm, bool last_connection,
3179                                    const Request &req, Response &res) {
3180   assert(res.status != -1);
3181 
3182   if (400 <= res.status && error_handler_) { error_handler_(req, res); }
3183 
3184   detail::BufferStream bstrm;
3185 
3186   // Response line
3187   if (!bstrm.write_format("HTTP/1.1 %d %s\r\n", res.status,
3188                           detail::status_message(res.status))) {
3189     return false;
3190   }
3191 
3192   // Headers
3193   if (last_connection || req.get_header_value("Connection") == "close") {
3194     res.set_header("Connection", "close");
3195   }
3196 
3197   if (!last_connection && req.get_header_value("Connection") == "Keep-Alive") {
3198     res.set_header("Connection", "Keep-Alive");
3199   }
3200 
3201   if (!res.has_header("Content-Type") &&
3202       (!res.body.empty() || res.content_length > 0)) {
3203     res.set_header("Content-Type", "text/plain");
3204   }
3205 
3206   if (!res.has_header("Accept-Ranges") && req.method == "HEAD") {
3207     res.set_header("Accept-Ranges", "bytes");
3208   }
3209 
3210   std::string content_type;
3211   std::string boundary;
3212 
3213   if (req.ranges.size() > 1) {
3214     boundary = detail::make_multipart_data_boundary();
3215 
3216     auto it = res.headers.find("Content-Type");
3217     if (it != res.headers.end()) {
3218       content_type = it->second;
3219       res.headers.erase(it);
3220     }
3221 
3222     res.headers.emplace("Content-Type",
3223                         "multipart/byteranges; boundary=" + boundary);
3224   }
3225 
3226   if (res.body.empty()) {
3227     if (res.content_length > 0) {
3228       size_t length = 0;
3229       if (req.ranges.empty()) {
3230         length = res.content_length;
3231       } else if (req.ranges.size() == 1) {
3232         auto offsets =
3233             detail::get_range_offset_and_length(req, res.content_length, 0);
3234         auto offset = offsets.first;
3235         length = offsets.second;
3236         auto content_range = detail::make_content_range_header_field(
3237             offset, length, res.content_length);
3238         res.set_header("Content-Range", content_range);
3239       } else {
3240         length = detail::get_multipart_ranges_data_length(req, res, boundary,
3241                                                           content_type);
3242       }
3243       res.set_header("Content-Length", std::to_string(length));
3244     } else {
3245       if (res.content_provider) {
3246         res.set_header("Transfer-Encoding", "chunked");
3247       } else {
3248         res.set_header("Content-Length", "0");
3249       }
3250     }
3251   } else {
3252     if (req.ranges.empty()) {
3253       ;
3254     } else if (req.ranges.size() == 1) {
3255       auto offsets =
3256           detail::get_range_offset_and_length(req, res.body.size(), 0);
3257       auto offset = offsets.first;
3258       auto length = offsets.second;
3259       auto content_range = detail::make_content_range_header_field(
3260           offset, length, res.body.size());
3261       res.set_header("Content-Range", content_range);
3262       res.body = res.body.substr(offset, length);
3263     } else {
3264       res.body =
3265           detail::make_multipart_ranges_data(req, res, boundary, content_type);
3266     }
3267 
3268 #ifdef CPPHTTPLIB_ZLIB_SUPPORT
3269     // TODO: 'Accept-Encoding' has gzip, not gzip;q=0
3270     const auto &encodings = req.get_header_value("Accept-Encoding");
3271     if (encodings.find("gzip") != std::string::npos &&
3272         detail::can_compress(res.get_header_value("Content-Type"))) {
3273       if (detail::compress(res.body)) {
3274         res.set_header("Content-Encoding", "gzip");
3275       }
3276     }
3277 #endif
3278 
3279     auto length = std::to_string(res.body.size());
3280     res.set_header("Content-Length", length);
3281   }
3282 
3283   if (!detail::write_headers(bstrm, res, Headers())) { return false; }
3284 
3285   // Flush buffer
3286   auto &data = bstrm.get_buffer();
3287   strm.write(data.data(), data.size());
3288 
3289   // Body
3290   if (req.method != "HEAD") {
3291     if (!res.body.empty()) {
3292       if (!strm.write(res.body)) { return false; }
3293     } else if (res.content_provider) {
3294       if (!write_content_with_provider(strm, req, res, boundary,
3295                                        content_type)) {
3296         return false;
3297       }
3298     }
3299   }
3300 
3301   // Log
3302   if (logger_) { logger_(req, res); }
3303 
3304   return true;
3305 }
3306 
3307 inline bool
write_content_with_provider(Stream & strm,const Request & req,Response & res,const std::string & boundary,const std::string & content_type)3308 Server::write_content_with_provider(Stream &strm, const Request &req,
3309                                     Response &res, const std::string &boundary,
3310                                     const std::string &content_type) {
3311   if (res.content_length) {
3312     if (req.ranges.empty()) {
3313       if (detail::write_content(strm, res.content_provider, 0,
3314                                 res.content_length) < 0) {
3315         return false;
3316       }
3317     } else if (req.ranges.size() == 1) {
3318       auto offsets =
3319           detail::get_range_offset_and_length(req, res.content_length, 0);
3320       auto offset = offsets.first;
3321       auto length = offsets.second;
3322       if (detail::write_content(strm, res.content_provider, offset, length) <
3323           0) {
3324         return false;
3325       }
3326     } else {
3327       if (!detail::write_multipart_ranges_data(strm, req, res, boundary,
3328                                                content_type)) {
3329         return false;
3330       }
3331     }
3332   } else {
3333     auto is_shutting_down = [this]() {
3334       return this->svr_sock_ == INVALID_SOCKET;
3335     };
3336     if (detail::write_content_chunked(strm, res.content_provider,
3337                                       is_shutting_down) < 0) {
3338       return false;
3339     }
3340   }
3341   return true;
3342 }
3343 
read_content(Stream & strm,Request & req,Response & res)3344 inline bool Server::read_content(Stream &strm, Request &req, Response &res) {
3345   MultipartFormDataMap::iterator cur;
3346   if (read_content_core(
3347           strm, req, res,
3348           // Regular
3349           [&](const char *buf, size_t n) {
3350             if (req.body.size() + n > req.body.max_size()) { return false; }
3351             req.body.append(buf, n);
3352             return true;
3353           },
3354           // Multipart
3355           [&](const MultipartFormData &file) {
3356             cur = req.files.emplace(file.name, file);
3357             return true;
3358           },
3359           [&](const char *buf, size_t n) {
3360             auto &content = cur->second.content;
3361             if (content.size() + n > content.max_size()) { return false; }
3362             content.append(buf, n);
3363             return true;
3364           })) {
3365     const auto &content_type = req.get_header_value("Content-Type");
3366     if (!content_type.find("application/x-www-form-urlencoded")) {
3367       detail::parse_query_text(req.body, req.params);
3368     }
3369     return true;
3370   }
3371   return false;
3372 }
3373 
read_content_with_content_receiver(Stream & strm,Request & req,Response & res,ContentReceiver receiver,MultipartContentHeader multipart_header,ContentReceiver multipart_receiver)3374 inline bool Server::read_content_with_content_receiver(
3375     Stream &strm, Request &req, Response &res, ContentReceiver receiver,
3376     MultipartContentHeader multipart_header,
3377     ContentReceiver multipart_receiver) {
3378   return read_content_core(strm, req, res, receiver, multipart_header,
3379                            multipart_receiver);
3380 }
3381 
read_content_core(Stream & strm,Request & req,Response & res,ContentReceiver receiver,MultipartContentHeader mulitpart_header,ContentReceiver multipart_receiver)3382 inline bool Server::read_content_core(Stream &strm, Request &req, Response &res,
3383                                       ContentReceiver receiver,
3384                                       MultipartContentHeader mulitpart_header,
3385                                       ContentReceiver multipart_receiver) {
3386   detail::MultipartFormDataParser multipart_form_data_parser;
3387   ContentReceiver out;
3388 
3389   if (req.is_multipart_form_data()) {
3390     const auto &content_type = req.get_header_value("Content-Type");
3391     std::string boundary;
3392     if (!detail::parse_multipart_boundary(content_type, boundary)) {
3393       res.status = 400;
3394       return false;
3395     }
3396 
3397     multipart_form_data_parser.set_boundary(std::move(boundary));
3398     out = [&](const char *buf, size_t n) {
3399       return multipart_form_data_parser.parse(buf, n, multipart_receiver,
3400                                               mulitpart_header);
3401     };
3402   } else {
3403     out = receiver;
3404   }
3405 
3406   if (!detail::read_content(strm, req, payload_max_length_, res.status,
3407                             Progress(), out)) {
3408     return false;
3409   }
3410 
3411   if (req.is_multipart_form_data()) {
3412     if (!multipart_form_data_parser.is_valid()) {
3413       res.status = 400;
3414       return false;
3415     }
3416   }
3417 
3418   return true;
3419 }
3420 
handle_file_request(Request & req,Response & res,bool head)3421 inline bool Server::handle_file_request(Request &req, Response &res,
3422                                         bool head) {
3423   for (const auto &kv : base_dirs_) {
3424     const auto &mount_point = kv.first;
3425     const auto &base_dir = kv.second;
3426 
3427     // Prefix match
3428     if (!req.path.find(mount_point)) {
3429       std::string sub_path = "/" + req.path.substr(mount_point.size());
3430       if (detail::is_valid_path(sub_path)) {
3431         auto path = base_dir + sub_path;
3432         if (path.back() == '/') { path += "index.html"; }
3433 
3434         if (detail::is_file(path)) {
3435           detail::read_file(path, res.body);
3436           auto type =
3437               detail::find_content_type(path, file_extension_and_mimetype_map_);
3438           if (type) { res.set_header("Content-Type", type); }
3439           res.status = 200;
3440           if (!head && file_request_handler_) {
3441             file_request_handler_(req, res);
3442           }
3443           return true;
3444         }
3445       }
3446     }
3447   }
3448   return false;
3449 }
3450 
create_server_socket(const char * host,int port,int socket_flags)3451 inline socket_t Server::create_server_socket(const char *host, int port,
3452                                              int socket_flags) const {
3453   return detail::create_socket(
3454       host, port,
3455       [](socket_t sock, struct addrinfo &ai) -> bool {
3456         if (::bind(sock, ai.ai_addr, static_cast<socklen_t>(ai.ai_addrlen))) {
3457           return false;
3458         }
3459         if (::listen(sock, 5)) { // Listen through 5 channels
3460           return false;
3461         }
3462         return true;
3463       },
3464       socket_flags);
3465 }
3466 
bind_internal(const char * host,int port,int socket_flags)3467 inline int Server::bind_internal(const char *host, int port, int socket_flags) {
3468   if (!is_valid()) { return -1; }
3469 
3470   svr_sock_ = create_server_socket(host, port, socket_flags);
3471   if (svr_sock_ == INVALID_SOCKET) { return -1; }
3472 
3473   if (port == 0) {
3474     struct sockaddr_storage addr;
3475     socklen_t addr_len = sizeof(addr);
3476     if (getsockname(svr_sock_, reinterpret_cast<struct sockaddr *>(&addr),
3477                     &addr_len) == -1) {
3478       return -1;
3479     }
3480     if (addr.ss_family == AF_INET) {
3481       return ntohs(reinterpret_cast<struct sockaddr_in *>(&addr)->sin_port);
3482     } else if (addr.ss_family == AF_INET6) {
3483       return ntohs(reinterpret_cast<struct sockaddr_in6 *>(&addr)->sin6_port);
3484     } else {
3485       return -1;
3486     }
3487   } else {
3488     return port;
3489   }
3490 }
3491 
listen_internal()3492 inline bool Server::listen_internal() {
3493   auto ret = true;
3494   is_running_ = true;
3495 
3496   {
3497     std::unique_ptr<TaskQueue> task_queue(new_task_queue());
3498 
3499     for (;;) {
3500       if (svr_sock_ == INVALID_SOCKET) {
3501         // The server socket was closed by 'stop' method.
3502         break;
3503       }
3504 
3505       auto val = detail::select_read(svr_sock_, 0, 100000);
3506 
3507       if (val == 0) { // Timeout
3508         task_queue->on_idle();
3509         continue;
3510       }
3511 
3512       socket_t sock = accept(svr_sock_, nullptr, nullptr);
3513 
3514       if (sock == INVALID_SOCKET) {
3515         if (errno == EMFILE) {
3516           // The per-process limit of open file descriptors has been reached.
3517           // Try to accept new connections after a short sleep.
3518           std::this_thread::sleep_for(std::chrono::milliseconds(1));
3519           continue;
3520         }
3521         if (svr_sock_ != INVALID_SOCKET) {
3522           detail::close_socket(svr_sock_);
3523           ret = false;
3524         } else {
3525           ; // The server socket was closed by user.
3526         }
3527         break;
3528       }
3529 
3530 #if __cplusplus > 201703L
3531       task_queue->enqueue([=, this]() { process_and_close_socket(sock); });
3532 #else
3533       task_queue->enqueue([=]() { process_and_close_socket(sock); });
3534 #endif
3535     }
3536 
3537     task_queue->shutdown();
3538   }
3539 
3540   is_running_ = false;
3541   return ret;
3542 }
3543 
routing(Request & req,Response & res,Stream & strm)3544 inline bool Server::routing(Request &req, Response &res, Stream &strm) {
3545   // File handler
3546   bool is_head_request = req.method == "HEAD";
3547   if ((req.method == "GET" || is_head_request) &&
3548       handle_file_request(req, res, is_head_request)) {
3549     return true;
3550   }
3551 
3552   if (detail::expect_content(req)) {
3553     // Content reader handler
3554     {
3555       ContentReader reader(
3556           [&](ContentReceiver receiver) {
3557             return read_content_with_content_receiver(strm, req, res, receiver,
3558                                                       nullptr, nullptr);
3559           },
3560           [&](MultipartContentHeader header, ContentReceiver receiver) {
3561             return read_content_with_content_receiver(strm, req, res, nullptr,
3562                                                       header, receiver);
3563           });
3564 
3565       if (req.method == "POST") {
3566         if (dispatch_request_for_content_reader(
3567                 req, res, reader, post_handlers_for_content_reader_)) {
3568           return true;
3569         }
3570       } else if (req.method == "PUT") {
3571         if (dispatch_request_for_content_reader(
3572                 req, res, reader, put_handlers_for_content_reader_)) {
3573           return true;
3574         }
3575       } else if (req.method == "PATCH") {
3576         if (dispatch_request_for_content_reader(
3577                 req, res, reader, patch_handlers_for_content_reader_)) {
3578           return true;
3579         }
3580       } else if (req.method == "DELETE") {
3581         if (dispatch_request_for_content_reader(
3582                 req, res, reader, delete_handlers_for_content_reader_)) {
3583           return true;
3584         }
3585       }
3586     }
3587 
3588     // Read content into `req.body`
3589     if (!read_content(strm, req, res)) { return false; }
3590   }
3591 
3592   // Regular handler
3593   if (req.method == "GET" || req.method == "HEAD") {
3594     return dispatch_request(req, res, get_handlers_);
3595   } else if (req.method == "POST") {
3596     return dispatch_request(req, res, post_handlers_);
3597   } else if (req.method == "PUT") {
3598     return dispatch_request(req, res, put_handlers_);
3599   } else if (req.method == "DELETE") {
3600     return dispatch_request(req, res, delete_handlers_);
3601   } else if (req.method == "OPTIONS") {
3602     return dispatch_request(req, res, options_handlers_);
3603   } else if (req.method == "PATCH") {
3604     return dispatch_request(req, res, patch_handlers_);
3605   }
3606 
3607   res.status = 400;
3608   return false;
3609 }
3610 
dispatch_request(Request & req,Response & res,Handlers & handlers)3611 inline bool Server::dispatch_request(Request &req, Response &res,
3612                                      Handlers &handlers) {
3613 
3614   try {
3615     for (const auto &x : handlers) {
3616       const auto &pattern = x.first;
3617       const auto &handler = x.second;
3618 
3619       if (std::regex_match(req.path, req.matches, pattern)) {
3620         handler(req, res);
3621         return true;
3622       }
3623     }
3624   } catch (const std::exception &ex) {
3625     res.status = 500;
3626     res.set_header("EXCEPTION_WHAT", ex.what());
3627   } catch (...) {
3628     res.status = 500;
3629     res.set_header("EXCEPTION_WHAT", "UNKNOWN");
3630   }
3631   return false;
3632 }
3633 
dispatch_request_for_content_reader(Request & req,Response & res,ContentReader content_reader,HandlersForContentReader & handlers)3634 inline bool Server::dispatch_request_for_content_reader(
3635     Request &req, Response &res, ContentReader content_reader,
3636     HandlersForContentReader &handlers) {
3637   for (const auto &x : handlers) {
3638     const auto &pattern = x.first;
3639     const auto &handler = x.second;
3640 
3641     if (std::regex_match(req.path, req.matches, pattern)) {
3642       handler(req, res, content_reader);
3643       return true;
3644     }
3645   }
3646   return false;
3647 }
3648 
3649 inline bool
process_request(Stream & strm,bool last_connection,bool & connection_close,const std::function<void (Request &)> & setup_request)3650 Server::process_request(Stream &strm, bool last_connection,
3651                         bool &connection_close,
3652                         const std::function<void(Request &)> &setup_request) {
3653   std::array<char, 2048> buf{};
3654 
3655   detail::stream_line_reader line_reader(strm, buf.data(), buf.size());
3656 
3657   // Connection has been closed on client
3658   if (!line_reader.getline()) { return false; }
3659 
3660   Request req;
3661   Response res;
3662 
3663   res.version = "HTTP/1.1";
3664 
3665   // Check if the request URI doesn't exceed the limit
3666   if (line_reader.size() > CPPHTTPLIB_REQUEST_URI_MAX_LENGTH) {
3667     Headers dummy;
3668     detail::read_headers(strm, dummy);
3669     res.status = 414;
3670     return write_response(strm, last_connection, req, res);
3671   }
3672 
3673   // Request line and headers
3674   if (!parse_request_line(line_reader.ptr(), req) ||
3675       !detail::read_headers(strm, req.headers)) {
3676     res.status = 400;
3677     return write_response(strm, last_connection, req, res);
3678   }
3679 
3680   if (req.get_header_value("Connection") == "close") {
3681     connection_close = true;
3682   }
3683 
3684   if (req.version == "HTTP/1.0" &&
3685       req.get_header_value("Connection") != "Keep-Alive") {
3686     connection_close = true;
3687   }
3688 
3689   strm.get_remote_ip_and_port(req.remote_addr, req.remote_port);
3690   req.set_header("REMOTE_ADDR", req.remote_addr);
3691   req.set_header("REMOTE_PORT", std::to_string(req.remote_port));
3692 
3693   if (req.has_header("Range")) {
3694     const auto &range_header_value = req.get_header_value("Range");
3695     if (!detail::parse_range_header(range_header_value, req.ranges)) {
3696       // TODO: error
3697     }
3698   }
3699 
3700   if (setup_request) { setup_request(req); }
3701 
3702   if (req.get_header_value("Expect") == "100-continue") {
3703     auto status = 100;
3704     if (expect_100_continue_handler_) {
3705       status = expect_100_continue_handler_(req, res);
3706     }
3707     switch (status) {
3708     case 100:
3709     case 417:
3710       strm.write_format("HTTP/1.1 %d %s\r\n\r\n", status,
3711                         detail::status_message(status));
3712       break;
3713     default: return write_response(strm, last_connection, req, res);
3714     }
3715   }
3716 
3717   // Rounting
3718   if (routing(req, res, strm)) {
3719     if (res.status == -1) { res.status = req.ranges.empty() ? 200 : 206; }
3720   } else {
3721     if (res.status == -1) { res.status = 404; }
3722   }
3723 
3724   return write_response(strm, last_connection, req, res);
3725 }
3726 
is_valid()3727 inline bool Server::is_valid() const { return true; }
3728 
process_and_close_socket(socket_t sock)3729 inline bool Server::process_and_close_socket(socket_t sock) {
3730   return detail::process_and_close_socket(
3731       false, sock, keep_alive_max_count_, read_timeout_sec_, read_timeout_usec_,
3732       [this](Stream &strm, bool last_connection, bool &connection_close) {
3733         return process_request(strm, last_connection, connection_close,
3734                                nullptr);
3735       });
3736 }
3737 
3738 // HTTP client implementation
Client(const std::string & host,int port,const std::string & client_cert_path,const std::string & client_key_path)3739 inline Client::Client(const std::string &host, int port,
3740                       const std::string &client_cert_path,
3741                       const std::string &client_key_path)
3742     : sock_(INVALID_SOCKET), host_(host), port_(port),
3743       host_and_port_(host_ + ":" + std::to_string(port_)),
3744       client_cert_path_(client_cert_path), client_key_path_(client_key_path) {}
3745 
~Client()3746 inline Client::~Client() {}
3747 
is_valid()3748 inline bool Client::is_valid() const { return true; }
3749 
create_client_socket()3750 inline socket_t Client::create_client_socket() const {
3751   if (!proxy_host_.empty()) {
3752     return detail::create_client_socket(proxy_host_.c_str(), proxy_port_,
3753                                         timeout_sec_, interface_);
3754   }
3755   return detail::create_client_socket(host_.c_str(), port_, timeout_sec_,
3756                                       interface_);
3757 }
3758 
read_response_line(Stream & strm,Response & res)3759 inline bool Client::read_response_line(Stream &strm, Response &res) {
3760   std::array<char, 2048> buf;
3761 
3762   detail::stream_line_reader line_reader(strm, buf.data(), buf.size());
3763 
3764   if (!line_reader.getline()) { return false; }
3765 
3766   const static std::regex re("(HTTP/1\\.[01]) (\\d+?) .*\r\n");
3767 
3768   std::cmatch m;
3769   if (std::regex_match(line_reader.ptr(), m, re)) {
3770     res.version = std::string(m[1]);
3771     res.status = std::stoi(std::string(m[2]));
3772   }
3773 
3774   return true;
3775 }
3776 
send(const Request & req,Response & res)3777 inline bool Client::send(const Request &req, Response &res) {
3778   sock_ = create_client_socket();
3779   if (sock_ == INVALID_SOCKET) { return false; }
3780 
3781 #ifdef CPPHTTPLIB_OPENSSL_SUPPORT
3782   if (is_ssl() && !proxy_host_.empty()) {
3783     bool error;
3784     if (!connect(sock_, res, error)) { return error; }
3785   }
3786 #endif
3787 
3788   return process_and_close_socket(
3789       sock_, 1,
3790       [&](Stream &strm, bool last_connection, bool &connection_close) {
3791         return handle_request(strm, req, res, last_connection,
3792                               connection_close);
3793       });
3794 }
3795 
send(const std::vector<Request> & requests,std::vector<Response> & responses)3796 inline bool Client::send(const std::vector<Request> &requests,
3797                          std::vector<Response> &responses) {
3798   size_t i = 0;
3799   while (i < requests.size()) {
3800     sock_ = create_client_socket();
3801     if (sock_ == INVALID_SOCKET) { return false; }
3802 
3803 #ifdef CPPHTTPLIB_OPENSSL_SUPPORT
3804     if (is_ssl() && !proxy_host_.empty()) {
3805       Response res;
3806       bool error;
3807       if (!connect(sock_, res, error)) { return false; }
3808     }
3809 #endif
3810 
3811     if (!process_and_close_socket(sock_, requests.size() - i,
3812                                   [&](Stream &strm, bool last_connection,
3813                                       bool &connection_close) -> bool {
3814                                     auto &req = requests[i++];
3815                                     auto res = Response();
3816                                     auto ret = handle_request(strm, req, res,
3817                                                               last_connection,
3818                                                               connection_close);
3819                                     if (ret) {
3820                                       responses.emplace_back(std::move(res));
3821                                     }
3822                                     return ret;
3823                                   })) {
3824       return false;
3825     }
3826   }
3827 
3828   return true;
3829 }
3830 
handle_request(Stream & strm,const Request & req,Response & res,bool last_connection,bool & connection_close)3831 inline bool Client::handle_request(Stream &strm, const Request &req,
3832                                    Response &res, bool last_connection,
3833                                    bool &connection_close) {
3834   if (req.path.empty()) { return false; }
3835 
3836   bool ret;
3837 
3838   if (!is_ssl() && !proxy_host_.empty()) {
3839     auto req2 = req;
3840     req2.path = "http://" + host_and_port_ + req.path;
3841     ret = process_request(strm, req2, res, last_connection, connection_close);
3842   } else {
3843     ret = process_request(strm, req, res, last_connection, connection_close);
3844   }
3845 
3846   if (!ret) { return false; }
3847 
3848   if (300 < res.status && res.status < 400 && follow_location_) {
3849     ret = redirect(req, res);
3850   }
3851 
3852 #ifdef CPPHTTPLIB_OPENSSL_SUPPORT
3853   if (res.status == 401 || res.status == 407) {
3854     auto is_proxy = res.status == 407;
3855     const auto &username =
3856         is_proxy ? proxy_digest_auth_username_ : digest_auth_username_;
3857     const auto &password =
3858         is_proxy ? proxy_digest_auth_password_ : digest_auth_password_;
3859 
3860     if (!username.empty() && !password.empty()) {
3861       std::map<std::string, std::string> auth;
3862       if (parse_www_authenticate(res, auth, is_proxy)) {
3863         Request new_req = req;
3864         auto key = is_proxy ? "Proxy-Authorization" : "WWW-Authorization";
3865         new_req.headers.erase(key);
3866         new_req.headers.insert(make_digest_authentication_header(
3867             req, auth, 1, random_string(10), username, password, is_proxy));
3868 
3869         Response new_res;
3870 
3871         ret = send(new_req, new_res);
3872         if (ret) { res = new_res; }
3873       }
3874     }
3875   }
3876 #endif
3877 
3878   return ret;
3879 }
3880 
3881 #ifdef CPPHTTPLIB_OPENSSL_SUPPORT
connect(socket_t sock,Response & res,bool & error)3882 inline bool Client::connect(socket_t sock, Response &res, bool &error) {
3883   error = true;
3884   Response res2;
3885 
3886   if (!detail::process_socket(
3887           true, sock, 1, read_timeout_sec_, read_timeout_usec_,
3888           [&](Stream &strm, bool /*last_connection*/, bool &connection_close) {
3889             Request req2;
3890             req2.method = "CONNECT";
3891             req2.path = host_and_port_;
3892             return process_request(strm, req2, res2, false, connection_close);
3893           })) {
3894     detail::close_socket(sock);
3895     error = false;
3896     return false;
3897   }
3898 
3899   if (res2.status == 407) {
3900     if (!proxy_digest_auth_username_.empty() &&
3901         !proxy_digest_auth_password_.empty()) {
3902       std::map<std::string, std::string> auth;
3903       if (parse_www_authenticate(res2, auth, true)) {
3904         Response res3;
3905         if (!detail::process_socket(
3906                 true, sock, 1, read_timeout_sec_, read_timeout_usec_,
3907                 [&](Stream &strm, bool /*last_connection*/,
3908                     bool &connection_close) {
3909                   Request req3;
3910                   req3.method = "CONNECT";
3911                   req3.path = host_and_port_;
3912                   req3.headers.insert(make_digest_authentication_header(
3913                       req3, auth, 1, random_string(10),
3914                       proxy_digest_auth_username_, proxy_digest_auth_password_,
3915                       true));
3916                   return process_request(strm, req3, res3, false,
3917                                          connection_close);
3918                 })) {
3919           detail::close_socket(sock);
3920           error = false;
3921           return false;
3922         }
3923       }
3924     } else {
3925       res = res2;
3926       return false;
3927     }
3928   }
3929 
3930   return true;
3931 }
3932 #endif
3933 
redirect(const Request & req,Response & res)3934 inline bool Client::redirect(const Request &req, Response &res) {
3935   if (req.redirect_count == 0) { return false; }
3936 
3937   auto location = res.get_header_value("location");
3938   if (location.empty()) { return false; }
3939 
3940   const static std::regex re(
3941       R"(^(?:(https?):)?(?://([^:/?#]*)(?::(\d+))?)?([^?#]*(?:\?[^#]*)?)(?:#.*)?)");
3942 
3943   std::smatch m;
3944   if (!std::regex_match(location, m, re)) { return false; }
3945 
3946   auto scheme = is_ssl() ? "https" : "http";
3947 
3948   auto next_scheme = m[1].str();
3949   auto next_host = m[2].str();
3950   auto port_str = m[3].str();
3951   auto next_path = m[4].str();
3952 
3953   auto next_port = port_;
3954   if (!port_str.empty()) {
3955     next_port = std::stoi(port_str);
3956   } else if (!next_scheme.empty()) {
3957     next_port = next_scheme == "https" ? 443 : 80;
3958   }
3959 
3960   if (next_scheme.empty()) { next_scheme = scheme; }
3961   if (next_host.empty()) { next_host = host_; }
3962   if (next_path.empty()) { next_path = "/"; }
3963 
3964   if (next_scheme == scheme && next_host == host_ && next_port == port_) {
3965     return detail::redirect(*this, req, res, next_path);
3966   } else {
3967     if (next_scheme == "https") {
3968 #ifdef CPPHTTPLIB_OPENSSL_SUPPORT
3969       SSLClient cli(next_host.c_str(), next_port);
3970       cli.copy_settings(*this);
3971       return detail::redirect(cli, req, res, next_path);
3972 #else
3973       return false;
3974 #endif
3975     } else {
3976       Client cli(next_host.c_str(), next_port);
3977       cli.copy_settings(*this);
3978       return detail::redirect(cli, req, res, next_path);
3979     }
3980   }
3981 }
3982 
write_request(Stream & strm,const Request & req,bool last_connection)3983 inline bool Client::write_request(Stream &strm, const Request &req,
3984                                   bool last_connection) {
3985   detail::BufferStream bstrm;
3986 
3987   // Request line
3988   const auto &path = detail::encode_url(req.path);
3989 
3990   bstrm.write_format("%s %s HTTP/1.1\r\n", req.method.c_str(), path.c_str());
3991 
3992   // Additonal headers
3993   Headers headers;
3994   if (last_connection) { headers.emplace("Connection", "close"); }
3995 
3996   if (!req.has_header("Host")) {
3997     if (is_ssl()) {
3998       if (port_ == 443) {
3999         headers.emplace("Host", host_);
4000       } else {
4001         headers.emplace("Host", host_and_port_);
4002       }
4003     } else {
4004       if (port_ == 80) {
4005         headers.emplace("Host", host_);
4006       } else {
4007         headers.emplace("Host", host_and_port_);
4008       }
4009     }
4010   }
4011 
4012   if (!req.has_header("Accept")) { headers.emplace("Accept", "*/*"); }
4013 
4014   if (!req.has_header("User-Agent")) {
4015     headers.emplace("User-Agent", "cpp-httplib/0.5");
4016   }
4017 
4018   if (req.body.empty()) {
4019     if (req.content_provider) {
4020       auto length = std::to_string(req.content_length);
4021       headers.emplace("Content-Length", length);
4022     } else {
4023       headers.emplace("Content-Length", "0");
4024     }
4025   } else {
4026     if (!req.has_header("Content-Type")) {
4027       headers.emplace("Content-Type", "text/plain");
4028     }
4029 
4030     if (!req.has_header("Content-Length")) {
4031       auto length = std::to_string(req.body.size());
4032       headers.emplace("Content-Length", length);
4033     }
4034   }
4035 
4036   if (!basic_auth_username_.empty() && !basic_auth_password_.empty()) {
4037     headers.insert(make_basic_authentication_header(
4038         basic_auth_username_, basic_auth_password_, false));
4039   }
4040 
4041   if (!proxy_basic_auth_username_.empty() &&
4042       !proxy_basic_auth_password_.empty()) {
4043     headers.insert(make_basic_authentication_header(
4044         proxy_basic_auth_username_, proxy_basic_auth_password_, true));
4045   }
4046 
4047   detail::write_headers(bstrm, req, headers);
4048 
4049   // Flush buffer
4050   auto &data = bstrm.get_buffer();
4051   strm.write(data.data(), data.size());
4052 
4053   // Body
4054   if (req.body.empty()) {
4055     if (req.content_provider) {
4056       size_t offset = 0;
4057       size_t end_offset = req.content_length;
4058 
4059       DataSink data_sink;
4060       data_sink.write = [&](const char *d, size_t l) {
4061         auto written_length = strm.write(d, l);
4062         offset += static_cast<size_t>(written_length);
4063       };
4064       data_sink.is_writable = [&](void) { return strm.is_writable(); };
4065 
4066       while (offset < end_offset) {
4067         req.content_provider(offset, end_offset - offset, data_sink);
4068       }
4069     }
4070   } else {
4071     strm.write(req.body);
4072   }
4073 
4074   return true;
4075 }
4076 
send_with_content_provider(const char * method,const char * path,const Headers & headers,const std::string & body,size_t content_length,ContentProvider content_provider,const char * content_type)4077 inline std::shared_ptr<Response> Client::send_with_content_provider(
4078     const char *method, const char *path, const Headers &headers,
4079     const std::string &body, size_t content_length,
4080     ContentProvider content_provider, const char *content_type) {
4081   Request req;
4082   req.method = method;
4083   req.headers = headers;
4084   req.path = path;
4085 
4086   if (content_type) { req.headers.emplace("Content-Type", content_type); }
4087 
4088 #ifdef CPPHTTPLIB_ZLIB_SUPPORT
4089   if (compress_) {
4090     if (content_provider) {
4091       size_t offset = 0;
4092 
4093       DataSink data_sink;
4094       data_sink.write = [&](const char *data, size_t data_len) {
4095         req.body.append(data, data_len);
4096         offset += data_len;
4097       };
4098       data_sink.is_writable = [&](void) { return true; };
4099 
4100       while (offset < content_length) {
4101         content_provider(offset, content_length - offset, data_sink);
4102       }
4103     } else {
4104       req.body = body;
4105     }
4106 
4107     if (!detail::compress(req.body)) { return nullptr; }
4108     req.headers.emplace("Content-Encoding", "gzip");
4109   } else
4110 #endif
4111   {
4112     if (content_provider) {
4113       req.content_length = content_length;
4114       req.content_provider = content_provider;
4115     } else {
4116       req.body = body;
4117     }
4118   }
4119 
4120   auto res = std::make_shared<Response>();
4121 
4122   return send(req, *res) ? res : nullptr;
4123 }
4124 
process_request(Stream & strm,const Request & req,Response & res,bool last_connection,bool & connection_close)4125 inline bool Client::process_request(Stream &strm, const Request &req,
4126                                     Response &res, bool last_connection,
4127                                     bool &connection_close) {
4128   // Send request
4129   if (!write_request(strm, req, last_connection)) { return false; }
4130 
4131   // Receive response and headers
4132   if (!read_response_line(strm, res) ||
4133       !detail::read_headers(strm, res.headers)) {
4134     return false;
4135   }
4136 
4137   if (res.get_header_value("Connection") == "close" ||
4138       res.version == "HTTP/1.0") {
4139     connection_close = true;
4140   }
4141 
4142   if (req.response_handler) {
4143     if (!req.response_handler(res)) { return false; }
4144   }
4145 
4146   // Body
4147   if (req.method != "HEAD" && req.method != "CONNECT") {
4148     auto out =
4149         req.content_receiver
4150             ? static_cast<ContentReceiver>([&](const char *buf, size_t n) {
4151                 return req.content_receiver(buf, n);
4152               })
4153             : static_cast<ContentReceiver>([&](const char *buf, size_t n) {
4154                 if (res.body.size() + n > res.body.max_size()) { return false; }
4155                 res.body.append(buf, n);
4156                 return true;
4157               });
4158 
4159     int dummy_status;
4160     if (!detail::read_content(strm, res, (std::numeric_limits<size_t>::max)(),
4161                               dummy_status, req.progress, out)) {
4162       return false;
4163     }
4164   }
4165 
4166   // Log
4167   if (logger_) { logger_(req, res); }
4168 
4169   return true;
4170 }
4171 
process_and_close_socket(socket_t sock,size_t request_count,std::function<bool (Stream & strm,bool last_connection,bool & connection_close)> callback)4172 inline bool Client::process_and_close_socket(
4173     socket_t sock, size_t request_count,
4174     std::function<bool(Stream &strm, bool last_connection,
4175                        bool &connection_close)>
4176         callback) {
4177   request_count = (std::min)(request_count, keep_alive_max_count_);
4178   return detail::process_and_close_socket(true, sock, request_count,
4179                                           read_timeout_sec_, read_timeout_usec_,
4180                                           callback);
4181 }
4182 
is_ssl()4183 inline bool Client::is_ssl() const { return false; }
4184 
Get(const char * path)4185 inline std::shared_ptr<Response> Client::Get(const char *path) {
4186   return Get(path, Headers(), Progress());
4187 }
4188 
Get(const char * path,Progress progress)4189 inline std::shared_ptr<Response> Client::Get(const char *path,
4190                                              Progress progress) {
4191   return Get(path, Headers(), std::move(progress));
4192 }
4193 
Get(const char * path,const Headers & headers)4194 inline std::shared_ptr<Response> Client::Get(const char *path,
4195                                              const Headers &headers) {
4196   return Get(path, headers, Progress());
4197 }
4198 
4199 inline std::shared_ptr<Response>
Get(const char * path,const Headers & headers,Progress progress)4200 Client::Get(const char *path, const Headers &headers, Progress progress) {
4201   Request req;
4202   req.method = "GET";
4203   req.path = path;
4204   req.headers = headers;
4205   req.progress = std::move(progress);
4206 
4207   auto res = std::make_shared<Response>();
4208   return send(req, *res) ? res : nullptr;
4209 }
4210 
Get(const char * path,ContentReceiver content_receiver)4211 inline std::shared_ptr<Response> Client::Get(const char *path,
4212                                              ContentReceiver content_receiver) {
4213   return Get(path, Headers(), nullptr, std::move(content_receiver), Progress());
4214 }
4215 
Get(const char * path,ContentReceiver content_receiver,Progress progress)4216 inline std::shared_ptr<Response> Client::Get(const char *path,
4217                                              ContentReceiver content_receiver,
4218                                              Progress progress) {
4219   return Get(path, Headers(), nullptr, std::move(content_receiver),
4220              std::move(progress));
4221 }
4222 
Get(const char * path,const Headers & headers,ContentReceiver content_receiver)4223 inline std::shared_ptr<Response> Client::Get(const char *path,
4224                                              const Headers &headers,
4225                                              ContentReceiver content_receiver) {
4226   return Get(path, headers, nullptr, std::move(content_receiver), Progress());
4227 }
4228 
Get(const char * path,const Headers & headers,ContentReceiver content_receiver,Progress progress)4229 inline std::shared_ptr<Response> Client::Get(const char *path,
4230                                              const Headers &headers,
4231                                              ContentReceiver content_receiver,
4232                                              Progress progress) {
4233   return Get(path, headers, nullptr, std::move(content_receiver),
4234              std::move(progress));
4235 }
4236 
Get(const char * path,const Headers & headers,ResponseHandler response_handler,ContentReceiver content_receiver)4237 inline std::shared_ptr<Response> Client::Get(const char *path,
4238                                              const Headers &headers,
4239                                              ResponseHandler response_handler,
4240                                              ContentReceiver content_receiver) {
4241   return Get(path, headers, std::move(response_handler), content_receiver,
4242              Progress());
4243 }
4244 
Get(const char * path,const Headers & headers,ResponseHandler response_handler,ContentReceiver content_receiver,Progress progress)4245 inline std::shared_ptr<Response> Client::Get(const char *path,
4246                                              const Headers &headers,
4247                                              ResponseHandler response_handler,
4248                                              ContentReceiver content_receiver,
4249                                              Progress progress) {
4250   Request req;
4251   req.method = "GET";
4252   req.path = path;
4253   req.headers = headers;
4254   req.response_handler = std::move(response_handler);
4255   req.content_receiver = std::move(content_receiver);
4256   req.progress = std::move(progress);
4257 
4258   auto res = std::make_shared<Response>();
4259   return send(req, *res) ? res : nullptr;
4260 }
4261 
Head(const char * path)4262 inline std::shared_ptr<Response> Client::Head(const char *path) {
4263   return Head(path, Headers());
4264 }
4265 
Head(const char * path,const Headers & headers)4266 inline std::shared_ptr<Response> Client::Head(const char *path,
4267                                               const Headers &headers) {
4268   Request req;
4269   req.method = "HEAD";
4270   req.headers = headers;
4271   req.path = path;
4272 
4273   auto res = std::make_shared<Response>();
4274 
4275   return send(req, *res) ? res : nullptr;
4276 }
4277 
Post(const char * path)4278 inline std::shared_ptr<Response> Client::Post(const char *path) {
4279   return Post(path, std::string(), nullptr);
4280 }
4281 
Post(const char * path,const std::string & body,const char * content_type)4282 inline std::shared_ptr<Response> Client::Post(const char *path,
4283                                               const std::string &body,
4284                                               const char *content_type) {
4285   return Post(path, Headers(), body, content_type);
4286 }
4287 
Post(const char * path,const Headers & headers,const std::string & body,const char * content_type)4288 inline std::shared_ptr<Response> Client::Post(const char *path,
4289                                               const Headers &headers,
4290                                               const std::string &body,
4291                                               const char *content_type) {
4292   return send_with_content_provider("POST", path, headers, body, 0, nullptr,
4293                                     content_type);
4294 }
4295 
Post(const char * path,const Params & params)4296 inline std::shared_ptr<Response> Client::Post(const char *path,
4297                                               const Params &params) {
4298   return Post(path, Headers(), params);
4299 }
4300 
Post(const char * path,size_t content_length,ContentProvider content_provider,const char * content_type)4301 inline std::shared_ptr<Response> Client::Post(const char *path,
4302                                               size_t content_length,
4303                                               ContentProvider content_provider,
4304                                               const char *content_type) {
4305   return Post(path, Headers(), content_length, content_provider, content_type);
4306 }
4307 
4308 inline std::shared_ptr<Response>
Post(const char * path,const Headers & headers,size_t content_length,ContentProvider content_provider,const char * content_type)4309 Client::Post(const char *path, const Headers &headers, size_t content_length,
4310              ContentProvider content_provider, const char *content_type) {
4311   return send_with_content_provider("POST", path, headers, std::string(),
4312                                     content_length, content_provider,
4313                                     content_type);
4314 }
4315 
4316 inline std::shared_ptr<Response>
Post(const char * path,const Headers & headers,const Params & params)4317 Client::Post(const char *path, const Headers &headers, const Params &params) {
4318   auto query = detail::params_to_query_str(params);
4319   return Post(path, headers, query, "application/x-www-form-urlencoded");
4320 }
4321 
4322 inline std::shared_ptr<Response>
Post(const char * path,const MultipartFormDataItems & items)4323 Client::Post(const char *path, const MultipartFormDataItems &items) {
4324   return Post(path, Headers(), items);
4325 }
4326 
4327 inline std::shared_ptr<Response>
Post(const char * path,const Headers & headers,const MultipartFormDataItems & items)4328 Client::Post(const char *path, const Headers &headers,
4329              const MultipartFormDataItems &items) {
4330   auto boundary = detail::make_multipart_data_boundary();
4331 
4332   std::string body;
4333 
4334   for (const auto &item : items) {
4335     body += "--" + boundary + "\r\n";
4336     body += "Content-Disposition: form-data; name=\"" + item.name + "\"";
4337     if (!item.filename.empty()) {
4338       body += "; filename=\"" + item.filename + "\"";
4339     }
4340     body += "\r\n";
4341     if (!item.content_type.empty()) {
4342       body += "Content-Type: " + item.content_type + "\r\n";
4343     }
4344     body += "\r\n";
4345     body += item.content + "\r\n";
4346   }
4347 
4348   body += "--" + boundary + "--\r\n";
4349 
4350   std::string content_type = "multipart/form-data; boundary=" + boundary;
4351   return Post(path, headers, body, content_type.c_str());
4352 }
4353 
Put(const char * path)4354 inline std::shared_ptr<Response> Client::Put(const char *path) {
4355   return Put(path, std::string(), nullptr);
4356 }
4357 
Put(const char * path,const std::string & body,const char * content_type)4358 inline std::shared_ptr<Response> Client::Put(const char *path,
4359                                              const std::string &body,
4360                                              const char *content_type) {
4361   return Put(path, Headers(), body, content_type);
4362 }
4363 
Put(const char * path,const Headers & headers,const std::string & body,const char * content_type)4364 inline std::shared_ptr<Response> Client::Put(const char *path,
4365                                              const Headers &headers,
4366                                              const std::string &body,
4367                                              const char *content_type) {
4368   return send_with_content_provider("PUT", path, headers, body, 0, nullptr,
4369                                     content_type);
4370 }
4371 
Put(const char * path,size_t content_length,ContentProvider content_provider,const char * content_type)4372 inline std::shared_ptr<Response> Client::Put(const char *path,
4373                                              size_t content_length,
4374                                              ContentProvider content_provider,
4375                                              const char *content_type) {
4376   return Put(path, Headers(), content_length, content_provider, content_type);
4377 }
4378 
4379 inline std::shared_ptr<Response>
Put(const char * path,const Headers & headers,size_t content_length,ContentProvider content_provider,const char * content_type)4380 Client::Put(const char *path, const Headers &headers, size_t content_length,
4381             ContentProvider content_provider, const char *content_type) {
4382   return send_with_content_provider("PUT", path, headers, std::string(),
4383                                     content_length, content_provider,
4384                                     content_type);
4385 }
4386 
Put(const char * path,const Params & params)4387 inline std::shared_ptr<Response> Client::Put(const char *path,
4388                                              const Params &params) {
4389   return Put(path, Headers(), params);
4390 }
4391 
4392 inline std::shared_ptr<Response>
Put(const char * path,const Headers & headers,const Params & params)4393 Client::Put(const char *path, const Headers &headers, const Params &params) {
4394   auto query = detail::params_to_query_str(params);
4395   return Put(path, headers, query, "application/x-www-form-urlencoded");
4396 }
4397 
Patch(const char * path,const std::string & body,const char * content_type)4398 inline std::shared_ptr<Response> Client::Patch(const char *path,
4399                                                const std::string &body,
4400                                                const char *content_type) {
4401   return Patch(path, Headers(), body, content_type);
4402 }
4403 
Patch(const char * path,const Headers & headers,const std::string & body,const char * content_type)4404 inline std::shared_ptr<Response> Client::Patch(const char *path,
4405                                                const Headers &headers,
4406                                                const std::string &body,
4407                                                const char *content_type) {
4408   return send_with_content_provider("PATCH", path, headers, body, 0, nullptr,
4409                                     content_type);
4410 }
4411 
Patch(const char * path,size_t content_length,ContentProvider content_provider,const char * content_type)4412 inline std::shared_ptr<Response> Client::Patch(const char *path,
4413                                                size_t content_length,
4414                                                ContentProvider content_provider,
4415                                                const char *content_type) {
4416   return Patch(path, Headers(), content_length, content_provider, content_type);
4417 }
4418 
4419 inline std::shared_ptr<Response>
Patch(const char * path,const Headers & headers,size_t content_length,ContentProvider content_provider,const char * content_type)4420 Client::Patch(const char *path, const Headers &headers, size_t content_length,
4421               ContentProvider content_provider, const char *content_type) {
4422   return send_with_content_provider("PATCH", path, headers, std::string(),
4423                                     content_length, content_provider,
4424                                     content_type);
4425 }
4426 
Delete(const char * path)4427 inline std::shared_ptr<Response> Client::Delete(const char *path) {
4428   return Delete(path, Headers(), std::string(), nullptr);
4429 }
4430 
Delete(const char * path,const std::string & body,const char * content_type)4431 inline std::shared_ptr<Response> Client::Delete(const char *path,
4432                                                 const std::string &body,
4433                                                 const char *content_type) {
4434   return Delete(path, Headers(), body, content_type);
4435 }
4436 
Delete(const char * path,const Headers & headers)4437 inline std::shared_ptr<Response> Client::Delete(const char *path,
4438                                                 const Headers &headers) {
4439   return Delete(path, headers, std::string(), nullptr);
4440 }
4441 
Delete(const char * path,const Headers & headers,const std::string & body,const char * content_type)4442 inline std::shared_ptr<Response> Client::Delete(const char *path,
4443                                                 const Headers &headers,
4444                                                 const std::string &body,
4445                                                 const char *content_type) {
4446   Request req;
4447   req.method = "DELETE";
4448   req.headers = headers;
4449   req.path = path;
4450 
4451   if (content_type) { req.headers.emplace("Content-Type", content_type); }
4452   req.body = body;
4453 
4454   auto res = std::make_shared<Response>();
4455 
4456   return send(req, *res) ? res : nullptr;
4457 }
4458 
Options(const char * path)4459 inline std::shared_ptr<Response> Client::Options(const char *path) {
4460   return Options(path, Headers());
4461 }
4462 
Options(const char * path,const Headers & headers)4463 inline std::shared_ptr<Response> Client::Options(const char *path,
4464                                                  const Headers &headers) {
4465   Request req;
4466   req.method = "OPTIONS";
4467   req.path = path;
4468   req.headers = headers;
4469 
4470   auto res = std::make_shared<Response>();
4471 
4472   return send(req, *res) ? res : nullptr;
4473 }
4474 
stop()4475 inline void Client::stop() {
4476   if (sock_ != INVALID_SOCKET) {
4477     std::atomic<socket_t> sock(sock_.exchange(INVALID_SOCKET));
4478     detail::shutdown_socket(sock);
4479     detail::close_socket(sock);
4480   }
4481 }
4482 
set_timeout_sec(time_t timeout_sec)4483 inline void Client::set_timeout_sec(time_t timeout_sec) {
4484   timeout_sec_ = timeout_sec;
4485 }
4486 
set_read_timeout(time_t sec,time_t usec)4487 inline void Client::set_read_timeout(time_t sec, time_t usec) {
4488   read_timeout_sec_ = sec;
4489   read_timeout_usec_ = usec;
4490 }
4491 
set_keep_alive_max_count(size_t count)4492 inline void Client::set_keep_alive_max_count(size_t count) {
4493   keep_alive_max_count_ = count;
4494 }
4495 
set_basic_auth(const char * username,const char * password)4496 inline void Client::set_basic_auth(const char *username, const char *password) {
4497   basic_auth_username_ = username;
4498   basic_auth_password_ = password;
4499 }
4500 
4501 #ifdef CPPHTTPLIB_OPENSSL_SUPPORT
set_digest_auth(const char * username,const char * password)4502 inline void Client::set_digest_auth(const char *username,
4503                                     const char *password) {
4504   digest_auth_username_ = username;
4505   digest_auth_password_ = password;
4506 }
4507 #endif
4508 
set_follow_location(bool on)4509 inline void Client::set_follow_location(bool on) { follow_location_ = on; }
4510 
set_compress(bool on)4511 inline void Client::set_compress(bool on) { compress_ = on; }
4512 
set_interface(const char * intf)4513 inline void Client::set_interface(const char *intf) { interface_ = intf; }
4514 
set_proxy(const char * host,int port)4515 inline void Client::set_proxy(const char *host, int port) {
4516   proxy_host_ = host;
4517   proxy_port_ = port;
4518 }
4519 
set_proxy_basic_auth(const char * username,const char * password)4520 inline void Client::set_proxy_basic_auth(const char *username,
4521                                          const char *password) {
4522   proxy_basic_auth_username_ = username;
4523   proxy_basic_auth_password_ = password;
4524 }
4525 
4526 #ifdef CPPHTTPLIB_OPENSSL_SUPPORT
set_proxy_digest_auth(const char * username,const char * password)4527 inline void Client::set_proxy_digest_auth(const char *username,
4528                                           const char *password) {
4529   proxy_digest_auth_username_ = username;
4530   proxy_digest_auth_password_ = password;
4531 }
4532 #endif
4533 
set_logger(Logger logger)4534 inline void Client::set_logger(Logger logger) { logger_ = std::move(logger); }
4535 
4536 /*
4537  * SSL Implementation
4538  */
4539 #ifdef CPPHTTPLIB_OPENSSL_SUPPORT
4540 namespace detail {
4541 
4542 template <typename U, typename V, typename T>
process_and_close_socket_ssl(bool is_client_request,socket_t sock,size_t keep_alive_max_count,time_t read_timeout_sec,time_t read_timeout_usec,SSL_CTX * ctx,std::mutex & ctx_mutex,U SSL_connect_or_accept,V setup,T callback)4543 inline bool process_and_close_socket_ssl(
4544     bool is_client_request, socket_t sock, size_t keep_alive_max_count,
4545     time_t read_timeout_sec, time_t read_timeout_usec, SSL_CTX *ctx,
4546     std::mutex &ctx_mutex, U SSL_connect_or_accept, V setup, T callback) {
4547   assert(keep_alive_max_count > 0);
4548 
4549   SSL *ssl = nullptr;
4550   {
4551     std::lock_guard<std::mutex> guard(ctx_mutex);
4552     ssl = SSL_new(ctx);
4553   }
4554 
4555   if (!ssl) {
4556     close_socket(sock);
4557     return false;
4558   }
4559 
4560   auto bio = BIO_new_socket(static_cast<int>(sock), BIO_NOCLOSE);
4561   SSL_set_bio(ssl, bio, bio);
4562 
4563   if (!setup(ssl)) {
4564     SSL_shutdown(ssl);
4565     {
4566       std::lock_guard<std::mutex> guard(ctx_mutex);
4567       SSL_free(ssl);
4568     }
4569 
4570     close_socket(sock);
4571     return false;
4572   }
4573 
4574   auto ret = false;
4575 
4576   if (SSL_connect_or_accept(ssl) == 1) {
4577     if (keep_alive_max_count > 1) {
4578       auto count = keep_alive_max_count;
4579       while (count > 0 &&
4580              (is_client_request ||
4581               select_read(sock, CPPHTTPLIB_KEEPALIVE_TIMEOUT_SECOND,
4582                           CPPHTTPLIB_KEEPALIVE_TIMEOUT_USECOND) > 0)) {
4583         SSLSocketStream strm(sock, ssl, read_timeout_sec, read_timeout_usec);
4584         auto last_connection = count == 1;
4585         auto connection_close = false;
4586 
4587         ret = callback(ssl, strm, last_connection, connection_close);
4588         if (!ret || connection_close) { break; }
4589 
4590         count--;
4591       }
4592     } else {
4593       SSLSocketStream strm(sock, ssl, read_timeout_sec, read_timeout_usec);
4594       auto dummy_connection_close = false;
4595       ret = callback(ssl, strm, true, dummy_connection_close);
4596     }
4597   }
4598 
4599   if (ret) {
4600     SSL_shutdown(ssl); // shutdown only if not already closed by remote
4601   }
4602   {
4603     std::lock_guard<std::mutex> guard(ctx_mutex);
4604     SSL_free(ssl);
4605   }
4606 
4607   close_socket(sock);
4608 
4609   return ret;
4610 }
4611 
4612 #if OPENSSL_VERSION_NUMBER < 0x10100000L
4613 static std::shared_ptr<std::vector<std::mutex>> openSSL_locks_;
4614 
4615 class SSLThreadLocks {
4616 public:
SSLThreadLocks()4617   SSLThreadLocks() {
4618     openSSL_locks_ =
4619         std::make_shared<std::vector<std::mutex>>(CRYPTO_num_locks());
4620     CRYPTO_set_locking_callback(locking_callback);
4621   }
4622 
~SSLThreadLocks()4623   ~SSLThreadLocks() { CRYPTO_set_locking_callback(nullptr); }
4624 
4625 private:
locking_callback(int mode,int type,const char *,int)4626   static void locking_callback(int mode, int type, const char * /*file*/,
4627                                int /*line*/) {
4628     auto &lk = (*openSSL_locks_)[static_cast<size_t>(type)];
4629     if (mode & CRYPTO_LOCK) {
4630       lk.lock();
4631     } else {
4632       lk.unlock();
4633     }
4634   }
4635 };
4636 
4637 #endif
4638 
4639 class SSLInit {
4640 public:
SSLInit()4641   SSLInit() {
4642 #if OPENSSL_VERSION_NUMBER < 0x1010001fL
4643     SSL_load_error_strings();
4644     SSL_library_init();
4645 #else
4646     OPENSSL_init_ssl(
4647         OPENSSL_INIT_LOAD_SSL_STRINGS | OPENSSL_INIT_LOAD_CRYPTO_STRINGS, NULL);
4648 #endif
4649   }
4650 
~SSLInit()4651   ~SSLInit() {
4652 #if OPENSSL_VERSION_NUMBER < 0x1010001fL
4653     ERR_free_strings();
4654 #endif
4655   }
4656 
4657 private:
4658 #if OPENSSL_VERSION_NUMBER < 0x10100000L
4659   SSLThreadLocks thread_init_;
4660 #endif
4661 };
4662 
4663 // SSL socket stream implementation
SSLSocketStream(socket_t sock,SSL * ssl,time_t read_timeout_sec,time_t read_timeout_usec)4664 inline SSLSocketStream::SSLSocketStream(socket_t sock, SSL *ssl,
4665                                         time_t read_timeout_sec,
4666                                         time_t read_timeout_usec)
4667     : sock_(sock), ssl_(ssl), read_timeout_sec_(read_timeout_sec),
4668       read_timeout_usec_(read_timeout_usec) {}
4669 
~SSLSocketStream()4670 inline SSLSocketStream::~SSLSocketStream() {}
4671 
is_readable()4672 inline bool SSLSocketStream::is_readable() const {
4673   return detail::select_read(sock_, read_timeout_sec_, read_timeout_usec_) > 0;
4674 }
4675 
is_writable()4676 inline bool SSLSocketStream::is_writable() const {
4677   return detail::select_write(sock_, 0, 0) > 0;
4678 }
4679 
read(char * ptr,size_t size)4680 inline ssize_t SSLSocketStream::read(char *ptr, size_t size) {
4681   if (SSL_pending(ssl_) > 0 ||
4682       select_read(sock_, read_timeout_sec_, read_timeout_usec_) > 0) {
4683     return SSL_read(ssl_, ptr, static_cast<int>(size));
4684   }
4685   return -1;
4686 }
4687 
write(const char * ptr,size_t size)4688 inline ssize_t SSLSocketStream::write(const char *ptr, size_t size) {
4689   if (is_writable()) { return SSL_write(ssl_, ptr, static_cast<int>(size)); }
4690   return -1;
4691 }
4692 
get_remote_ip_and_port(std::string & ip,int & port)4693 inline void SSLSocketStream::get_remote_ip_and_port(std::string &ip,
4694                                                     int &port) const {
4695   detail::get_remote_ip_and_port(sock_, ip, port);
4696 }
4697 
4698 static SSLInit sslinit_;
4699 
4700 } // namespace detail
4701 
4702 // SSL HTTP server implementation
SSLServer(const char * cert_path,const char * private_key_path,const char * client_ca_cert_file_path,const char * client_ca_cert_dir_path)4703 inline SSLServer::SSLServer(const char *cert_path, const char *private_key_path,
4704                             const char *client_ca_cert_file_path,
4705                             const char *client_ca_cert_dir_path) {
4706   ctx_ = SSL_CTX_new(SSLv23_server_method());
4707 
4708   if (ctx_) {
4709     SSL_CTX_set_options(ctx_,
4710                         SSL_OP_ALL | SSL_OP_NO_SSLv2 | SSL_OP_NO_SSLv3 |
4711                             SSL_OP_NO_COMPRESSION |
4712                             SSL_OP_NO_SESSION_RESUMPTION_ON_RENEGOTIATION);
4713 
4714     // auto ecdh = EC_KEY_new_by_curve_name(NID_X9_62_prime256v1);
4715     // SSL_CTX_set_tmp_ecdh(ctx_, ecdh);
4716     // EC_KEY_free(ecdh);
4717 
4718     if (SSL_CTX_use_certificate_chain_file(ctx_, cert_path) != 1 ||
4719         SSL_CTX_use_PrivateKey_file(ctx_, private_key_path, SSL_FILETYPE_PEM) !=
4720             1) {
4721       SSL_CTX_free(ctx_);
4722       ctx_ = nullptr;
4723     } else if (client_ca_cert_file_path || client_ca_cert_dir_path) {
4724       // if (client_ca_cert_file_path) {
4725       //   auto list = SSL_load_client_CA_file(client_ca_cert_file_path);
4726       //   SSL_CTX_set_client_CA_list(ctx_, list);
4727       // }
4728 
4729       SSL_CTX_load_verify_locations(ctx_, client_ca_cert_file_path,
4730                                     client_ca_cert_dir_path);
4731 
4732       SSL_CTX_set_verify(
4733           ctx_,
4734           SSL_VERIFY_PEER |
4735               SSL_VERIFY_FAIL_IF_NO_PEER_CERT, // SSL_VERIFY_CLIENT_ONCE,
4736           nullptr);
4737     }
4738   }
4739 }
4740 
SSLServer(X509 * cert,EVP_PKEY * private_key,X509_STORE * client_ca_cert_store)4741 inline SSLServer::SSLServer(X509 *cert, EVP_PKEY *private_key,
4742                             X509_STORE *client_ca_cert_store) {
4743   ctx_ = SSL_CTX_new(SSLv23_server_method());
4744 
4745   if (ctx_) {
4746     SSL_CTX_set_options(ctx_,
4747                         SSL_OP_ALL | SSL_OP_NO_SSLv2 | SSL_OP_NO_SSLv3 |
4748                             SSL_OP_NO_COMPRESSION |
4749                             SSL_OP_NO_SESSION_RESUMPTION_ON_RENEGOTIATION);
4750 
4751     if (SSL_CTX_use_certificate(ctx_, cert) != 1 ||
4752         SSL_CTX_use_PrivateKey(ctx_, private_key) != 1) {
4753       SSL_CTX_free(ctx_);
4754       ctx_ = nullptr;
4755     } else if (client_ca_cert_store) {
4756 
4757       SSL_CTX_set_cert_store(ctx_, client_ca_cert_store);
4758 
4759       SSL_CTX_set_verify(
4760           ctx_,
4761           SSL_VERIFY_PEER |
4762               SSL_VERIFY_FAIL_IF_NO_PEER_CERT, // SSL_VERIFY_CLIENT_ONCE,
4763           nullptr);
4764     }
4765   }
4766 }
4767 
~SSLServer()4768 inline SSLServer::~SSLServer() {
4769   if (ctx_) { SSL_CTX_free(ctx_); }
4770 }
4771 
is_valid()4772 inline bool SSLServer::is_valid() const { return ctx_; }
4773 
process_and_close_socket(socket_t sock)4774 inline bool SSLServer::process_and_close_socket(socket_t sock) {
4775   return detail::process_and_close_socket_ssl(
4776       false, sock, keep_alive_max_count_, read_timeout_sec_, read_timeout_usec_,
4777       ctx_, ctx_mutex_, SSL_accept, [](SSL * /*ssl*/) { return true; },
4778       [this](SSL *ssl, Stream &strm, bool last_connection,
4779              bool &connection_close) {
4780         return process_request(strm, last_connection, connection_close,
4781                                [&](Request &req) { req.ssl = ssl; });
4782       });
4783 }
4784 
4785 // SSL HTTP client implementation
SSLClient(const std::string & host,int port,const std::string & client_cert_path,const std::string & client_key_path)4786 inline SSLClient::SSLClient(const std::string &host, int port,
4787                             const std::string &client_cert_path,
4788                             const std::string &client_key_path)
4789     : Client(host, port, client_cert_path, client_key_path) {
4790   ctx_ = SSL_CTX_new(SSLv23_client_method());
4791 
4792   detail::split(&host_[0], &host_[host_.size()], '.',
4793                 [&](const char *b, const char *e) {
4794                   host_components_.emplace_back(std::string(b, e));
4795                 });
4796   if (!client_cert_path.empty() && !client_key_path.empty()) {
4797     if (SSL_CTX_use_certificate_file(ctx_, client_cert_path.c_str(),
4798                                      SSL_FILETYPE_PEM) != 1 ||
4799         SSL_CTX_use_PrivateKey_file(ctx_, client_key_path.c_str(),
4800                                     SSL_FILETYPE_PEM) != 1) {
4801       SSL_CTX_free(ctx_);
4802       ctx_ = nullptr;
4803     }
4804   }
4805 }
4806 
SSLClient(const std::string & host,int port,X509 * client_cert,EVP_PKEY * client_key)4807 inline SSLClient::SSLClient(const std::string &host, int port,
4808                             X509 *client_cert, EVP_PKEY *client_key)
4809     : Client(host, port) {
4810   ctx_ = SSL_CTX_new(SSLv23_client_method());
4811 
4812   detail::split(&host_[0], &host_[host_.size()], '.',
4813                 [&](const char *b, const char *e) {
4814                   host_components_.emplace_back(std::string(b, e));
4815                 });
4816   if (client_cert != nullptr && client_key != nullptr) {
4817     if (SSL_CTX_use_certificate(ctx_, client_cert) != 1 ||
4818         SSL_CTX_use_PrivateKey(ctx_, client_key) != 1) {
4819       SSL_CTX_free(ctx_);
4820       ctx_ = nullptr;
4821     }
4822   }
4823 }
4824 
~SSLClient()4825 inline SSLClient::~SSLClient() {
4826   if (ctx_) { SSL_CTX_free(ctx_); }
4827 }
4828 
is_valid()4829 inline bool SSLClient::is_valid() const { return ctx_; }
4830 
set_ca_cert_path(const char * ca_cert_file_path,const char * ca_cert_dir_path)4831 inline void SSLClient::set_ca_cert_path(const char *ca_cert_file_path,
4832                                         const char *ca_cert_dir_path) {
4833   if (ca_cert_file_path) { ca_cert_file_path_ = ca_cert_file_path; }
4834   if (ca_cert_dir_path) { ca_cert_dir_path_ = ca_cert_dir_path; }
4835 }
4836 
set_ca_cert_store(X509_STORE * ca_cert_store)4837 inline void SSLClient::set_ca_cert_store(X509_STORE *ca_cert_store) {
4838   if (ca_cert_store) { ca_cert_store_ = ca_cert_store; }
4839 }
4840 
enable_server_certificate_verification(bool enabled)4841 inline void SSLClient::enable_server_certificate_verification(bool enabled) {
4842   server_certificate_verification_ = enabled;
4843 }
4844 
get_openssl_verify_result()4845 inline long SSLClient::get_openssl_verify_result() const {
4846   return verify_result_;
4847 }
4848 
ssl_context()4849 inline SSL_CTX *SSLClient::ssl_context() const { return ctx_; }
4850 
process_and_close_socket(socket_t sock,size_t request_count,std::function<bool (Stream & strm,bool last_connection,bool & connection_close)> callback)4851 inline bool SSLClient::process_and_close_socket(
4852     socket_t sock, size_t request_count,
4853     std::function<bool(Stream &strm, bool last_connection,
4854                        bool &connection_close)>
4855         callback) {
4856 
4857   request_count = std::min(request_count, keep_alive_max_count_);
4858 
4859   return is_valid() &&
4860          detail::process_and_close_socket_ssl(
4861              true, sock, request_count, read_timeout_sec_, read_timeout_usec_,
4862              ctx_, ctx_mutex_,
4863              [&](SSL *ssl) {
4864                if (ca_cert_file_path_.empty() && ca_cert_store_ == nullptr) {
4865                  SSL_CTX_set_verify(ctx_, SSL_VERIFY_NONE, nullptr);
4866                } else if (!ca_cert_file_path_.empty()) {
4867                  if (!SSL_CTX_load_verify_locations(
4868                          ctx_, ca_cert_file_path_.c_str(), nullptr)) {
4869                    return false;
4870                  }
4871                  SSL_CTX_set_verify(ctx_, SSL_VERIFY_PEER, nullptr);
4872                } else if (ca_cert_store_ != nullptr) {
4873                  if (SSL_CTX_get_cert_store(ctx_) != ca_cert_store_) {
4874                    SSL_CTX_set_cert_store(ctx_, ca_cert_store_);
4875                  }
4876                  SSL_CTX_set_verify(ctx_, SSL_VERIFY_PEER, nullptr);
4877                }
4878 
4879                if (SSL_connect(ssl) != 1) { return false; }
4880 
4881                if (server_certificate_verification_) {
4882                  verify_result_ = SSL_get_verify_result(ssl);
4883 
4884                  if (verify_result_ != X509_V_OK) { return false; }
4885 
4886                  auto server_cert = SSL_get_peer_certificate(ssl);
4887 
4888                  if (server_cert == nullptr) { return false; }
4889 
4890                  if (!verify_host(server_cert)) {
4891                    X509_free(server_cert);
4892                    return false;
4893                  }
4894                  X509_free(server_cert);
4895                }
4896 
4897                return true;
4898              },
4899              [&](SSL *ssl) {
4900                SSL_set_tlsext_host_name(ssl, host_.c_str());
4901                return true;
4902              },
4903              [&](SSL * /*ssl*/, Stream &strm, bool last_connection,
4904                  bool &connection_close) {
4905                return callback(strm, last_connection, connection_close);
4906              });
4907 }
4908 
is_ssl()4909 inline bool SSLClient::is_ssl() const { return true; }
4910 
verify_host(X509 * server_cert)4911 inline bool SSLClient::verify_host(X509 *server_cert) const {
4912   /* Quote from RFC2818 section 3.1 "Server Identity"
4913 
4914      If a subjectAltName extension of type dNSName is present, that MUST
4915      be used as the identity. Otherwise, the (most specific) Common Name
4916      field in the Subject field of the certificate MUST be used. Although
4917      the use of the Common Name is existing practice, it is deprecated and
4918      Certification Authorities are encouraged to use the dNSName instead.
4919 
4920      Matching is performed using the matching rules specified by
4921      [RFC2459].  If more than one identity of a given type is present in
4922      the certificate (e.g., more than one dNSName name, a match in any one
4923      of the set is considered acceptable.) Names may contain the wildcard
4924      character * which is considered to match any single domain name
4925      component or component fragment. E.g., *.a.com matches foo.a.com but
4926      not bar.foo.a.com. f*.com matches foo.com but not bar.com.
4927 
4928      In some cases, the URI is specified as an IP address rather than a
4929      hostname. In this case, the iPAddress subjectAltName must be present
4930      in the certificate and must exactly match the IP in the URI.
4931 
4932   */
4933   return verify_host_with_subject_alt_name(server_cert) ||
4934          verify_host_with_common_name(server_cert);
4935 }
4936 
4937 inline bool
verify_host_with_subject_alt_name(X509 * server_cert)4938 SSLClient::verify_host_with_subject_alt_name(X509 *server_cert) const {
4939   auto ret = false;
4940 
4941   auto type = GEN_DNS;
4942 
4943   struct in6_addr addr6;
4944   struct in_addr addr;
4945   size_t addr_len = 0;
4946 
4947 #ifndef __MINGW32__
4948   if (inet_pton(AF_INET6, host_.c_str(), &addr6)) {
4949     type = GEN_IPADD;
4950     addr_len = sizeof(struct in6_addr);
4951   } else if (inet_pton(AF_INET, host_.c_str(), &addr)) {
4952     type = GEN_IPADD;
4953     addr_len = sizeof(struct in_addr);
4954   }
4955 #endif
4956 
4957   auto alt_names = static_cast<const struct stack_st_GENERAL_NAME *>(
4958       X509_get_ext_d2i(server_cert, NID_subject_alt_name, nullptr, nullptr));
4959 
4960   if (alt_names) {
4961     auto dsn_matched = false;
4962     auto ip_mached = false;
4963 
4964     auto count = sk_GENERAL_NAME_num(alt_names);
4965 
4966     for (auto i = 0; i < count && !dsn_matched; i++) {
4967       auto val = sk_GENERAL_NAME_value(alt_names, i);
4968       if (val->type == type) {
4969         auto name = (const char *)ASN1_STRING_get0_data(val->d.ia5);
4970         auto name_len = (size_t)ASN1_STRING_length(val->d.ia5);
4971 
4972         if (strlen(name) == name_len) {
4973           switch (type) {
4974           case GEN_DNS: dsn_matched = check_host_name(name, name_len); break;
4975 
4976           case GEN_IPADD:
4977             if (!memcmp(&addr6, name, addr_len) ||
4978                 !memcmp(&addr, name, addr_len)) {
4979               ip_mached = true;
4980             }
4981             break;
4982           }
4983         }
4984       }
4985     }
4986 
4987     if (dsn_matched || ip_mached) { ret = true; }
4988   }
4989 
4990   GENERAL_NAMES_free((STACK_OF(GENERAL_NAME) *)alt_names);
4991 
4992   return ret;
4993 }
4994 
verify_host_with_common_name(X509 * server_cert)4995 inline bool SSLClient::verify_host_with_common_name(X509 *server_cert) const {
4996   const auto subject_name = X509_get_subject_name(server_cert);
4997 
4998   if (subject_name != nullptr) {
4999     char name[BUFSIZ];
5000     auto name_len = X509_NAME_get_text_by_NID(subject_name, NID_commonName,
5001                                               name, sizeof(name));
5002 
5003     if (name_len != -1) {
5004       return check_host_name(name, static_cast<size_t>(name_len));
5005     }
5006   }
5007 
5008   return false;
5009 }
5010 
check_host_name(const char * pattern,size_t pattern_len)5011 inline bool SSLClient::check_host_name(const char *pattern,
5012                                        size_t pattern_len) const {
5013   if (host_.size() == pattern_len && host_ == pattern) { return true; }
5014 
5015   // Wildcard match
5016   // https://bugs.launchpad.net/ubuntu/+source/firefox-3.0/+bug/376484
5017   std::vector<std::string> pattern_components;
5018   detail::split(&pattern[0], &pattern[pattern_len], '.',
5019                 [&](const char *b, const char *e) {
5020                   pattern_components.emplace_back(std::string(b, e));
5021                 });
5022 
5023   if (host_components_.size() != pattern_components.size()) { return false; }
5024 
5025   auto itr = pattern_components.begin();
5026   for (const auto &h : host_components_) {
5027     auto &p = *itr;
5028     if (p != h && p != "*") {
5029       auto partial_match = (p.size() > 0 && p[p.size() - 1] == '*' &&
5030                             !p.compare(0, p.size() - 1, h));
5031       if (!partial_match) { return false; }
5032     }
5033     ++itr;
5034   }
5035 
5036   return true;
5037 }
5038 #endif
5039 
5040 namespace url {
5041 
5042 struct Options {
5043   // TODO: support more options...
5044   bool follow_location = false;
5045   std::string client_cert_path;
5046   std::string client_key_path;
5047 
5048   std::string ca_cert_file_path;
5049   std::string ca_cert_dir_path;
5050   bool server_certificate_verification = false;
5051 };
5052 
Get(const char * url,Options & options)5053 inline std::shared_ptr<Response> Get(const char *url, Options &options) {
5054   const static std::regex re(
5055       R"(^(https?)://([^:/?#]+)(?::(\d+))?([^?#]*(?:\?[^#]*)?)(?:#.*)?)");
5056 
5057   std::cmatch m;
5058   if (!std::regex_match(url, m, re)) { return nullptr; }
5059 
5060   auto next_scheme = m[1].str();
5061   auto next_host = m[2].str();
5062   auto port_str = m[3].str();
5063   auto next_path = m[4].str();
5064 
5065   auto next_port = !port_str.empty() ? std::stoi(port_str)
5066                                      : (next_scheme == "https" ? 443 : 80);
5067 
5068   if (next_path.empty()) { next_path = "/"; }
5069 
5070   if (next_scheme == "https") {
5071 #ifdef CPPHTTPLIB_OPENSSL_SUPPORT
5072     SSLClient cli(next_host.c_str(), next_port, options.client_cert_path,
5073                   options.client_key_path);
5074     cli.set_follow_location(options.follow_location);
5075     cli.set_ca_cert_path(options.ca_cert_file_path.c_str(),
5076                          options.ca_cert_dir_path.c_str());
5077     cli.enable_server_certificate_verification(
5078         options.server_certificate_verification);
5079     return cli.Get(next_path.c_str());
5080 #else
5081     return nullptr;
5082 #endif
5083   } else {
5084     Client cli(next_host.c_str(), next_port, options.client_cert_path,
5085                options.client_key_path);
5086     cli.set_follow_location(options.follow_location);
5087     return cli.Get(next_path.c_str());
5088   }
5089 }
5090 
Get(const char * url)5091 inline std::shared_ptr<Response> Get(const char *url) {
5092   Options options;
5093   return Get(url, options);
5094 }
5095 
5096 } // namespace url
5097 
5098 // ----------------------------------------------------------------------------
5099 
5100 } // namespace httplib
5101 
5102 #endif // CPPHTTPLIB_HTTPLIB_H
5103