1 //===-- TCPSocket.cpp -----------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #if defined(_MSC_VER)
10 #define _WINSOCK_DEPRECATED_NO_WARNINGS
11 #endif
12 
13 #include "lldb/Host/common/TCPSocket.h"
14 
15 #include "lldb/Host/Config.h"
16 #include "lldb/Host/MainLoop.h"
17 #include "lldb/Utility/LLDBLog.h"
18 #include "lldb/Utility/Log.h"
19 
20 #include "llvm/Config/llvm-config.h"
21 #include "llvm/Support/Errno.h"
22 #include "llvm/Support/WindowsError.h"
23 #include "llvm/Support/raw_ostream.h"
24 
25 #if LLDB_ENABLE_POSIX
26 #include <arpa/inet.h>
27 #include <netinet/tcp.h>
28 #include <sys/socket.h>
29 #endif
30 
31 #if defined(_WIN32)
32 #include <winsock2.h>
33 #endif
34 
35 #ifdef _WIN32
36 #define CLOSE_SOCKET closesocket
37 typedef const char *set_socket_option_arg_type;
38 #else
39 #include <unistd.h>
40 #define CLOSE_SOCKET ::close
41 typedef const void *set_socket_option_arg_type;
42 #endif
43 
44 using namespace lldb;
45 using namespace lldb_private;
46 
GetLastSocketError()47 static Status GetLastSocketError() {
48   std::error_code EC;
49 #ifdef _WIN32
50   EC = llvm::mapWindowsError(WSAGetLastError());
51 #else
52   EC = std::error_code(errno, std::generic_category());
53 #endif
54   return EC;
55 }
56 
57 static const int kType = SOCK_STREAM;
58 
TCPSocket(bool should_close,bool child_processes_inherit)59 TCPSocket::TCPSocket(bool should_close, bool child_processes_inherit)
60     : Socket(ProtocolTcp, should_close, child_processes_inherit) {}
61 
TCPSocket(NativeSocket socket,const TCPSocket & listen_socket)62 TCPSocket::TCPSocket(NativeSocket socket, const TCPSocket &listen_socket)
63     : Socket(ProtocolTcp, listen_socket.m_should_close_fd,
64              listen_socket.m_child_processes_inherit) {
65   m_socket = socket;
66 }
67 
TCPSocket(NativeSocket socket,bool should_close,bool child_processes_inherit)68 TCPSocket::TCPSocket(NativeSocket socket, bool should_close,
69                      bool child_processes_inherit)
70     : Socket(ProtocolTcp, should_close, child_processes_inherit) {
71   m_socket = socket;
72 }
73 
~TCPSocket()74 TCPSocket::~TCPSocket() { CloseListenSockets(); }
75 
IsValid() const76 bool TCPSocket::IsValid() const {
77   return m_socket != kInvalidSocketValue || m_listen_sockets.size() != 0;
78 }
79 
80 // Return the port number that is being used by the socket.
GetLocalPortNumber() const81 uint16_t TCPSocket::GetLocalPortNumber() const {
82   if (m_socket != kInvalidSocketValue) {
83     SocketAddress sock_addr;
84     socklen_t sock_addr_len = sock_addr.GetMaxLength();
85     if (::getsockname(m_socket, sock_addr, &sock_addr_len) == 0)
86       return sock_addr.GetPort();
87   } else if (!m_listen_sockets.empty()) {
88     SocketAddress sock_addr;
89     socklen_t sock_addr_len = sock_addr.GetMaxLength();
90     if (::getsockname(m_listen_sockets.begin()->first, sock_addr,
91                       &sock_addr_len) == 0)
92       return sock_addr.GetPort();
93   }
94   return 0;
95 }
96 
GetLocalIPAddress() const97 std::string TCPSocket::GetLocalIPAddress() const {
98   // We bound to port zero, so we need to figure out which port we actually
99   // bound to
100   if (m_socket != kInvalidSocketValue) {
101     SocketAddress sock_addr;
102     socklen_t sock_addr_len = sock_addr.GetMaxLength();
103     if (::getsockname(m_socket, sock_addr, &sock_addr_len) == 0)
104       return sock_addr.GetIPAddress();
105   }
106   return "";
107 }
108 
GetRemotePortNumber() const109 uint16_t TCPSocket::GetRemotePortNumber() const {
110   if (m_socket != kInvalidSocketValue) {
111     SocketAddress sock_addr;
112     socklen_t sock_addr_len = sock_addr.GetMaxLength();
113     if (::getpeername(m_socket, sock_addr, &sock_addr_len) == 0)
114       return sock_addr.GetPort();
115   }
116   return 0;
117 }
118 
GetRemoteIPAddress() const119 std::string TCPSocket::GetRemoteIPAddress() const {
120   // We bound to port zero, so we need to figure out which port we actually
121   // bound to
122   if (m_socket != kInvalidSocketValue) {
123     SocketAddress sock_addr;
124     socklen_t sock_addr_len = sock_addr.GetMaxLength();
125     if (::getpeername(m_socket, sock_addr, &sock_addr_len) == 0)
126       return sock_addr.GetIPAddress();
127   }
128   return "";
129 }
130 
GetRemoteConnectionURI() const131 std::string TCPSocket::GetRemoteConnectionURI() const {
132   if (m_socket != kInvalidSocketValue) {
133     return std::string(llvm::formatv(
134         "connect://[{0}]:{1}", GetRemoteIPAddress(), GetRemotePortNumber()));
135   }
136   return "";
137 }
138 
CreateSocket(int domain)139 Status TCPSocket::CreateSocket(int domain) {
140   Status error;
141   if (IsValid())
142     error = Close();
143   if (error.Fail())
144     return error;
145   m_socket = Socket::CreateSocket(domain, kType, IPPROTO_TCP,
146                                   m_child_processes_inherit, error);
147   return error;
148 }
149 
Connect(llvm::StringRef name)150 Status TCPSocket::Connect(llvm::StringRef name) {
151 
152   Log *log = GetLog(LLDBLog::Communication);
153   LLDB_LOGF(log, "TCPSocket::%s (host/port = %s)", __FUNCTION__, name.data());
154 
155   Status error;
156   llvm::Expected<HostAndPort> host_port = DecodeHostAndPort(name);
157   if (!host_port)
158     return Status(host_port.takeError());
159 
160   std::vector<SocketAddress> addresses =
161       SocketAddress::GetAddressInfo(host_port->hostname.c_str(), nullptr,
162                                     AF_UNSPEC, SOCK_STREAM, IPPROTO_TCP);
163   for (SocketAddress &address : addresses) {
164     error = CreateSocket(address.GetFamily());
165     if (error.Fail())
166       continue;
167 
168     address.SetPort(host_port->port);
169 
170     if (llvm::sys::RetryAfterSignal(-1, ::connect, GetNativeSocket(),
171                                     &address.sockaddr(),
172                                     address.GetLength()) == -1) {
173       Close();
174       continue;
175     }
176 
177     if (SetOptionNoDelay() == -1) {
178       Close();
179       continue;
180     }
181 
182     error.Clear();
183     return error;
184   }
185 
186   error.SetErrorString("Failed to connect port");
187   return error;
188 }
189 
Listen(llvm::StringRef name,int backlog)190 Status TCPSocket::Listen(llvm::StringRef name, int backlog) {
191   Log *log = GetLog(LLDBLog::Connection);
192   LLDB_LOGF(log, "TCPSocket::%s (%s)", __FUNCTION__, name.data());
193 
194   Status error;
195   llvm::Expected<HostAndPort> host_port = DecodeHostAndPort(name);
196   if (!host_port)
197     return Status(host_port.takeError());
198 
199   if (host_port->hostname == "*")
200     host_port->hostname = "0.0.0.0";
201   std::vector<SocketAddress> addresses = SocketAddress::GetAddressInfo(
202       host_port->hostname.c_str(), nullptr, AF_UNSPEC, SOCK_STREAM, IPPROTO_TCP);
203   for (SocketAddress &address : addresses) {
204     int fd = Socket::CreateSocket(address.GetFamily(), kType, IPPROTO_TCP,
205                                   m_child_processes_inherit, error);
206     if (error.Fail() || fd < 0)
207       continue;
208 
209     // enable local address reuse
210     int option_value = 1;
211     set_socket_option_arg_type option_value_p =
212         reinterpret_cast<set_socket_option_arg_type>(&option_value);
213     if (::setsockopt(fd, SOL_SOCKET, SO_REUSEADDR, option_value_p,
214                      sizeof(option_value)) == -1) {
215       CLOSE_SOCKET(fd);
216       continue;
217     }
218 
219     SocketAddress listen_address = address;
220     if(!listen_address.IsLocalhost())
221       listen_address.SetToAnyAddress(address.GetFamily(), host_port->port);
222     else
223       listen_address.SetPort(host_port->port);
224 
225     int err =
226         ::bind(fd, &listen_address.sockaddr(), listen_address.GetLength());
227     if (err != -1)
228       err = ::listen(fd, backlog);
229 
230     if (err == -1) {
231       error = GetLastSocketError();
232       CLOSE_SOCKET(fd);
233       continue;
234     }
235 
236     if (host_port->port == 0) {
237       socklen_t sa_len = address.GetLength();
238       if (getsockname(fd, &address.sockaddr(), &sa_len) == 0)
239         host_port->port = address.GetPort();
240     }
241     m_listen_sockets[fd] = address;
242   }
243 
244   if (m_listen_sockets.empty()) {
245     assert(error.Fail());
246     return error;
247   }
248   return Status();
249 }
250 
CloseListenSockets()251 void TCPSocket::CloseListenSockets() {
252   for (auto socket : m_listen_sockets)
253     CLOSE_SOCKET(socket.first);
254   m_listen_sockets.clear();
255 }
256 
Accept(Socket * & conn_socket)257 Status TCPSocket::Accept(Socket *&conn_socket) {
258   Status error;
259   if (m_listen_sockets.size() == 0) {
260     error.SetErrorString("No open listening sockets!");
261     return error;
262   }
263 
264   NativeSocket sock = kInvalidSocketValue;
265   NativeSocket listen_sock = kInvalidSocketValue;
266   lldb_private::SocketAddress AcceptAddr;
267   MainLoop accept_loop;
268   std::vector<MainLoopBase::ReadHandleUP> handles;
269   for (auto socket : m_listen_sockets) {
270     auto fd = socket.first;
271     auto inherit = this->m_child_processes_inherit;
272     auto io_sp = IOObjectSP(new TCPSocket(socket.first, false, inherit));
273     handles.emplace_back(accept_loop.RegisterReadObject(
274         io_sp, [fd, inherit, &sock, &AcceptAddr, &error,
275                         &listen_sock](MainLoopBase &loop) {
276           socklen_t sa_len = AcceptAddr.GetMaxLength();
277           sock = AcceptSocket(fd, &AcceptAddr.sockaddr(), &sa_len, inherit,
278                               error);
279           listen_sock = fd;
280           loop.RequestTermination();
281         }, error));
282     if (error.Fail())
283       return error;
284   }
285 
286   bool accept_connection = false;
287   std::unique_ptr<TCPSocket> accepted_socket;
288   // Loop until we are happy with our connection
289   while (!accept_connection) {
290     accept_loop.Run();
291 
292     if (error.Fail())
293         return error;
294 
295     lldb_private::SocketAddress &AddrIn = m_listen_sockets[listen_sock];
296     if (!AddrIn.IsAnyAddr() && AcceptAddr != AddrIn) {
297       if (sock != kInvalidSocketValue) {
298         CLOSE_SOCKET(sock);
299         sock = kInvalidSocketValue;
300       }
301       llvm::errs() << llvm::formatv(
302           "error: rejecting incoming connection from {0} (expecting {1})",
303           AcceptAddr.GetIPAddress(), AddrIn.GetIPAddress());
304       continue;
305     }
306     accept_connection = true;
307     accepted_socket.reset(new TCPSocket(sock, *this));
308   }
309 
310   if (!accepted_socket)
311     return error;
312 
313   // Keep our TCP packets coming without any delays.
314   accepted_socket->SetOptionNoDelay();
315   error.Clear();
316   conn_socket = accepted_socket.release();
317   return error;
318 }
319 
SetOptionNoDelay()320 int TCPSocket::SetOptionNoDelay() {
321   return SetOption(IPPROTO_TCP, TCP_NODELAY, 1);
322 }
323 
SetOptionReuseAddress()324 int TCPSocket::SetOptionReuseAddress() {
325   return SetOption(SOL_SOCKET, SO_REUSEADDR, 1);
326 }
327