1 //===-- TCPSocket.cpp -----------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #if defined(_MSC_VER)
10 #define _WINSOCK_DEPRECATED_NO_WARNINGS
11 #endif
12 
13 #include "lldb/Host/common/TCPSocket.h"
14 
15 #include "lldb/Host/Config.h"
16 #include "lldb/Host/MainLoop.h"
17 #include "lldb/Utility/Log.h"
18 
19 #include "llvm/Config/llvm-config.h"
20 #include "llvm/Support/Errno.h"
21 #include "llvm/Support/WindowsError.h"
22 #include "llvm/Support/raw_ostream.h"
23 
24 #if LLDB_ENABLE_POSIX
25 #include <arpa/inet.h>
26 #include <netinet/tcp.h>
27 #include <sys/socket.h>
28 #endif
29 
30 #if defined(_WIN32)
31 #include <winsock2.h>
32 #endif
33 
34 #ifdef _WIN32
35 #define CLOSE_SOCKET closesocket
36 typedef const char *set_socket_option_arg_type;
37 #else
38 #include <unistd.h>
39 #define CLOSE_SOCKET ::close
40 typedef const void *set_socket_option_arg_type;
41 #endif
42 
43 using namespace lldb;
44 using namespace lldb_private;
45 
46 static Status GetLastSocketError() {
47   std::error_code EC;
48 #ifdef _WIN32
49   EC = llvm::mapWindowsError(WSAGetLastError());
50 #else
51   EC = std::error_code(errno, std::generic_category());
52 #endif
53   return EC;
54 }
55 
56 static const int kType = SOCK_STREAM;
57 
58 TCPSocket::TCPSocket(bool should_close, bool child_processes_inherit)
59     : Socket(ProtocolTcp, should_close, child_processes_inherit) {}
60 
61 TCPSocket::TCPSocket(NativeSocket socket, const TCPSocket &listen_socket)
62     : Socket(ProtocolTcp, listen_socket.m_should_close_fd,
63              listen_socket.m_child_processes_inherit) {
64   m_socket = socket;
65 }
66 
67 TCPSocket::TCPSocket(NativeSocket socket, bool should_close,
68                      bool child_processes_inherit)
69     : Socket(ProtocolTcp, should_close, child_processes_inherit) {
70   m_socket = socket;
71 }
72 
73 TCPSocket::~TCPSocket() { CloseListenSockets(); }
74 
75 bool TCPSocket::IsValid() const {
76   return m_socket != kInvalidSocketValue || m_listen_sockets.size() != 0;
77 }
78 
79 // Return the port number that is being used by the socket.
80 uint16_t TCPSocket::GetLocalPortNumber() const {
81   if (m_socket != kInvalidSocketValue) {
82     SocketAddress sock_addr;
83     socklen_t sock_addr_len = sock_addr.GetMaxLength();
84     if (::getsockname(m_socket, sock_addr, &sock_addr_len) == 0)
85       return sock_addr.GetPort();
86   } else if (!m_listen_sockets.empty()) {
87     SocketAddress sock_addr;
88     socklen_t sock_addr_len = sock_addr.GetMaxLength();
89     if (::getsockname(m_listen_sockets.begin()->first, sock_addr,
90                       &sock_addr_len) == 0)
91       return sock_addr.GetPort();
92   }
93   return 0;
94 }
95 
96 std::string TCPSocket::GetLocalIPAddress() const {
97   // We bound to port zero, so we need to figure out which port we actually
98   // bound to
99   if (m_socket != kInvalidSocketValue) {
100     SocketAddress sock_addr;
101     socklen_t sock_addr_len = sock_addr.GetMaxLength();
102     if (::getsockname(m_socket, sock_addr, &sock_addr_len) == 0)
103       return sock_addr.GetIPAddress();
104   }
105   return "";
106 }
107 
108 uint16_t TCPSocket::GetRemotePortNumber() const {
109   if (m_socket != kInvalidSocketValue) {
110     SocketAddress sock_addr;
111     socklen_t sock_addr_len = sock_addr.GetMaxLength();
112     if (::getpeername(m_socket, sock_addr, &sock_addr_len) == 0)
113       return sock_addr.GetPort();
114   }
115   return 0;
116 }
117 
118 std::string TCPSocket::GetRemoteIPAddress() const {
119   // We bound to port zero, so we need to figure out which port we actually
120   // bound to
121   if (m_socket != kInvalidSocketValue) {
122     SocketAddress sock_addr;
123     socklen_t sock_addr_len = sock_addr.GetMaxLength();
124     if (::getpeername(m_socket, sock_addr, &sock_addr_len) == 0)
125       return sock_addr.GetIPAddress();
126   }
127   return "";
128 }
129 
130 std::string TCPSocket::GetRemoteConnectionURI() const {
131   if (m_socket != kInvalidSocketValue) {
132     return std::string(llvm::formatv(
133         "connect://[{0}]:{1}", GetRemoteIPAddress(), GetRemotePortNumber()));
134   }
135   return "";
136 }
137 
138 Status TCPSocket::CreateSocket(int domain) {
139   Status error;
140   if (IsValid())
141     error = Close();
142   if (error.Fail())
143     return error;
144   m_socket = Socket::CreateSocket(domain, kType, IPPROTO_TCP,
145                                   m_child_processes_inherit, error);
146   return error;
147 }
148 
149 Status TCPSocket::Connect(llvm::StringRef name) {
150 
151   Log *log(lldb_private::GetLogIfAnyCategoriesSet(LIBLLDB_LOG_COMMUNICATION));
152   LLDB_LOGF(log, "TCPSocket::%s (host/port = %s)", __FUNCTION__, name.data());
153 
154   Status error;
155   llvm::Expected<HostAndPort> host_port = DecodeHostAndPort(name);
156   if (!host_port)
157     return Status(host_port.takeError());
158 
159   std::vector<SocketAddress> addresses =
160       SocketAddress::GetAddressInfo(host_port->hostname.c_str(), nullptr,
161                                     AF_UNSPEC, SOCK_STREAM, IPPROTO_TCP);
162   for (SocketAddress &address : addresses) {
163     error = CreateSocket(address.GetFamily());
164     if (error.Fail())
165       continue;
166 
167     address.SetPort(host_port->port);
168 
169     if (-1 == llvm::sys::RetryAfterSignal(-1, ::connect, GetNativeSocket(),
170                                           &address.sockaddr(),
171                                           address.GetLength())) {
172       CLOSE_SOCKET(GetNativeSocket());
173       continue;
174     }
175 
176     SetOptionNoDelay();
177 
178     error.Clear();
179     return error;
180   }
181 
182   error.SetErrorString("Failed to connect port");
183   return error;
184 }
185 
186 Status TCPSocket::Listen(llvm::StringRef name, int backlog) {
187   Log *log(lldb_private::GetLogIfAnyCategoriesSet(LIBLLDB_LOG_CONNECTION));
188   LLDB_LOGF(log, "TCPSocket::%s (%s)", __FUNCTION__, name.data());
189 
190   Status error;
191   llvm::Expected<HostAndPort> host_port = DecodeHostAndPort(name);
192   if (!host_port)
193     return Status(host_port.takeError());
194 
195   if (host_port->hostname == "*")
196     host_port->hostname = "0.0.0.0";
197   std::vector<SocketAddress> addresses = SocketAddress::GetAddressInfo(
198       host_port->hostname.c_str(), nullptr, AF_UNSPEC, SOCK_STREAM, IPPROTO_TCP);
199   for (SocketAddress &address : addresses) {
200     int fd = Socket::CreateSocket(address.GetFamily(), kType, IPPROTO_TCP,
201                                   m_child_processes_inherit, error);
202     if (error.Fail())
203       continue;
204 
205     // enable local address reuse
206     int option_value = 1;
207     set_socket_option_arg_type option_value_p =
208         reinterpret_cast<set_socket_option_arg_type>(&option_value);
209     ::setsockopt(fd, SOL_SOCKET, SO_REUSEADDR, option_value_p,
210                  sizeof(option_value));
211 
212     SocketAddress listen_address = address;
213     if(!listen_address.IsLocalhost())
214       listen_address.SetToAnyAddress(address.GetFamily(), host_port->port);
215     else
216       listen_address.SetPort(host_port->port);
217 
218     int err =
219         ::bind(fd, &listen_address.sockaddr(), listen_address.GetLength());
220     if (-1 != err)
221       err = ::listen(fd, backlog);
222 
223     if (-1 == err) {
224       error = GetLastSocketError();
225       CLOSE_SOCKET(fd);
226       continue;
227     }
228 
229     if (host_port->port == 0) {
230       socklen_t sa_len = address.GetLength();
231       if (getsockname(fd, &address.sockaddr(), &sa_len) == 0)
232         host_port->port = address.GetPort();
233     }
234     m_listen_sockets[fd] = address;
235   }
236 
237   if (m_listen_sockets.empty()) {
238     assert(error.Fail());
239     return error;
240   }
241   return Status();
242 }
243 
244 void TCPSocket::CloseListenSockets() {
245   for (auto socket : m_listen_sockets)
246     CLOSE_SOCKET(socket.first);
247   m_listen_sockets.clear();
248 }
249 
250 Status TCPSocket::Accept(Socket *&conn_socket) {
251   Status error;
252   if (m_listen_sockets.size() == 0) {
253     error.SetErrorString("No open listening sockets!");
254     return error;
255   }
256 
257   int sock = -1;
258   int listen_sock = -1;
259   lldb_private::SocketAddress AcceptAddr;
260   MainLoop accept_loop;
261   std::vector<MainLoopBase::ReadHandleUP> handles;
262   for (auto socket : m_listen_sockets) {
263     auto fd = socket.first;
264     auto inherit = this->m_child_processes_inherit;
265     auto io_sp = IOObjectSP(new TCPSocket(socket.first, false, inherit));
266     handles.emplace_back(accept_loop.RegisterReadObject(
267         io_sp, [fd, inherit, &sock, &AcceptAddr, &error,
268                         &listen_sock](MainLoopBase &loop) {
269           socklen_t sa_len = AcceptAddr.GetMaxLength();
270           sock = AcceptSocket(fd, &AcceptAddr.sockaddr(), &sa_len, inherit,
271                               error);
272           listen_sock = fd;
273           loop.RequestTermination();
274         }, error));
275     if (error.Fail())
276       return error;
277   }
278 
279   bool accept_connection = false;
280   std::unique_ptr<TCPSocket> accepted_socket;
281   // Loop until we are happy with our connection
282   while (!accept_connection) {
283     accept_loop.Run();
284 
285     if (error.Fail())
286         return error;
287 
288     lldb_private::SocketAddress &AddrIn = m_listen_sockets[listen_sock];
289     if (!AddrIn.IsAnyAddr() && AcceptAddr != AddrIn) {
290       CLOSE_SOCKET(sock);
291       llvm::errs() << llvm::formatv(
292           "error: rejecting incoming connection from {0} (expecting {1})",
293           AcceptAddr.GetIPAddress(), AddrIn.GetIPAddress());
294       continue;
295     }
296     accept_connection = true;
297     accepted_socket.reset(new TCPSocket(sock, *this));
298   }
299 
300   if (!accepted_socket)
301     return error;
302 
303   // Keep our TCP packets coming without any delays.
304   accepted_socket->SetOptionNoDelay();
305   error.Clear();
306   conn_socket = accepted_socket.release();
307   return error;
308 }
309 
310 int TCPSocket::SetOptionNoDelay() {
311   return SetOption(IPPROTO_TCP, TCP_NODELAY, 1);
312 }
313 
314 int TCPSocket::SetOptionReuseAddress() {
315   return SetOption(SOL_SOCKET, SO_REUSEADDR, 1);
316 }
317