1 /*
2  *  This file is part of nzbget. See <http://nzbget.net>.
3  *
4  *  Copyright (C) 2004 Sven Henkel <sidddy@users.sourceforge.net>
5  *  Copyright (C) 2007-2017 Andrey Prygunkov <hugbug@users.sourceforge.net>
6  *
7  *  This program is free software; you can redistribute it and/or modify
8  *  it under the terms of the GNU General Public License as published by
9  *  the Free Software Foundation; either version 2 of the License, or
10  *  (at your option) any later version.
11  *
12  *  This program is distributed in the hope that it will be useful,
13  *  but WITHOUT ANY WARRANTY; without even the implied warranty of
14  *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15  *  GNU General Public License for more details.
16  *
17  *  You should have received a copy of the GNU General Public License
18  *  along with this program.  If not, see <http://www.gnu.org/licenses/>.
19  */
20 
21 
22 #ifndef CONNECTION_H
23 #define CONNECTION_H
24 
25 #include "NString.h"
26 
27 #ifndef HAVE_GETADDRINFO
28 #ifndef HAVE_GETHOSTBYNAME_R
29 #include "Thread.h"
30 #endif
31 #endif
32 #ifndef DISABLE_TLS
33 #include "TlsSocket.h"
34 #endif
35 
36 class Connection
37 {
38 public:
39 	enum EStatus
40 	{
41 		csConnected,
42 		csDisconnected,
43 		csListening,
44 		csCancelled,
45 		csBroken
46 	};
47 
48 	enum EIPVersion
49 	{
50 		ipAuto,
51 		ipV4,
52 		ipV6
53 	};
54 
55 	Connection(const char* host, int port, bool tls);
56 	Connection(SOCKET socket, bool tls);
57 	virtual ~Connection();
58 	static void Init();
59 	static void Final();
60 	virtual bool Connect();
61 	virtual bool Disconnect();
62 	bool Bind();
63 	bool Send(const char* buffer, int size);
64 	bool Recv(char* buffer, int size);
65 	int TryRecv(char* buffer, int size);
66 	char* ReadLine(char* buffer, int size, int* bytesRead);
67 	void ReadBuffer(char** buffer, int *bufLen);
68 	int WriteLine(const char* buffer);
69 	std::unique_ptr<Connection> Accept();
70 	void Cancel();
GetHost()71 	const char* GetHost() { return m_host; }
GetPort()72 	int GetPort() { return m_port; }
GetTls()73 	bool GetTls() { return m_tls; }
GetCipher()74 	const char* GetCipher() { return m_cipher; }
SetCipher(const char * cipher)75 	void SetCipher(const char* cipher) { m_cipher = cipher; }
SetTimeout(int timeout)76 	void SetTimeout(int timeout) { m_timeout = timeout; }
SetIPVersion(EIPVersion ipVersion)77 	void SetIPVersion(EIPVersion ipVersion) { m_ipVersion = ipVersion; }
GetStatus()78 	EStatus GetStatus() { return m_status; }
79 	void SetSuppressErrors(bool suppressErrors);
GetSuppressErrors()80 	bool GetSuppressErrors() { return m_suppressErrors; }
81 	const char* GetRemoteAddr();
GetGracefull()82 	bool GetGracefull() { return m_gracefull; }
SetGracefull(bool gracefull)83 	void SetGracefull(bool gracefull) { m_gracefull = gracefull; }
SetForceClose(bool forceClose)84 	void SetForceClose(bool forceClose) { m_forceClose = forceClose; }
85 #ifndef DISABLE_TLS
86 	bool StartTls(bool isClient, const char* certFile, const char* keyFile);
87 #endif
88 	int FetchTotalBytesRead();
89 
90 protected:
91 	CString m_host;
92 	int m_port;
93 	bool m_tls;
94 	EIPVersion m_ipVersion = ipAuto;
95 	SOCKET m_socket = INVALID_SOCKET;
96 	CString m_cipher;
97 	CharBuffer m_readBuf;
98 	int m_bufAvail = 0;
99 	char* m_bufPtr = nullptr;
100 	EStatus m_status = csDisconnected;
101 	int m_timeout = 60;
102 	bool m_suppressErrors = true;
103 	BString<100> m_remoteAddr;
104 	int m_totalBytesRead = 0;
105 	bool m_gracefull = false;
106 	bool m_forceClose = false;
107 
108 	struct SockAddr
109 	{
110 		int ai_family;
111 		int ai_socktype;
112 		int ai_protocol;
113 		bool operator==(const SockAddr& rhs) const
114 			{ return memcmp(this, &rhs, sizeof(SockAddr)) == 0; }
115 	};
116 
117 #ifndef DISABLE_TLS
118 	class ConTlsSocket: public TlsSocket
119 	{
120 	public:
ConTlsSocket(SOCKET socket,bool isClient,const char * host,const char * certFile,const char * keyFile,const char * cipher,Connection * owner)121 		ConTlsSocket(SOCKET socket, bool isClient, const char* host,
122 			const char* certFile, const char* keyFile, const char* cipher, Connection* owner) :
123 			TlsSocket(socket, isClient, host, certFile, keyFile, cipher), m_owner(owner) {}
124 	protected:
PrintError(const char * errMsg)125 		virtual void PrintError(const char* errMsg) { m_owner->PrintError(errMsg); }
126 	private:
127 		Connection* m_owner;
128 	};
129 
130 	std::unique_ptr<ConTlsSocket> m_tlsSocket;
131 	bool m_tlsError = false;
132 #endif
133 #ifndef HAVE_GETADDRINFO
134 #ifndef HAVE_GETHOSTBYNAME_R
135 	static std::unique_ptr<Mutex> m_getHostByNameMutex;
136 #endif
137 #endif
138 
139 	void ReportError(const char* msgPrefix, const char* msgArg, bool printErrCode, int errCode = 0,
140 		const char* errMsg = nullptr);
141 	virtual void PrintError(const char* errMsg);
142 	int GetLastNetworkError();
143 	bool DoConnect();
144 	bool DoDisconnect();
145 	bool InitSocketOpts(SOCKET socket);
146 	bool ConnectWithTimeout(void* address, int address_len);
147 #ifndef HAVE_GETADDRINFO
148 	in_addr_t ResolveHostAddr(const char* host);
149 #endif
150 #ifndef DISABLE_TLS
151 	int recv(SOCKET s, char* buf, int len, int flags);
152 	int send(SOCKET s, const char* buf, int len, int flags);
153 	void CloseTls();
154 #endif
155 };
156 
157 #endif
158