1 /*
2  * Copyright (C) 2001-2012 Jacek Sieka, arnetheduck on gmail point com
3  *
4  * This program is free software; you can redistribute it and/or modify
5  * it under the terms of the GNU General Public License as published by
6  * the Free Software Foundation; either version 2 of the License, or
7  * (at your option) any later version.
8  *
9  * This program is distributed in the hope that it will be useful,
10  * but WITHOUT ANY WARRANTY; without even the implied warranty of
11  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
12  * GNU General Public License for more details.
13  *
14  * You should have received a copy of the GNU General Public License
15  * along with this program; if not, write to the Free Software
16  * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
17  */
18 
19 #pragma once
20 
21 #ifdef _WIN32
22 #include "w.h"
23 typedef int socklen_t;
24 typedef SOCKET socket_t;
25 #else
26 #include <sys/ioctl.h>
27 #include <sys/socket.h>
28 #include <netinet/in.h>
29 #include <arpa/inet.h>
30 #include <netdb.h>
31 #include <fcntl.h>
32 #include <errno.h>
33 #include <vector>
34 typedef int socket_t;
35 const int INVALID_SOCKET = -1;
36 #define SOCKET_ERROR -1
37 #endif
38 #include "Util.h"
39 #include "Exception.h"
40 
41 namespace dcpp {
42 
43 class SocketException : public Exception {
44 public:
45 #ifdef _DEBUG
SocketException(const string & aError)46     SocketException(const string& aError) noexcept : Exception("SocketException: " + aError) { }
47 #else //_DEBUG
48     SocketException(const string& aError) noexcept : Exception(aError) { }
49 #endif // _DEBUG
50 
51     SocketException(int aError) noexcept;
~SocketException()52     virtual ~SocketException() noexcept { }
53 private:
54     static string errorToString(int aError) noexcept;
55 };
56 
57 class Socket
58 {
59 public:
60     enum {
61         WAIT_NONE = 0x00,
62         WAIT_CONNECT = 0x01,
63         WAIT_READ = 0x02,
64         WAIT_WRITE = 0x04
65     };
66 
67     enum {
68         TYPE_TCP,
69         TYPE_UDP
70     };
71 
Socket()72     Socket() : sock(INVALID_SOCKET), connected(false) { }
Socket(const string & aIp,uint16_t aPort)73     Socket(const string& aIp, uint16_t aPort) : sock(INVALID_SOCKET), connected(false) { connect(aIp, aPort); }
~Socket()74     virtual ~Socket() { disconnect(); }
75 
76     /**
77      * Connects a socket to an address/ip, closing any other connections made with
78      * this instance.
79      * @param aAddr Server address, in dns or xxx.xxx.xxx.xxx format.
80      * @param aPort Server port.
81      * @throw SocketException If any connection error occurs.
82      */
83     virtual void connect(const string& aIp, uint16_t aPort);
connect(const string & aIp,const string & aPort)84     void connect(const string& aIp, const string& aPort) { connect(aIp, static_cast<uint16_t>(Util::toInt(aPort))); }
85     /**
86      * Same as connect(), but through the SOCKS5 server
87      */
88     void socksConnect(const string& aIp, uint16_t aPort, uint32_t timeout = 0);
89 
90     /**
91      * Sends data, will block until all data has been sent or an exception occurs
92      * @param aBuffer Buffer with data
93      * @param aLen Data length
94      * @throw SocketExcpetion Send failed.
95      */
96     void writeAll(const void* aBuffer, int aLen, uint32_t timeout = 0);
97     virtual int write(const void* aBuffer, int aLen);
write(const string & aData)98     int write(const string& aData) { return write(aData.data(), (int)aData.length()); }
99     virtual void writeTo(const string& aIp, uint16_t aPort, const void* aBuffer, int aLen, bool proxy = true);
writeTo(const string & aIp,uint16_t aPort,const string & aData)100     void writeTo(const string& aIp, uint16_t aPort, const string& aData) { writeTo(aIp, aPort, aData.data(), (int)aData.length()); }
101     virtual void shutdown() noexcept;
102     virtual void close() noexcept;
103     void disconnect() noexcept;
104 
105     virtual bool waitConnected(uint32_t millis);
106     virtual bool waitAccepted(uint32_t millis);
107 
108     /**
109      * Reads zero to aBufLen characters from this socket,
110      * @param aBuffer A buffer to store the data in.
111      * @param aBufLen Size of the buffer.
112      * @return Number of bytes read, 0 if disconnected and -1 if the call would block.
113      * @throw SocketException On any failure.
114      */
115     virtual int read(void* aBuffer, int aBufLen);
116     /**
117      * Reads zero to aBufLen characters from this socket,
118      * @param aBuffer A buffer to store the data in.
119      * @param aBufLen Size of the buffer.
120      * @param aIP Remote IP address
121      * @return Number of bytes read, 0 if disconnected and -1 if the call would block.
122      * @throw SocketException On any failure.
123      */
124     virtual int read(void* aBuffer, int aBufLen, sockaddr_in& remote);
125     /**
126      * Reads data until aBufLen bytes have been read or an error occurs.
127      * If the socket is closed, or the timeout is reached, the number of bytes read
128      * actually read is returned.
129      * On exception, an unspecified amount of bytes might have already been read.
130      */
131     int readAll(void* aBuffer, int aBufLen, uint32_t timeout = 0);
132 
133     virtual int wait(uint32_t millis, int waitFor);
isConnected()134     bool isConnected() { return connected; }
135 
136     static string resolve(const string& aDns);
getTotalDown()137     static uint64_t getTotalDown() { return stats.totalDown; }
getTotalUp()138     static uint64_t getTotalUp() { return stats.totalUp; }
139 
140 #ifdef _WIN32
setBlocking(bool block)141     void setBlocking(bool block) noexcept {
142         u_long b = block ? 0 : 1;
143         ioctlsocket(sock, FIONBIO, &b);
144     }
145 #else
setBlocking(bool block)146     void setBlocking(bool block) noexcept {
147         int flags = fcntl(sock, F_GETFL, 0);
148         if(block) {
149             fcntl(sock, F_SETFL, flags & (~O_NONBLOCK));
150         } else {
151             fcntl(sock, F_SETFL, flags | O_NONBLOCK);
152         }
153     }
154 #endif
155 
156     string getLocalIp() noexcept;
157     uint16_t getLocalPort() noexcept;
158 
159     // Low level interface
160     virtual void create(int aType = TYPE_TCP);
161 
162     /** Binds a socket to a certain local port and possibly IP. */
163     virtual uint16_t bind(uint16_t aPort = 0, const string& aIp = "0.0.0.0");
164     virtual void listen();
165     virtual void accept(const Socket& listeningSocket);
166 
167     int getSocketOptInt(int option);
168     void setSocketOpt(int option, int value);
169 
isSecure()170     virtual bool isSecure() const noexcept { return false; }
isTrusted()171     virtual bool isTrusted() const noexcept { return false; }
getCipherName()172     virtual std::string getCipherName() const noexcept { return Util::emptyString; }
getKeyprint()173     virtual std::vector<uint8_t> getKeyprint() const noexcept { return std::vector<uint8_t>(); }
174 
175     /** When socks settings are updated, this has to be called... */
176     static void socksUpdated();
177     string getIfaceI4 (const string &iface);
178 
179     GETSET(string, ip, Ip);
180     socket_t sock;
181 protected:
182     int type;
183     bool connected;
184 
185     class Stats {
186     public:
187         uint64_t totalDown;
188         uint64_t totalUp;
189     };
190     static Stats stats;
191 
192     static string udpServer;
193     static uint16_t udpPort;
194 
195 private:
196     Socket(const Socket&);
197     Socket& operator=(const Socket&);
198 
199 
200     void socksAuth(uint32_t timeout);
201 
202 #ifdef _WIN32
getLastError()203     static int getLastError() { return ::WSAGetLastError(); }
checksocket(int ret)204     static int checksocket(int ret) {
205         if(ret == SOCKET_ERROR) {
206             throw SocketException(getLastError());
207         }
208         return ret;
209     }
210     static int check(int ret, bool blockOk = false) {
211         if(ret == SOCKET_ERROR) {
212             int error = getLastError();
213             if(blockOk && error == WSAEWOULDBLOCK) {
214                 return -1;
215             } else {
216                 throw SocketException(error);
217             }
218         }
219         return ret;
220     }
221 #else
getLastError()222     static int getLastError() { return errno; }
checksocket(int ret)223     static int checksocket(int ret) {
224         if(ret < 0) {
225             throw SocketException(getLastError());
226         }
227         return ret;
228     }
229     static int check(int ret, bool blockOk = false) {
230         if(ret == -1) {
231             int error = getLastError();
232             if(blockOk && (error == EWOULDBLOCK || error == ENOBUFS || error == EINPROGRESS || error == EAGAIN) ) {
233                 return -1;
234             } else {
235                 throw SocketException(error);
236             }
237         }
238         return ret;
239     }
240 #endif
241 
242 };
243 
244 } // namespace dcpp
245