1 // Copyright (C) 2006  Davis E. King (davis@dlib.net)
2 // License: Boost Software License   See LICENSE.txt for the full license.
3 #ifndef DLIB_SOCKETS_EXTENSIONs_CPP
4 #define DLIB_SOCKETS_EXTENSIONs_CPP
5 
6 #include <string>
7 #include <sstream>
8 #include "../sockets.h"
9 #include "../error.h"
10 #include "sockets_extensions.h"
11 #include "../timer.h"
12 #include "../algs.h"
13 #include "../timeout.h"
14 #include "../misc_api.h"
15 #include "../serialize.h"
16 #include "../string.h"
17 
18 namespace dlib
19 {
20 
21 // ----------------------------------------------------------------------------------------
22 
23     network_address::
network_address(const std::string & full_address)24     network_address(
25         const std::string& full_address
26     )
27     {
28         std::istringstream sin(full_address);
29         sin >> *this;
30         if (!sin || sin.peek() != EOF)
31             throw invalid_network_address("invalid network address: " + full_address);
32     }
33 
34 // ----------------------------------------------------------------------------------------
35 
serialize(const network_address & item,std::ostream & out)36     void serialize(
37         const network_address& item,
38         std::ostream& out
39     )
40     {
41         serialize(item.host_address, out);
42         serialize(item.port, out);
43     }
44 
45 // ----------------------------------------------------------------------------------------
46 
deserialize(network_address & item,std::istream & in)47     void deserialize(
48         network_address& item,
49         std::istream& in
50     )
51     {
52         deserialize(item.host_address, in);
53         deserialize(item.port, in);
54     }
55 
56 // ----------------------------------------------------------------------------------------
57 
operator <<(std::ostream & out,const network_address & item)58     std::ostream& operator<< (
59         std::ostream& out,
60         const network_address& item
61     )
62     {
63         out << item.host_address << ":" << item.port;
64         return out;
65     }
66 
67 // ----------------------------------------------------------------------------------------
68 
operator >>(std::istream & in,network_address & item)69     std::istream& operator>> (
70         std::istream& in,
71         network_address& item
72     )
73     {
74         std::string temp;
75         in >> temp;
76 
77         std::string::size_type pos = temp.find_last_of(":");
78         if (pos == std::string::npos)
79         {
80             in.setstate(std::ios::badbit);
81             return in;
82         }
83 
84         item.host_address = temp.substr(0, pos);
85         try
86         {
87             item.port = sa = temp.substr(pos+1);
88         } catch (std::exception& )
89         {
90             in.setstate(std::ios::badbit);
91             return in;
92         }
93 
94 
95         return in;
96     }
97 
98 // ----------------------------------------------------------------------------------------
99 // ----------------------------------------------------------------------------------------
100 
connect(const std::string & host_or_ip,unsigned short port)101     connection* connect (
102         const std::string& host_or_ip,
103         unsigned short port
104     )
105     {
106         std::string ip;
107         connection* con;
108         if (is_ip_address(host_or_ip))
109         {
110             ip = host_or_ip;
111         }
112         else
113         {
114             if( hostname_to_ip(host_or_ip,ip))
115                 throw socket_error(ERESOLVE,"unable to resolve '" + host_or_ip + "' in connect()");
116         }
117 
118         if(create_connection(con,port,ip))
119         {
120             std::ostringstream sout;
121             sout << "unable to connect to '" << host_or_ip << ":" << port << "'";
122             throw socket_error(sout.str());
123         }
124 
125         return con;
126     }
127 
128 // ----------------------------------------------------------------------------------------
129 
connect(const network_address & addr)130     connection* connect (
131         const network_address& addr
132     )
133     {
134         return connect(addr.host_address, addr.port);
135     }
136 
137 // ----------------------------------------------------------------------------------------
138 
139     namespace connect_timeout_helpers
140     {
141         mutex connect_mutex;
142         signaler connect_signaler(connect_mutex);
143         timestamper ts;
144         long outstanding_connects = 0;
145 
146         struct thread_data
147         {
148             std::string host_or_ip;
149             unsigned short port;
150             connection* con;
151             bool connect_ended;
152             bool error_occurred;
153         };
154 
thread(void * param)155         void thread(void* param)
156         {
157             thread_data p = *static_cast<thread_data*>(param);
158             try
159             {
160                 p.con = connect(p.host_or_ip, p.port);
161             }
162             catch (...)
163             {
164                 p.error_occurred = true;
165             }
166 
167             auto_mutex M(connect_mutex);
168             // report the results back to the connect() call that spawned this
169             // thread.
170             static_cast<thread_data*>(param)->con = p.con;
171             static_cast<thread_data*>(param)->error_occurred = p.error_occurred;
172             connect_signaler.broadcast();
173 
174             // wait for the call to connect() that spawned this thread to terminate
175             // before we delete the thread_data struct.
176             while (static_cast<thread_data*>(param)->connect_ended == false)
177                 connect_signaler.wait();
178 
179             connect_signaler.broadcast();
180             --outstanding_connects;
181             delete static_cast<thread_data*>(param);
182         }
183     }
184 
connect(const std::string & host_or_ip,unsigned short port,unsigned long timeout)185     connection* connect (
186         const std::string& host_or_ip,
187         unsigned short port,
188         unsigned long timeout
189     )
190     {
191         using namespace connect_timeout_helpers;
192 
193         auto_mutex M(connect_mutex);
194 
195         const uint64 end_time = ts.get_timestamp() + timeout*1000;
196 
197 
198         // wait until there are less than 100 outstanding connections
199         while (outstanding_connects > 100)
200         {
201             uint64 cur_time = ts.get_timestamp();
202             if (end_time > cur_time)
203             {
204                 timeout = static_cast<unsigned long>((end_time - cur_time)/1000);
205             }
206             else
207             {
208                 throw socket_error("unable to connect to '" + host_or_ip + "' because connect timed out");
209             }
210 
211             connect_signaler.wait_or_timeout(timeout);
212         }
213 
214 
215         thread_data* data = new thread_data;
216         data->host_or_ip = host_or_ip.c_str();
217         data->port = port;
218         data->con = 0;
219         data->connect_ended = false;
220         data->error_occurred = false;
221 
222 
223         if (create_new_thread(thread, data) == false)
224         {
225             delete data;
226             throw socket_error("unable to connect to '" + host_or_ip);
227         }
228 
229         ++outstanding_connects;
230 
231         // wait until we have a connection object
232         while (data->con == 0)
233         {
234             uint64 cur_time = ts.get_timestamp();
235             if (end_time > cur_time && data->error_occurred == false)
236             {
237                 timeout = static_cast<unsigned long>((end_time - cur_time)/1000);
238             }
239             else
240             {
241                 // let the thread know that it should terminate
242                 data->connect_ended = true;
243                 connect_signaler.broadcast();
244                 if (data->error_occurred)
245                     throw socket_error("unable to connect to '" + host_or_ip);
246                 else
247                     throw socket_error("unable to connect to '" + host_or_ip + "' because connect timed out");
248             }
249 
250             connect_signaler.wait_or_timeout(timeout);
251         }
252 
253         // let the thread know that it should terminate
254         data->connect_ended = true;
255         connect_signaler.broadcast();
256         return data->con;
257     }
258 
259 // ----------------------------------------------------------------------------------------
260 
is_ip_address(std::string ip)261     bool is_ip_address (
262         std::string ip
263     )
264     {
265         for (std::string::size_type i = 0; i < ip.size(); ++i)
266         {
267             if (ip[i] == '.')
268                 ip[i] = ' ';
269         }
270         std::istringstream sin(ip);
271 
272         bool bad = false;
273         int num;
274         for (int i = 0; i < 4; ++i)
275         {
276             sin >> num;
277             if (!sin || num < 0 || num > 255)
278             {
279                 bad = true;
280                 break;
281             }
282         }
283 
284         if (sin.get() != EOF)
285             bad = true;
286 
287         return !bad;
288     }
289 
290 // ----------------------------------------------------------------------------------------
291 
close_gracefully(connection * con,unsigned long timeout)292     void close_gracefully (
293         connection* con,
294         unsigned long timeout
295     )
296     {
297         std::unique_ptr<connection> ptr(con);
298         close_gracefully(ptr,timeout);
299     }
300 
301 // ----------------------------------------------------------------------------------------
302 
close_gracefully(std::unique_ptr<connection> & con,unsigned long timeout)303     void close_gracefully (
304         std::unique_ptr<connection>& con,
305         unsigned long timeout
306     )
307     {
308         if (!con)
309             return;
310 
311         if(con->shutdown_outgoing())
312         {
313             // there was an error so just close it now and return
314             con.reset();
315             return;
316         }
317 
318         try
319         {
320             dlib::timeout t(*con,&connection::shutdown,timeout);
321 
322             char junk[100];
323             // wait for the other end to close their side
324             while (con->read(junk,sizeof(junk)) > 0) ;
325         }
326         catch (...)
327         {
328             con.reset();
329             throw;
330         }
331 
332         con.reset();
333     }
334 
335 // ----------------------------------------------------------------------------------------
336 
337 }
338 
339 #endif // DLIB_SOCKETS_EXTENSIONs_CPP
340 
341 
342