1 //===-- RNBSocketTest.cpp ---------------------------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "gtest/gtest.h"
10 
11 #include <arpa/inet.h>
12 #include <sys/sysctl.h>
13 #include <unistd.h>
14 
15 #include "RNBDefs.h"
16 #include "RNBSocket.h"
17 #include "lldb/Host/Socket.h"
18 #include "lldb/Host/StringConvert.h"
19 #include "lldb/Host/common/TCPSocket.h"
20 
21 using namespace lldb_private;
22 
23 std::string hello = "Hello, world!";
24 std::string goodbye = "Goodbye!";
25 
ServerCallbackv4(const void * baton,in_port_t port)26 static void ServerCallbackv4(const void *baton, in_port_t port) {
27   auto child_pid = fork();
28   if (child_pid == 0) {
29     Socket *client_socket;
30     char addr_buffer[256];
31     sprintf(addr_buffer, "%s:%d", baton, port);
32     Status err = Socket::TcpConnect(addr_buffer, false, client_socket);
33     if (err.Fail())
34       abort();
35     char buffer[32];
36     size_t read_size = 32;
37     err = client_socket->Read((void *)&buffer[0], read_size);
38     if (err.Fail())
39       abort();
40     std::string Recv(&buffer[0], read_size);
41     if (Recv != hello)
42       abort();
43     size_t write_size = goodbye.length();
44     err = client_socket->Write(goodbye.c_str(), write_size);
45     if (err.Fail())
46       abort();
47     if (write_size != goodbye.length())
48       abort();
49     delete client_socket;
50     exit(0);
51   }
52 }
53 
TestSocketListen(const char * addr)54 void TestSocketListen(const char *addr) {
55   // Skip IPv6 tests if there isn't a valid interafce
56   auto addresses = lldb_private::SocketAddress::GetAddressInfo(
57       addr, NULL, AF_UNSPEC, SOCK_STREAM, IPPROTO_TCP);
58   if (addresses.size() == 0)
59     return;
60 
61   char addr_wrap[256];
62   if (addresses.front().GetFamily() == AF_INET6)
63     sprintf(addr_wrap, "[%s]", addr);
64   else
65     sprintf(addr_wrap, "%s", addr);
66 
67   RNBSocket server_socket;
68   auto result =
69       server_socket.Listen(addr, 0, ServerCallbackv4, (const void *)addr_wrap);
70   ASSERT_TRUE(result == rnb_success);
71   result = server_socket.Write(hello.c_str(), hello.length());
72   ASSERT_TRUE(result == rnb_success);
73   std::string bye;
74   result = server_socket.Read(bye);
75   ASSERT_TRUE(result == rnb_success);
76   ASSERT_EQ(bye, goodbye);
77 
78   int exit_status;
79   wait(&exit_status);
80   ASSERT_EQ(exit_status, 0);
81 }
82 
TEST(RNBSocket,LoopBackListenIPv4)83 TEST(RNBSocket, LoopBackListenIPv4) { TestSocketListen("127.0.0.1"); }
84 
TEST(RNBSocket,LoopBackListenIPv6)85 TEST(RNBSocket, LoopBackListenIPv6) { TestSocketListen("::1"); }
86 
TEST(RNBSocket,AnyListen)87 TEST(RNBSocket, AnyListen) { TestSocketListen("*"); }
88 
TestSocketConnect(const char * addr)89 void TestSocketConnect(const char *addr) {
90   // Skip IPv6 tests if there isn't a valid interafce
91   auto addresses = lldb_private::SocketAddress::GetAddressInfo(
92       addr, NULL, AF_UNSPEC, SOCK_STREAM, IPPROTO_TCP);
93   if (addresses.size() == 0)
94     return;
95 
96   char addr_wrap[256];
97   if (addresses.front().GetFamily() == AF_INET6)
98     sprintf(addr_wrap, "[%s]:0", addr);
99   else
100     sprintf(addr_wrap, "%s:0", addr);
101 
102   Socket *server_socket;
103   Predicate<uint16_t> port_predicate;
104   port_predicate.SetValue(0, eBroadcastNever);
105   Status err =
106       Socket::TcpListen(addr_wrap, false, server_socket, &port_predicate);
107   ASSERT_FALSE(err.Fail());
108 
109   auto port = ((TCPSocket *)server_socket)->GetLocalPortNumber();
110   auto child_pid = fork();
111   if (child_pid != 0) {
112     RNBSocket client_socket;
113     auto result = client_socket.Connect(addr, port);
114     ASSERT_TRUE(result == rnb_success);
115     result = client_socket.Write(hello.c_str(), hello.length());
116     ASSERT_TRUE(result == rnb_success);
117     std::string bye;
118     result = client_socket.Read(bye);
119     ASSERT_TRUE(result == rnb_success);
120     ASSERT_EQ(bye, goodbye);
121   } else {
122     Socket *connected_socket;
123     err = server_socket->Accept(connected_socket);
124     if (err.Fail()) {
125       llvm::errs() << err.AsCString();
126       abort();
127     }
128     char buffer[32];
129     size_t read_size = 32;
130     err = connected_socket->Read((void *)&buffer[0], read_size);
131     if (err.Fail()) {
132       llvm::errs() << err.AsCString();
133       abort();
134     }
135     std::string Recv(&buffer[0], read_size);
136     if (Recv != hello) {
137       llvm::errs() << err.AsCString();
138       abort();
139     }
140     size_t write_size = goodbye.length();
141     err = connected_socket->Write(goodbye.c_str(), write_size);
142     if (err.Fail()) {
143       llvm::errs() << err.AsCString();
144       abort();
145     }
146     if (write_size != goodbye.length()) {
147       llvm::errs() << err.AsCString();
148       abort();
149     }
150     exit(0);
151   }
152   int exit_status;
153   wait(&exit_status);
154   ASSERT_EQ(exit_status, 0);
155 }
156 
TEST(RNBSocket,LoopBackConnectIPv4)157 TEST(RNBSocket, LoopBackConnectIPv4) { TestSocketConnect("127.0.0.1"); }
158 
TEST(RNBSocket,LoopBackConnectIPv6)159 TEST(RNBSocket, LoopBackConnectIPv6) { TestSocketConnect("::1"); }
160