1 /*
2  * This file is part of PowerDNS or dnsdist.
3  * Copyright -- PowerDNS.COM B.V. and its contributors
4  *
5  * This program is free software; you can redistribute it and/or modify
6  * it under the terms of version 2 of the GNU General Public License as
7  * published by the Free Software Foundation.
8  *
9  * In addition, for the avoidance of any doubt, permission is granted to
10  * link this program with OpenSSL and to (re)distribute the binaries
11  * produced as the result of such linking.
12  *
13  * This program is distributed in the hope that it will be useful,
14  * but WITHOUT ANY WARRANTY; without even the implied warranty of
15  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
16  * GNU General Public License for more details.
17  *
18  * You should have received a copy of the GNU General Public License
19  * along with this program; if not, write to the Free Software
20  * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
21  */
22 #pragma once
23 #include <string>
24 #include <sstream>
25 #include <iostream>
26 #include "iputils.hh"
27 #include <errno.h>
28 #include <sys/types.h>
29 #include <unistd.h>
30 #include <sys/socket.h>
31 #include <netinet/in.h>
32 #include <netinet/tcp.h>
33 #include <arpa/inet.h>
34 #include <sys/select.h>
35 #include <fcntl.h>
36 #include <stdexcept>
37 
38 #include <boost/utility.hpp>
39 #include <csignal>
40 #include "namespaces.hh"
41 
42 
43 typedef int ProtocolType; //!< Supported protocol types
44 
45 //! Representation of a Socket and many of the Berkeley functions available
46 class Socket : public boost::noncopyable
47 {
48 public:
Socket(int fd)49   Socket(int fd): d_socket(fd)
50   {
51   }
52 
53   //! Construct a socket of specified address family and socket type.
Socket(int af,int st,ProtocolType pt=0)54   Socket(int af, int st, ProtocolType pt=0)
55   {
56     if((d_socket=socket(af, st, pt))<0)
57       throw NetworkError(stringerror());
58     setCloseOnExec(d_socket);
59   }
60 
Socket(Socket && rhs)61   Socket(Socket&& rhs): d_buffer(std::move(rhs.d_buffer)), d_socket(rhs.d_socket)
62   {
63     rhs.d_socket = -1;
64   }
65 
~Socket()66   ~Socket()
67   {
68     try {
69       if (d_socket != -1) {
70         closesocket(d_socket);
71       }
72     }
73     catch(const PDNSException& e) {
74     }
75   }
76 
77   //! If the socket is capable of doing so, this function will wait for a connection
accept()78   std::unique_ptr<Socket> accept()
79   {
80     struct sockaddr_in remote;
81     socklen_t remlen=sizeof(remote);
82     memset(&remote, 0, sizeof(remote));
83     int s=::accept(d_socket, reinterpret_cast<sockaddr *>(&remote), &remlen);
84     if(s<0) {
85       if(errno==EAGAIN)
86         return nullptr;
87 
88       throw NetworkError("Accepting a connection: "+stringerror());
89     }
90 
91     return std::unique_ptr<Socket>(new Socket(s));
92   }
93 
94   //! Get remote address
getRemote(ComboAddress & remote)95   bool getRemote(ComboAddress &remote) {
96     socklen_t remotelen=sizeof(remote);
97     return (getpeername(d_socket, reinterpret_cast<struct sockaddr *>(&remote), &remotelen) >= 0);
98   }
99 
100   //! Check remote address against netmaskgroup ng
acl(const NetmaskGroup & ng)101   bool acl(const NetmaskGroup &ng)
102   {
103     ComboAddress remote;
104     if (getRemote(remote))
105       return ng.match(remote);
106 
107     return false;
108   }
109 
110   //! Set the socket to non-blocking
setNonBlocking()111   void setNonBlocking()
112   {
113     ::setNonBlocking(d_socket);
114   }
115 
116   //! Set the socket to blocking
setBlocking()117   void setBlocking()
118   {
119     ::setBlocking(d_socket);
120   }
121 
setReuseAddr()122   void setReuseAddr()
123   {
124     try {
125       ::setReuseAddr(d_socket);
126     } catch (const PDNSException &e) {
127       throw NetworkError(e.reason);
128     }
129   }
130 
setFastOpenConnect()131   void setFastOpenConnect()
132   {
133 #ifdef TCP_FASTOPEN_CONNECT
134     int on = 1;
135     if (setsockopt(d_socket, IPPROTO_TCP, TCP_FASTOPEN_CONNECT, &on, sizeof(on)) < 0) {
136       throw NetworkError("While setting TCP_FASTOPEN_CONNECT: " + stringerror());
137     }
138 #else
139    throw NetworkError("While setting TCP_FASTOPEN_CONNECT: not compiled in");
140 #endif
141   }
142 
143   //! Bind the socket to a specified endpoint
bind(const ComboAddress & local,bool reuseaddr=true)144   void bind(const ComboAddress &local, bool reuseaddr=true)
145   {
146     int tmp=1;
147     if(reuseaddr && setsockopt(d_socket, SOL_SOCKET, SO_REUSEADDR, reinterpret_cast<char*>(&tmp), sizeof tmp)<0)
148       throw NetworkError("Setsockopt failed: "+stringerror());
149 
150     if(::bind(d_socket, reinterpret_cast<const struct sockaddr *>(&local), local.getSocklen())<0)
151       throw NetworkError("While binding: "+stringerror());
152   }
153 
154   //! Connect the socket to a specified endpoint
connect(const ComboAddress & ep,int timeout=0)155   void connect(const ComboAddress &ep, int timeout=0)
156   {
157     SConnectWithTimeout(d_socket, ep, timeout);
158   }
159 
160 
161   //! For datagram sockets, receive a datagram and learn where it came from
162   /** For datagram sockets, receive a datagram and learn where it came from
163       \param dgram Will be filled with the datagram
164       \param ep Will be filled with the origin of the datagram */
recvFrom(string & dgram,ComboAddress & ep)165   void recvFrom(string &dgram, ComboAddress &ep)
166   {
167     socklen_t remlen = sizeof(ep);
168     ssize_t bytes;
169     d_buffer.resize(s_buflen);
170     if((bytes=recvfrom(d_socket, &d_buffer[0], s_buflen, 0, reinterpret_cast<sockaddr *>(&ep) , &remlen)) <0)
171       throw NetworkError("After recvfrom: "+stringerror());
172 
173     dgram.assign(d_buffer, 0, static_cast<size_t>(bytes));
174   }
175 
recvFromAsync(string & dgram,ComboAddress & ep)176   bool recvFromAsync(string &dgram, ComboAddress &ep)
177   {
178     struct sockaddr_in remote;
179     socklen_t remlen = sizeof(remote);
180     ssize_t bytes;
181     d_buffer.resize(s_buflen);
182     if((bytes=recvfrom(d_socket, &d_buffer[0], s_buflen, 0, reinterpret_cast<sockaddr *>(&remote), &remlen))<0) {
183       if(errno!=EAGAIN) {
184         throw NetworkError("After async recvfrom: "+stringerror());
185       }
186       else {
187         return false;
188       }
189     }
190     dgram.assign(d_buffer, 0, static_cast<size_t>(bytes));
191     return true;
192   }
193 
194 
195   //! For datagram sockets, send a datagram to a destination
sendTo(const char * msg,size_t len,const ComboAddress & ep)196   void sendTo(const char* msg, size_t len, const ComboAddress &ep)
197   {
198     if(sendto(d_socket, msg, len, 0, reinterpret_cast<const sockaddr *>(&ep), ep.getSocklen())<0)
199       throw NetworkError("After sendto: "+stringerror());
200   }
201 
202   //! For connected datagram sockets, send a datagram
send(const std::string & msg)203   void send(const std::string& msg)
204   {
205     if(::send(d_socket, msg.c_str(), msg.size(), 0)<0)
206       throw NetworkError("After send: "+stringerror());
207   }
208 
209 
210   /** For datagram sockets, send a datagram to a destination
211       \param dgram The datagram
212       \param ep The intended destination of the datagram */
sendTo(const string & dgram,const ComboAddress & ep)213   void sendTo(const string &dgram, const ComboAddress &ep)
214   {
215     sendTo(dgram.c_str(), dgram.length(), ep);
216   }
217 
218 
219   //! Write this data to the socket, taking care that all bytes are written out
writen(const string & data)220   void writen(const string &data)
221   {
222     if(data.empty())
223       return;
224 
225     size_t toWrite=data.length();
226     ssize_t res;
227     const char *ptr=data.c_str();
228 
229     do {
230       res=::send(d_socket, ptr, toWrite, 0);
231       if(res<0)
232         throw NetworkError("Writing to a socket: "+stringerror());
233       if(!res)
234         throw NetworkError("EOF on socket");
235       toWrite -= static_cast<size_t>(res);
236       ptr += static_cast<size_t>(res);
237     } while(toWrite);
238 
239   }
240 
241   //! tries to write toWrite bytes from ptr to the socket
242   /** tries to write toWrite bytes from ptr to the socket, but does not make sure they al get written out
243       \param ptr Location to write from
244       \param toWrite number of bytes to try
245   */
tryWrite(const char * ptr,size_t toWrite)246   size_t tryWrite(const char *ptr, size_t toWrite)
247   {
248     ssize_t res;
249     res=::send(d_socket,ptr,toWrite,0);
250     if(res==0)
251       throw NetworkError("EOF on writing to a socket");
252 
253     if(res>0)
254       return res;
255 
256     if(errno==EAGAIN)
257       return 0;
258 
259     throw NetworkError("Writing to a socket: "+stringerror());
260   }
261 
262   //! Writes toWrite bytes from ptr to the socket
263   /** Writes toWrite bytes from ptr to the socket. Returns how many bytes were written */
write(const char * ptr,size_t toWrite)264   size_t write(const char *ptr, size_t toWrite)
265   {
266     ssize_t res;
267     res=::send(d_socket,ptr,toWrite,0);
268     if(res<0) {
269       throw NetworkError("Writing to a socket: "+stringerror());
270     }
271     return res;
272   }
273 
writenWithTimeout(const void * buffer,size_t n,int timeout)274   void writenWithTimeout(const void *buffer, size_t n, int timeout)
275   {
276     size_t bytes=n;
277     const char *ptr = reinterpret_cast<const char*>(buffer);
278     ssize_t ret;
279     while(bytes) {
280       ret=::write(d_socket, ptr, bytes);
281       if(ret < 0) {
282         if(errno == EAGAIN) {
283           ret=waitForRWData(d_socket, false, timeout, 0);
284           if(ret < 0)
285             throw NetworkError("Waiting for data write");
286           if(!ret)
287             throw NetworkError("Timeout writing data");
288           continue;
289         }
290         else
291           throw NetworkError("Writing data: "+stringerror());
292       }
293       if(!ret) {
294         throw NetworkError("Did not fulfill TCP write due to EOF");
295       }
296 
297       ptr += static_cast<size_t>(ret);
298       bytes -= static_cast<size_t>(ret);
299     }
300   }
301 
302   //! reads one character from the socket
getChar()303   int getChar()
304   {
305     char c;
306 
307     ssize_t res=::recv(d_socket,&c,1,0);
308     if(res)
309       return c;
310     return -1;
311   }
312 
getline(string & data)313   void getline(string &data)
314   {
315     data="";
316     int c;
317     while((c=getChar())!=-1) {
318       data+=(char)c;
319       if(c=='\n')
320         break;
321     }
322   }
323 
324   //! Reads a block of data from the socket to a string
read(string & data)325   void read(string &data)
326   {
327     d_buffer.resize(s_buflen);
328     ssize_t res=::recv(d_socket, &d_buffer[0], s_buflen, 0);
329     if(res<0)
330       throw NetworkError("Reading from a socket: "+stringerror());
331     data.assign(d_buffer, 0, static_cast<size_t>(res));
332   }
333 
334   //! Reads a block of data from the socket to a block of memory
read(char * buffer,size_t bytes)335   size_t read(char *buffer, size_t bytes)
336   {
337     ssize_t res=::recv(d_socket, buffer, bytes, 0);
338     if(res<0)
339       throw NetworkError("Reading from a socket: "+stringerror());
340     return static_cast<size_t>(res);
341   }
342 
readWithTimeout(char * buffer,size_t n,int timeout)343   ssize_t readWithTimeout(char* buffer, size_t n, int timeout)
344   {
345     int err = waitForRWData(d_socket, true, timeout, 0);
346 
347     if(err == 0)
348       throw NetworkError("timeout reading");
349     if(err < 0)
350       throw NetworkError("nonblocking read failed: "+stringerror());
351 
352     return read(buffer, n);
353   }
354 
355   //! Sets the socket to listen with a default listen backlog of 10 pending connections
listen(unsigned int length=10)356   void listen(unsigned int length=10)
357   {
358     if(::listen(d_socket,length)<0)
359       throw NetworkError("Setting socket to listen: "+stringerror());
360   }
361 
362   //! Returns the internal file descriptor of the socket
getHandle() const363   int getHandle() const
364   {
365     return d_socket;
366   }
367 
releaseHandle()368   int releaseHandle()
369   {
370     int ret = d_socket;
371     d_socket = -1;
372     return ret;
373   }
374 
375 private:
376   static const size_t s_buflen{4096};
377   std::string d_buffer;
378   int d_socket;
379 };
380