1 // Copyright (c) 2012 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #include <memory>
6 #include <string>
7 
8 #include "base/bind.h"
9 #include "base/memory/ref_counted.h"
10 #include "base/run_loop.h"
11 #include "net/base/address_list.h"
12 #include "net/base/io_buffer.h"
13 #include "net/base/ip_address.h"
14 #include "net/base/net_errors.h"
15 #include "net/base/test_completion_callback.h"
16 #include "net/log/net_log_event_type.h"
17 #include "net/log/net_log_source.h"
18 #include "net/log/net_log_with_source.h"
19 #include "net/log/test_net_log.h"
20 #include "net/log/test_net_log_util.h"
21 #include "net/socket/client_socket_factory.h"
22 #include "net/socket/tcp_client_socket.h"
23 #include "net/socket/tcp_server_socket.h"
24 #include "net/test/gtest_util.h"
25 #include "net/test/test_with_task_environment.h"
26 #include "net/traffic_annotation/network_traffic_annotation_test_helper.h"
27 #include "testing/gmock/include/gmock/gmock.h"
28 #include "testing/gtest/include/gtest/gtest.h"
29 #include "testing/platform_test.h"
30 
31 using net::test::IsError;
32 using net::test::IsOk;
33 
34 namespace net {
35 
36 namespace {
37 
38 const char kServerReply[] = "HTTP/1.1 404 Not Found";
39 
40 enum ClientSocketTestTypes { TCP, SCTP };
41 
42 }  // namespace
43 
44 class TransportClientSocketTest
45     : public ::testing::TestWithParam<ClientSocketTestTypes>,
46       public WithTaskEnvironment {
47  public:
TransportClientSocketTest()48   TransportClientSocketTest()
49       : listen_port_(0),
50         socket_factory_(ClientSocketFactory::GetDefaultFactory()),
51         close_server_socket_on_next_send_(false) {}
52 
53   virtual ~TransportClientSocketTest() = default;
54 
55   // Testcase hooks
56   void SetUp() override;
57 
CloseServerSocket()58   void CloseServerSocket() {
59     // delete the connected_sock_, which will close it.
60     connected_sock_.reset();
61   }
62 
AcceptCallback(int res)63   void AcceptCallback(int res) {
64     ASSERT_THAT(res, IsOk());
65     connect_loop_.Quit();
66   }
67 
68   int DrainClientSocket(IOBuffer* buf,
69                         uint32_t buf_len,
70                         uint32_t bytes_to_read,
71                         TestCompletionCallback* callback);
72 
73   // Establishes a connection to the server.
74   void EstablishConnection(TestCompletionCallback* callback);
75 
76   // Sends a request from the client to the server socket. Makes the server read
77   // the request and send a response.
78   void SendRequestAndResponse();
79 
80   // Makes |connected_sock_| to read |expected_bytes_read| bytes. Returns the
81   // the data read as a string.
82   std::string ReadServerData(int expected_bytes_read);
83 
84   // Sends server response.
85   void SendServerResponse();
86 
set_close_server_socket_on_next_send(bool close)87   void set_close_server_socket_on_next_send(bool close) {
88     close_server_socket_on_next_send_ = close;
89   }
90 
91  protected:
92   base::RunLoop connect_loop_;
93   uint16_t listen_port_;
94   RecordingTestNetLog net_log_;
95   ClientSocketFactory* const socket_factory_;
96   std::unique_ptr<StreamSocket> sock_;
97   std::unique_ptr<StreamSocket> connected_sock_;
98 
99  private:
100   std::unique_ptr<TCPServerSocket> listen_sock_;
101   bool close_server_socket_on_next_send_;
102 };
103 
SetUp()104 void TransportClientSocketTest::SetUp() {
105   ::testing::TestWithParam<ClientSocketTestTypes>::SetUp();
106 
107   // Open a server socket on an ephemeral port.
108   listen_sock_.reset(new TCPServerSocket(nullptr, NetLogSource()));
109   IPEndPoint local_address(IPAddress::IPv4Localhost(), 0);
110   ASSERT_THAT(listen_sock_->Listen(local_address, 1), IsOk());
111   // Get the server's address (including the actual port number).
112   ASSERT_THAT(listen_sock_->GetLocalAddress(&local_address), IsOk());
113   listen_port_ = local_address.port();
114   listen_sock_->Accept(
115       &connected_sock_,
116       base::BindOnce(&TransportClientSocketTest::AcceptCallback,
117                      base::Unretained(this)));
118 
119   AddressList addr = AddressList::CreateFromIPAddress(
120       IPAddress::IPv4Localhost(), listen_port_);
121   sock_ = socket_factory_->CreateTransportClientSocket(
122       addr, nullptr, nullptr, &net_log_, NetLogSource());
123 }
124 
DrainClientSocket(IOBuffer * buf,uint32_t buf_len,uint32_t bytes_to_read,TestCompletionCallback * callback)125 int TransportClientSocketTest::DrainClientSocket(
126     IOBuffer* buf,
127     uint32_t buf_len,
128     uint32_t bytes_to_read,
129     TestCompletionCallback* callback) {
130   int rv = OK;
131   uint32_t bytes_read = 0;
132 
133   while (bytes_read < bytes_to_read) {
134     rv = sock_->Read(buf, buf_len, callback->callback());
135     EXPECT_TRUE(rv >= 0 || rv == ERR_IO_PENDING);
136     rv = callback->GetResult(rv);
137     EXPECT_GT(rv, 0);
138     bytes_read += rv;
139   }
140 
141   return static_cast<int>(bytes_read);
142 }
143 
EstablishConnection(TestCompletionCallback * callback)144 void TransportClientSocketTest::EstablishConnection(
145     TestCompletionCallback* callback) {
146   int rv = sock_->Connect(callback->callback());
147   // Wait for |listen_sock_| to accept a connection.
148   connect_loop_.Run();
149   // Now wait for the client socket to accept the connection.
150   EXPECT_THAT(callback->GetResult(rv), IsOk());
151 }
152 
SendRequestAndResponse()153 void TransportClientSocketTest::SendRequestAndResponse() {
154   // Send client request.
155   const char request_text[] = "GET / HTTP/1.0\r\n\r\n";
156   int request_len = strlen(request_text);
157   scoped_refptr<DrainableIOBuffer> request_buffer =
158       base::MakeRefCounted<DrainableIOBuffer>(
159           base::MakeRefCounted<IOBuffer>(request_len), request_len);
160   memcpy(request_buffer->data(), request_text, request_len);
161 
162   int bytes_written = 0;
163   while (request_buffer->BytesRemaining() > 0) {
164     TestCompletionCallback write_callback;
165     int write_result =
166         sock_->Write(request_buffer.get(), request_buffer->BytesRemaining(),
167                      write_callback.callback(), TRAFFIC_ANNOTATION_FOR_TESTS);
168     write_result = write_callback.GetResult(write_result);
169     ASSERT_GT(write_result, 0);
170     ASSERT_LE(bytes_written + write_result, request_len);
171     request_buffer->DidConsume(write_result);
172     bytes_written += write_result;
173   }
174   ASSERT_EQ(request_len, bytes_written);
175 
176   // Confirm that the server receives what client sent.
177   std::string data_received = ReadServerData(bytes_written);
178   ASSERT_TRUE(connected_sock_->IsConnectedAndIdle());
179   ASSERT_EQ(request_text, data_received);
180 
181   // Write server response.
182   SendServerResponse();
183 }
184 
SendServerResponse()185 void TransportClientSocketTest::SendServerResponse() {
186   // TODO(dkegel): this might not be long enough to tickle some bugs.
187   int reply_len = strlen(kServerReply);
188   scoped_refptr<DrainableIOBuffer> write_buffer =
189       base::MakeRefCounted<DrainableIOBuffer>(
190           base::MakeRefCounted<IOBuffer>(reply_len), reply_len);
191   memcpy(write_buffer->data(), kServerReply, reply_len);
192   int bytes_written = 0;
193   while (write_buffer->BytesRemaining() > 0) {
194     TestCompletionCallback write_callback;
195     int write_result = connected_sock_->Write(
196         write_buffer.get(), write_buffer->BytesRemaining(),
197         write_callback.callback(), TRAFFIC_ANNOTATION_FOR_TESTS);
198     write_result = write_callback.GetResult(write_result);
199     ASSERT_GE(write_result, 0);
200     ASSERT_LE(bytes_written + write_result, reply_len);
201     write_buffer->DidConsume(write_result);
202     bytes_written += write_result;
203   }
204   if (close_server_socket_on_next_send_)
205     CloseServerSocket();
206 }
207 
ReadServerData(int expected_bytes_read)208 std::string TransportClientSocketTest::ReadServerData(int expected_bytes_read) {
209   int bytes_read = 0;
210   scoped_refptr<IOBufferWithSize> read_buffer =
211       base::MakeRefCounted<IOBufferWithSize>(expected_bytes_read);
212   while (bytes_read < expected_bytes_read) {
213     TestCompletionCallback read_callback;
214     int rv = connected_sock_->Read(read_buffer.get(),
215                                    expected_bytes_read - bytes_read,
216                                    read_callback.callback());
217     EXPECT_TRUE(rv >= 0 || rv == ERR_IO_PENDING);
218     rv = read_callback.GetResult(rv);
219     EXPECT_GE(rv, 0);
220     bytes_read += rv;
221   }
222   EXPECT_EQ(expected_bytes_read, bytes_read);
223   return std::string(read_buffer->data(), bytes_read);
224 }
225 
226 // TODO(leighton):  Add SCTP to this list when it is ready.
227 INSTANTIATE_TEST_SUITE_P(StreamSocket,
228                          TransportClientSocketTest,
229                          ::testing::Values(TCP));
230 
TEST_P(TransportClientSocketTest,Connect)231 TEST_P(TransportClientSocketTest, Connect) {
232   TestCompletionCallback callback;
233   EXPECT_FALSE(sock_->IsConnected());
234 
235   int rv = sock_->Connect(callback.callback());
236   // Wait for |listen_sock_| to accept a connection.
237   connect_loop_.Run();
238 
239   auto net_log_entries = net_log_.GetEntries();
240   EXPECT_TRUE(
241       LogContainsBeginEvent(net_log_entries, 0, NetLogEventType::SOCKET_ALIVE));
242   EXPECT_TRUE(
243       LogContainsBeginEvent(net_log_entries, 1, NetLogEventType::TCP_CONNECT));
244   // Now wait for the client socket to accept the connection.
245   if (rv != OK) {
246     ASSERT_EQ(rv, ERR_IO_PENDING);
247     rv = callback.WaitForResult();
248     EXPECT_EQ(rv, OK);
249   }
250 
251   EXPECT_TRUE(sock_->IsConnected());
252   net_log_entries = net_log_.GetEntries();
253   EXPECT_TRUE(
254       LogContainsEndEvent(net_log_entries, -1, NetLogEventType::TCP_CONNECT));
255 
256   sock_->Disconnect();
257   EXPECT_FALSE(sock_->IsConnected());
258 }
259 
TEST_P(TransportClientSocketTest,IsConnected)260 TEST_P(TransportClientSocketTest, IsConnected) {
261   scoped_refptr<IOBuffer> buf = base::MakeRefCounted<IOBuffer>(4096);
262   TestCompletionCallback callback;
263   uint32_t bytes_read;
264 
265   EXPECT_FALSE(sock_->IsConnected());
266   EXPECT_FALSE(sock_->IsConnectedAndIdle());
267 
268   EstablishConnection(&callback);
269 
270   EXPECT_TRUE(sock_->IsConnected());
271   EXPECT_TRUE(sock_->IsConnectedAndIdle());
272 
273   // Send the request and wait for the server to respond.
274   SendRequestAndResponse();
275 
276   // Drain a single byte so we know we've received some data.
277   bytes_read = DrainClientSocket(buf.get(), 1, 1, &callback);
278   ASSERT_EQ(bytes_read, 1u);
279 
280   // Socket should be considered connected, but not idle, due to
281   // pending data.
282   EXPECT_TRUE(sock_->IsConnected());
283   EXPECT_FALSE(sock_->IsConnectedAndIdle());
284 
285   bytes_read =
286       DrainClientSocket(buf.get(), 4096, strlen(kServerReply) - 1, &callback);
287   ASSERT_EQ(bytes_read, strlen(kServerReply) - 1);
288 
289   // After draining the data, the socket should be back to connected
290   // and idle.
291   EXPECT_TRUE(sock_->IsConnected());
292   EXPECT_TRUE(sock_->IsConnectedAndIdle());
293 
294   // This time close the server socket immediately after the server response.
295   set_close_server_socket_on_next_send(true);
296   SendRequestAndResponse();
297 
298   bytes_read = DrainClientSocket(buf.get(), 1, 1, &callback);
299   ASSERT_EQ(bytes_read, 1u);
300 
301   // As above because of data.
302   EXPECT_TRUE(sock_->IsConnected());
303   EXPECT_FALSE(sock_->IsConnectedAndIdle());
304 
305   bytes_read =
306       DrainClientSocket(buf.get(), 4096, strlen(kServerReply) - 1, &callback);
307   ASSERT_EQ(bytes_read, strlen(kServerReply) - 1);
308 
309   // Once the data is drained, the socket should now be seen as not
310   // connected.
311   if (sock_->IsConnected()) {
312     // In the unlikely event that the server's connection closure is not
313     // processed in time, wait for the connection to be closed.
314     int rv = sock_->Read(buf.get(), 4096, callback.callback());
315     EXPECT_EQ(0, callback.GetResult(rv));
316     EXPECT_FALSE(sock_->IsConnected());
317   }
318   EXPECT_FALSE(sock_->IsConnectedAndIdle());
319 }
320 
TEST_P(TransportClientSocketTest,Read)321 TEST_P(TransportClientSocketTest, Read) {
322   TestCompletionCallback callback;
323   EstablishConnection(&callback);
324 
325   SendRequestAndResponse();
326 
327   scoped_refptr<IOBuffer> buf = base::MakeRefCounted<IOBuffer>(4096);
328   uint32_t bytes_read =
329       DrainClientSocket(buf.get(), 4096, strlen(kServerReply), &callback);
330   ASSERT_EQ(bytes_read, strlen(kServerReply));
331   ASSERT_EQ(std::string(kServerReply), std::string(buf->data(), bytes_read));
332 
333   // All data has been read now.  Read once more to force an ERR_IO_PENDING, and
334   // then close the server socket, and note the close.
335 
336   int rv = sock_->Read(buf.get(), 4096, callback.callback());
337   ASSERT_THAT(rv, IsError(ERR_IO_PENDING));
338   CloseServerSocket();
339   EXPECT_EQ(0, callback.WaitForResult());
340 }
341 
TEST_P(TransportClientSocketTest,Read_SmallChunks)342 TEST_P(TransportClientSocketTest, Read_SmallChunks) {
343   TestCompletionCallback callback;
344   EstablishConnection(&callback);
345 
346   SendRequestAndResponse();
347 
348   scoped_refptr<IOBuffer> buf = base::MakeRefCounted<IOBuffer>(1);
349   uint32_t bytes_read = 0;
350   while (bytes_read < strlen(kServerReply)) {
351     int rv = sock_->Read(buf.get(), 1, callback.callback());
352     EXPECT_TRUE(rv >= 0 || rv == ERR_IO_PENDING);
353 
354     rv = callback.GetResult(rv);
355 
356     ASSERT_EQ(1, rv);
357     bytes_read += rv;
358   }
359 
360   // All data has been read now.  Read once more to force an ERR_IO_PENDING, and
361   // then close the server socket, and note the close.
362 
363   int rv = sock_->Read(buf.get(), 1, callback.callback());
364   ASSERT_THAT(rv, IsError(ERR_IO_PENDING));
365   CloseServerSocket();
366   EXPECT_EQ(0, callback.WaitForResult());
367 }
368 
TEST_P(TransportClientSocketTest,Read_Interrupted)369 TEST_P(TransportClientSocketTest, Read_Interrupted) {
370   TestCompletionCallback callback;
371   EstablishConnection(&callback);
372 
373   SendRequestAndResponse();
374 
375   // Do a partial read and then exit.  This test should not crash!
376   scoped_refptr<IOBuffer> buf = base::MakeRefCounted<IOBuffer>(16);
377   int rv = sock_->Read(buf.get(), 16, callback.callback());
378   EXPECT_TRUE(rv >= 0 || rv == ERR_IO_PENDING);
379 
380   rv = callback.GetResult(rv);
381 
382   EXPECT_NE(0, rv);
383 }
384 
TEST_P(TransportClientSocketTest,FullDuplex_ReadFirst)385 TEST_P(TransportClientSocketTest, FullDuplex_ReadFirst) {
386   TestCompletionCallback callback;
387   EstablishConnection(&callback);
388 
389   // Read first.  There's no data, so it should return ERR_IO_PENDING.
390   const int kBufLen = 4096;
391   scoped_refptr<IOBuffer> buf = base::MakeRefCounted<IOBuffer>(kBufLen);
392   int rv = sock_->Read(buf.get(), kBufLen, callback.callback());
393   EXPECT_THAT(rv, IsError(ERR_IO_PENDING));
394 
395   const int kWriteBufLen = 64 * 1024;
396   scoped_refptr<IOBuffer> request_buffer =
397       base::MakeRefCounted<IOBuffer>(kWriteBufLen);
398   char* request_data = request_buffer->data();
399   memset(request_data, 'A', kWriteBufLen);
400   TestCompletionCallback write_callback;
401 
402   int bytes_written = 0;
403   while (true) {
404     rv = sock_->Write(request_buffer.get(), kWriteBufLen,
405                       write_callback.callback(), TRAFFIC_ANNOTATION_FOR_TESTS);
406     ASSERT_TRUE(rv >= 0 || rv == ERR_IO_PENDING);
407     if (rv == ERR_IO_PENDING) {
408       ReadServerData(bytes_written);
409       SendServerResponse();
410       rv = write_callback.WaitForResult();
411       break;
412     }
413     bytes_written += rv;
414   }
415 
416   // At this point, both read and write have returned ERR_IO_PENDING, and the
417   // write callback has executed.  We wait for the read callback to run now to
418   // make sure that the socket can handle full duplex communications.
419 
420   rv = callback.WaitForResult();
421   EXPECT_GE(rv, 0);
422 }
423 
TEST_P(TransportClientSocketTest,FullDuplex_WriteFirst)424 TEST_P(TransportClientSocketTest, FullDuplex_WriteFirst) {
425   TestCompletionCallback callback;
426   EstablishConnection(&callback);
427 
428   const int kWriteBufLen = 64 * 1024;
429   scoped_refptr<IOBuffer> request_buffer =
430       base::MakeRefCounted<IOBuffer>(kWriteBufLen);
431   char* request_data = request_buffer->data();
432   memset(request_data, 'A', kWriteBufLen);
433   TestCompletionCallback write_callback;
434 
435   int bytes_written = 0;
436   while (true) {
437     int rv =
438         sock_->Write(request_buffer.get(), kWriteBufLen,
439                      write_callback.callback(), TRAFFIC_ANNOTATION_FOR_TESTS);
440     ASSERT_TRUE(rv >= 0 || rv == ERR_IO_PENDING);
441 
442     if (rv == ERR_IO_PENDING)
443       break;
444     bytes_written += rv;
445   }
446 
447   // Now we have the Write() blocked on ERR_IO_PENDING.  It's time to force the
448   // Read() to block on ERR_IO_PENDING too.
449 
450   const int kBufLen = 4096;
451   scoped_refptr<IOBuffer> buf = base::MakeRefCounted<IOBuffer>(kBufLen);
452   while (true) {
453     int rv = sock_->Read(buf.get(), kBufLen, callback.callback());
454     ASSERT_TRUE(rv >= 0 || rv == ERR_IO_PENDING);
455     if (rv == ERR_IO_PENDING)
456       break;
457   }
458 
459   // At this point, both read and write have returned ERR_IO_PENDING.  Now we
460   // run the write and read callbacks to make sure they can handle full duplex
461   // communications.
462 
463   ReadServerData(bytes_written);
464   SendServerResponse();
465   int rv = write_callback.WaitForResult();
466   EXPECT_GE(rv, 0);
467 
468   rv = callback.WaitForResult();
469   EXPECT_GT(rv, 0);
470 }
471 
472 }  // namespace net
473