1 /* 2 * This file is part of nzbget. See <http://nzbget.net>. 3 * 4 * Copyright (C) 2004 Sven Henkel <sidddy@users.sourceforge.net> 5 * Copyright (C) 2007-2017 Andrey Prygunkov <hugbug@users.sourceforge.net> 6 * 7 * This program is free software; you can redistribute it and/or modify 8 * it under the terms of the GNU General Public License as published by 9 * the Free Software Foundation; either version 2 of the License, or 10 * (at your option) any later version. 11 * 12 * This program is distributed in the hope that it will be useful, 13 * but WITHOUT ANY WARRANTY; without even the implied warranty of 14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 15 * GNU General Public License for more details. 16 * 17 * You should have received a copy of the GNU General Public License 18 * along with this program. If not, see <http://www.gnu.org/licenses/>. 19 */ 20 21 22 #ifndef CONNECTION_H 23 #define CONNECTION_H 24 25 #include "NString.h" 26 27 #ifndef HAVE_GETADDRINFO 28 #ifndef HAVE_GETHOSTBYNAME_R 29 #include "Thread.h" 30 #endif 31 #endif 32 #ifndef DISABLE_TLS 33 #include "TlsSocket.h" 34 #endif 35 36 class Connection 37 { 38 public: 39 enum EStatus 40 { 41 csConnected, 42 csDisconnected, 43 csListening, 44 csCancelled, 45 csBroken 46 }; 47 48 enum EIPVersion 49 { 50 ipAuto, 51 ipV4, 52 ipV6 53 }; 54 55 Connection(const char* host, int port, bool tls); 56 Connection(SOCKET socket, bool tls); 57 virtual ~Connection(); 58 static void Init(); 59 static void Final(); 60 virtual bool Connect(); 61 virtual bool Disconnect(); 62 bool Bind(); 63 bool Send(const char* buffer, int size); 64 bool Recv(char* buffer, int size); 65 int TryRecv(char* buffer, int size); 66 char* ReadLine(char* buffer, int size, int* bytesRead); 67 void ReadBuffer(char** buffer, int *bufLen); 68 int WriteLine(const char* buffer); 69 std::unique_ptr<Connection> Accept(); 70 void Cancel(); GetHost()71 const char* GetHost() { return m_host; } GetPort()72 int GetPort() { return m_port; } GetTls()73 bool GetTls() { return m_tls; } GetCipher()74 const char* GetCipher() { return m_cipher; } SetCipher(const char * cipher)75 void SetCipher(const char* cipher) { m_cipher = cipher; } SetTimeout(int timeout)76 void SetTimeout(int timeout) { m_timeout = timeout; } SetIPVersion(EIPVersion ipVersion)77 void SetIPVersion(EIPVersion ipVersion) { m_ipVersion = ipVersion; } GetStatus()78 EStatus GetStatus() { return m_status; } 79 void SetSuppressErrors(bool suppressErrors); GetSuppressErrors()80 bool GetSuppressErrors() { return m_suppressErrors; } 81 const char* GetRemoteAddr(); GetGracefull()82 bool GetGracefull() { return m_gracefull; } SetGracefull(bool gracefull)83 void SetGracefull(bool gracefull) { m_gracefull = gracefull; } SetForceClose(bool forceClose)84 void SetForceClose(bool forceClose) { m_forceClose = forceClose; } 85 #ifndef DISABLE_TLS 86 bool StartTls(bool isClient, const char* certFile, const char* keyFile); 87 #endif 88 int FetchTotalBytesRead(); 89 90 protected: 91 CString m_host; 92 int m_port; 93 bool m_tls; 94 EIPVersion m_ipVersion = ipAuto; 95 SOCKET m_socket = INVALID_SOCKET; 96 CString m_cipher; 97 CharBuffer m_readBuf; 98 int m_bufAvail = 0; 99 char* m_bufPtr = nullptr; 100 EStatus m_status = csDisconnected; 101 int m_timeout = 60; 102 bool m_suppressErrors = true; 103 BString<100> m_remoteAddr; 104 int m_totalBytesRead = 0; 105 bool m_gracefull = false; 106 bool m_forceClose = false; 107 108 struct SockAddr 109 { 110 int ai_family; 111 int ai_socktype; 112 int ai_protocol; 113 bool operator==(const SockAddr& rhs) const 114 { return memcmp(this, &rhs, sizeof(SockAddr)) == 0; } 115 }; 116 117 #ifndef DISABLE_TLS 118 class ConTlsSocket: public TlsSocket 119 { 120 public: ConTlsSocket(SOCKET socket,bool isClient,const char * host,const char * certFile,const char * keyFile,const char * cipher,Connection * owner)121 ConTlsSocket(SOCKET socket, bool isClient, const char* host, 122 const char* certFile, const char* keyFile, const char* cipher, Connection* owner) : 123 TlsSocket(socket, isClient, host, certFile, keyFile, cipher), m_owner(owner) {} 124 protected: PrintError(const char * errMsg)125 virtual void PrintError(const char* errMsg) { m_owner->PrintError(errMsg); } 126 private: 127 Connection* m_owner; 128 }; 129 130 std::unique_ptr<ConTlsSocket> m_tlsSocket; 131 bool m_tlsError = false; 132 #endif 133 #ifndef HAVE_GETADDRINFO 134 #ifndef HAVE_GETHOSTBYNAME_R 135 static std::unique_ptr<Mutex> m_getHostByNameMutex; 136 #endif 137 #endif 138 139 void ReportError(const char* msgPrefix, const char* msgArg, bool printErrCode, int errCode = 0, 140 const char* errMsg = nullptr); 141 virtual void PrintError(const char* errMsg); 142 int GetLastNetworkError(); 143 bool DoConnect(); 144 bool DoDisconnect(); 145 bool InitSocketOpts(SOCKET socket); 146 bool ConnectWithTimeout(void* address, int address_len); 147 #ifndef HAVE_GETADDRINFO 148 in_addr_t ResolveHostAddr(const char* host); 149 #endif 150 #ifndef DISABLE_TLS 151 int recv(SOCKET s, char* buf, int len, int flags); 152 int send(SOCKET s, const char* buf, int len, int flags); 153 void CloseTls(); 154 #endif 155 }; 156 157 #endif 158