1 // Copyright (c) 2012 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 #include "net/socket/socket_test_util.h"
6
7 #include <inttypes.h> // For SCNx64
8 #include <stdint.h>
9 #include <stdio.h>
10
11 #include <algorithm>
12 #include <string>
13 #include <utility>
14 #include <vector>
15
16 #include "base/bind.h"
17 #include "base/bind_helpers.h"
18 #include "base/callback_helpers.h"
19 #include "base/compiler_specific.h"
20 #include "base/files/file_util.h"
21 #include "base/location.h"
22 #include "base/logging.h"
23 #include "base/rand_util.h"
24 #include "base/run_loop.h"
25 #include "base/single_thread_task_runner.h"
26 #include "base/stl_util.h"
27 #include "base/threading/thread_task_runner_handle.h"
28 #include "base/time/time.h"
29 #include "net/base/address_family.h"
30 #include "net/base/address_list.h"
31 #include "net/base/auth.h"
32 #include "net/base/hex_utils.h"
33 #include "net/base/ip_address.h"
34 #include "net/base/load_timing_info.h"
35 #include "net/base/proxy_server.h"
36 #include "net/http/http_network_session.h"
37 #include "net/http/http_request_headers.h"
38 #include "net/http/http_response_headers.h"
39 #include "net/log/net_log_source.h"
40 #include "net/log/net_log_source_type.h"
41 #include "net/socket/connect_job.h"
42 #include "net/socket/socket.h"
43 #include "net/socket/stream_socket.h"
44 #include "net/socket/websocket_endpoint_lock_manager.h"
45 #include "net/ssl/ssl_cert_request_info.h"
46 #include "net/ssl/ssl_connection_status_flags.h"
47 #include "net/ssl/ssl_info.h"
48 #include "net/traffic_annotation/network_traffic_annotation.h"
49 #include "net/traffic_annotation/network_traffic_annotation_test_helper.h"
50 #include "testing/gtest/include/gtest/gtest.h"
51
52 #if defined(OS_ANDROID)
53 #include "base/android/build_info.h"
54 #endif
55
56 #define NET_TRACE(level, s) VLOG(level) << s << __FUNCTION__ << "() "
57
58 namespace net {
59 namespace {
60
AsciifyHigh(char x)61 inline char AsciifyHigh(char x) {
62 char nybble = static_cast<char>((x >> 4) & 0x0F);
63 return nybble + ((nybble < 0x0A) ? '0' : 'A' - 10);
64 }
65
AsciifyLow(char x)66 inline char AsciifyLow(char x) {
67 char nybble = static_cast<char>((x >> 0) & 0x0F);
68 return nybble + ((nybble < 0x0A) ? '0' : 'A' - 10);
69 }
70
Asciify(char x)71 inline char Asciify(char x) {
72 if ((x < 0) || !isprint(x))
73 return '.';
74 return x;
75 }
76
DumpData(const char * data,int data_len)77 void DumpData(const char* data, int data_len) {
78 if (logging::LOG_INFO < logging::GetMinLogLevel())
79 return;
80 DVLOG(1) << "Length: " << data_len;
81 const char* pfx = "Data: ";
82 if (!data || (data_len <= 0)) {
83 DVLOG(1) << pfx << "<None>";
84 } else {
85 int i;
86 for (i = 0; i <= (data_len - 4); i += 4) {
87 DVLOG(1) << pfx
88 << AsciifyHigh(data[i + 0]) << AsciifyLow(data[i + 0])
89 << AsciifyHigh(data[i + 1]) << AsciifyLow(data[i + 1])
90 << AsciifyHigh(data[i + 2]) << AsciifyLow(data[i + 2])
91 << AsciifyHigh(data[i + 3]) << AsciifyLow(data[i + 3])
92 << " '"
93 << Asciify(data[i + 0])
94 << Asciify(data[i + 1])
95 << Asciify(data[i + 2])
96 << Asciify(data[i + 3])
97 << "'";
98 pfx = " ";
99 }
100 // Take care of any 'trailing' bytes, if data_len was not a multiple of 4.
101 switch (data_len - i) {
102 case 3:
103 DVLOG(1) << pfx
104 << AsciifyHigh(data[i + 0]) << AsciifyLow(data[i + 0])
105 << AsciifyHigh(data[i + 1]) << AsciifyLow(data[i + 1])
106 << AsciifyHigh(data[i + 2]) << AsciifyLow(data[i + 2])
107 << " '"
108 << Asciify(data[i + 0])
109 << Asciify(data[i + 1])
110 << Asciify(data[i + 2])
111 << " '";
112 break;
113 case 2:
114 DVLOG(1) << pfx
115 << AsciifyHigh(data[i + 0]) << AsciifyLow(data[i + 0])
116 << AsciifyHigh(data[i + 1]) << AsciifyLow(data[i + 1])
117 << " '"
118 << Asciify(data[i + 0])
119 << Asciify(data[i + 1])
120 << " '";
121 break;
122 case 1:
123 DVLOG(1) << pfx
124 << AsciifyHigh(data[i + 0]) << AsciifyLow(data[i + 0])
125 << " '"
126 << Asciify(data[i + 0])
127 << " '";
128 break;
129 }
130 }
131 }
132
133 template <MockReadWriteType type>
DumpMockReadWrite(const MockReadWrite<type> & r)134 void DumpMockReadWrite(const MockReadWrite<type>& r) {
135 if (logging::LOG_INFO < logging::GetMinLogLevel())
136 return;
137 DVLOG(1) << "Async: " << (r.mode == ASYNC)
138 << "\nResult: " << r.result;
139 DumpData(r.data, r.data_len);
140 const char* stop = (r.sequence_number & MockRead::STOPLOOP) ? " (STOP)" : "";
141 DVLOG(1) << "Stage: " << (r.sequence_number & ~MockRead::STOPLOOP) << stop;
142 }
143
RunClosureIfNonNull(base::OnceClosure closure)144 void RunClosureIfNonNull(base::OnceClosure closure) {
145 if (!closure.is_null()) {
146 std::move(closure).Run();
147 }
148 }
149
150 } // namespace
151
MockConnect()152 MockConnect::MockConnect() : mode(ASYNC), result(OK) {
153 peer_addr = IPEndPoint(IPAddress(192, 0, 2, 33), 0);
154 }
155
MockConnect(IoMode io_mode,int r)156 MockConnect::MockConnect(IoMode io_mode, int r) : mode(io_mode), result(r) {
157 peer_addr = IPEndPoint(IPAddress(192, 0, 2, 33), 0);
158 }
159
MockConnect(IoMode io_mode,int r,IPEndPoint addr)160 MockConnect::MockConnect(IoMode io_mode, int r, IPEndPoint addr) :
161 mode(io_mode),
162 result(r),
163 peer_addr(addr) {
164 }
165
166 MockConnect::~MockConnect() = default;
167
MockConfirm()168 MockConfirm::MockConfirm() : mode(SYNCHRONOUS), result(OK) {}
169
MockConfirm(IoMode io_mode,int r)170 MockConfirm::MockConfirm(IoMode io_mode, int r) : mode(io_mode), result(r) {}
171
172 MockConfirm::~MockConfirm() = default;
173
IsIdle() const174 bool SocketDataProvider::IsIdle() const {
175 return true;
176 }
177
Initialize(AsyncSocket * socket)178 void SocketDataProvider::Initialize(AsyncSocket* socket) {
179 CHECK(!socket_);
180 CHECK(socket);
181 socket_ = socket;
182 Reset();
183 }
184
DetachSocket()185 void SocketDataProvider::DetachSocket() {
186 CHECK(socket_);
187 socket_ = nullptr;
188 }
189
SocketDataProvider()190 SocketDataProvider::SocketDataProvider() {}
191
~SocketDataProvider()192 SocketDataProvider::~SocketDataProvider() {
193 if (socket_)
194 socket_->OnDataProviderDestroyed();
195 }
196
StaticSocketDataHelper(base::span<const MockRead> reads,base::span<const MockWrite> writes)197 StaticSocketDataHelper::StaticSocketDataHelper(
198 base::span<const MockRead> reads,
199 base::span<const MockWrite> writes)
200 : reads_(reads), read_index_(0), writes_(writes), write_index_(0) {}
201
202 StaticSocketDataHelper::~StaticSocketDataHelper() = default;
203
PeekRead() const204 const MockRead& StaticSocketDataHelper::PeekRead() const {
205 CHECK(!AllReadDataConsumed());
206 return reads_[read_index_];
207 }
208
PeekWrite() const209 const MockWrite& StaticSocketDataHelper::PeekWrite() const {
210 CHECK(!AllWriteDataConsumed());
211 return writes_[write_index_];
212 }
213
AdvanceRead()214 const MockRead& StaticSocketDataHelper::AdvanceRead() {
215 CHECK(!AllReadDataConsumed());
216 return reads_[read_index_++];
217 }
218
AdvanceWrite()219 const MockWrite& StaticSocketDataHelper::AdvanceWrite() {
220 CHECK(!AllWriteDataConsumed());
221 return writes_[write_index_++];
222 }
223
Reset()224 void StaticSocketDataHelper::Reset() {
225 read_index_ = 0;
226 write_index_ = 0;
227 }
228
VerifyWriteData(const std::string & data,SocketDataPrinter * printer)229 bool StaticSocketDataHelper::VerifyWriteData(const std::string& data,
230 SocketDataPrinter* printer) {
231 CHECK(!AllWriteDataConsumed());
232 // Check that the actual data matches the expectations, skipping over any
233 // pause events.
234 const MockWrite& next_write = PeekRealWrite();
235 if (!next_write.data)
236 return true;
237
238 // Note: Partial writes are supported here. If the expected data
239 // is a match, but shorter than the write actually written, that is legal.
240 // Example:
241 // Application writes "foobarbaz" (9 bytes)
242 // Expected write was "foo" (3 bytes)
243 // This is a success, and the function returns true.
244 std::string expected_data(next_write.data, next_write.data_len);
245 std::string actual_data(data.substr(0, next_write.data_len));
246 EXPECT_GE(data.length(), expected_data.length());
247 EXPECT_TRUE(actual_data == expected_data)
248 << "Actual write data:\n" << HexDump(data)
249 << "Expected write data:\n" << HexDump(expected_data);
250 if (printer) {
251 EXPECT_TRUE(actual_data == expected_data)
252 << "Actual write data:\n"
253 << printer->PrintWrite(data) << "Expected write data:\n"
254 << printer->PrintWrite(expected_data);
255 }
256 return expected_data == actual_data;
257 }
258
PeekRealWrite() const259 const MockWrite& StaticSocketDataHelper::PeekRealWrite() const {
260 for (size_t i = write_index_; i < write_count(); i++) {
261 if (writes_[i].mode != ASYNC || writes_[i].result != ERR_IO_PENDING)
262 return writes_[i];
263 }
264
265 CHECK(false) << "No write data available.";
266 return writes_[0]; // Avoid warning about unreachable missing return.
267 }
268
StaticSocketDataProvider()269 StaticSocketDataProvider::StaticSocketDataProvider()
270 : StaticSocketDataProvider(base::span<const MockRead>(),
271 base::span<const MockWrite>()) {}
272
StaticSocketDataProvider(base::span<const MockRead> reads,base::span<const MockWrite> writes)273 StaticSocketDataProvider::StaticSocketDataProvider(
274 base::span<const MockRead> reads,
275 base::span<const MockWrite> writes)
276 : helper_(reads, writes) {}
277
278 StaticSocketDataProvider::~StaticSocketDataProvider() = default;
279
Pause()280 void StaticSocketDataProvider::Pause() {
281 paused_ = true;
282 }
283
Resume()284 void StaticSocketDataProvider::Resume() {
285 paused_ = false;
286 }
287
OnRead()288 MockRead StaticSocketDataProvider::OnRead() {
289 if (AllReadDataConsumed()) {
290 const net::MockRead pending_read(net::SYNCHRONOUS, net::ERR_IO_PENDING);
291 return pending_read;
292 }
293
294 return helper_.AdvanceRead();
295 }
296
OnWrite(const std::string & data)297 MockWriteResult StaticSocketDataProvider::OnWrite(const std::string& data) {
298 if (helper_.write_count() == 0) {
299 // Not using mock writes; succeed synchronously.
300 return MockWriteResult(SYNCHRONOUS, data.length());
301 }
302 EXPECT_FALSE(helper_.AllWriteDataConsumed())
303 << "No more mock data to match write:\n"
304 << HexDump(data);
305 if (helper_.AllWriteDataConsumed()) {
306 return MockWriteResult(SYNCHRONOUS, ERR_UNEXPECTED);
307 }
308
309 // Check that what we are writing matches the expectation.
310 // Then give the mocked return value.
311 if (!helper_.VerifyWriteData(data, printer_))
312 return MockWriteResult(SYNCHRONOUS, ERR_UNEXPECTED);
313
314 const MockWrite& next_write = helper_.AdvanceWrite();
315 // In the case that the write was successful, return the number of bytes
316 // written. Otherwise return the error code.
317 int result =
318 next_write.result == OK ? next_write.data_len : next_write.result;
319 return MockWriteResult(next_write.mode, result);
320 }
321
AllReadDataConsumed() const322 bool StaticSocketDataProvider::AllReadDataConsumed() const {
323 return paused_ || helper_.AllReadDataConsumed();
324 }
325
AllWriteDataConsumed() const326 bool StaticSocketDataProvider::AllWriteDataConsumed() const {
327 return helper_.AllWriteDataConsumed();
328 }
329
Reset()330 void StaticSocketDataProvider::Reset() {
331 helper_.Reset();
332 }
333
ProxyClientSocketDataProvider(IoMode mode,int result)334 ProxyClientSocketDataProvider::ProxyClientSocketDataProvider(IoMode mode,
335 int result)
336 : connect(mode, result) {}
337
338 ProxyClientSocketDataProvider::ProxyClientSocketDataProvider(
339 const ProxyClientSocketDataProvider& other) = default;
340
341 ProxyClientSocketDataProvider::~ProxyClientSocketDataProvider() = default;
342
SSLSocketDataProvider(IoMode mode,int result)343 SSLSocketDataProvider::SSLSocketDataProvider(IoMode mode, int result)
344 : connect(mode, result),
345 next_proto(kProtoUnknown),
346 cert_request_info(nullptr),
347 expected_ssl_version_min(kDefaultSSLVersionMin),
348 expected_ssl_version_max(kDefaultSSLVersionMax) {
349 SSLConnectionStatusSetVersion(SSL_CONNECTION_VERSION_TLS1_3,
350 &ssl_info.connection_status);
351 // Set to TLS_CHACHA20_POLY1305_SHA256
352 SSLConnectionStatusSetCipherSuite(0x1301, &ssl_info.connection_status);
353 }
354
355 SSLSocketDataProvider::SSLSocketDataProvider(
356 const SSLSocketDataProvider& other) = default;
357
358 SSLSocketDataProvider::~SSLSocketDataProvider() = default;
359
SequencedSocketData()360 SequencedSocketData::SequencedSocketData()
361 : SequencedSocketData(base::span<const MockRead>(),
362 base::span<const MockWrite>()) {}
363
SequencedSocketData(base::span<const MockRead> reads,base::span<const MockWrite> writes)364 SequencedSocketData::SequencedSocketData(base::span<const MockRead> reads,
365 base::span<const MockWrite> writes)
366 : helper_(reads, writes),
367 sequence_number_(0),
368 read_state_(IDLE),
369 write_state_(IDLE),
370 busy_before_sync_reads_(false) {
371 // Check that reads and writes have a contiguous set of sequence numbers
372 // starting from 0 and working their way up, with no repeats and skipping
373 // no values.
374 int next_sequence_number = 0;
375 bool last_event_was_pause = false;
376
377 auto next_read = reads.begin();
378 auto next_write = writes.begin();
379 while (next_read != reads.end() || next_write != writes.end()) {
380 if (next_read != reads.end() &&
381 next_read->sequence_number == next_sequence_number) {
382 // Check if this is a pause.
383 if (next_read->mode == ASYNC && next_read->result == ERR_IO_PENDING) {
384 CHECK(!last_event_was_pause) << "Two pauses in a row are not allowed: "
385 << next_sequence_number;
386 last_event_was_pause = true;
387 } else if (last_event_was_pause) {
388 CHECK_EQ(ASYNC, next_read->mode)
389 << "A sync event after a pause makes no sense: "
390 << next_sequence_number;
391 CHECK_NE(ERR_IO_PENDING, next_read->result)
392 << "A pause event after a pause makes no sense: "
393 << next_sequence_number;
394 last_event_was_pause = false;
395 }
396
397 ++next_read;
398 ++next_sequence_number;
399 continue;
400 }
401 if (next_write != writes.end() &&
402 next_write->sequence_number == next_sequence_number) {
403 // Check if this is a pause.
404 if (next_write->mode == ASYNC && next_write->result == ERR_IO_PENDING) {
405 CHECK(!last_event_was_pause) << "Two pauses in a row are not allowed: "
406 << next_sequence_number;
407 last_event_was_pause = true;
408 } else if (last_event_was_pause) {
409 CHECK_EQ(ASYNC, next_write->mode)
410 << "A sync event after a pause makes no sense: "
411 << next_sequence_number;
412 CHECK_NE(ERR_IO_PENDING, next_write->result)
413 << "A pause event after a pause makes no sense: "
414 << next_sequence_number;
415 last_event_was_pause = false;
416 }
417
418 ++next_write;
419 ++next_sequence_number;
420 continue;
421 }
422 if (next_write != writes.end()) {
423 CHECK(false) << "Sequence number " << next_write->sequence_number
424 << " not found where expected: " << next_sequence_number;
425 } else {
426 CHECK(false) << "Too few writes, next expected sequence number: "
427 << next_sequence_number;
428 }
429 return;
430 }
431
432 // Last event must not be a pause. For the final event to indicate the
433 // operation never completes, it should be SYNCHRONOUS and return
434 // ERR_IO_PENDING.
435 CHECK(!last_event_was_pause);
436
437 CHECK(next_read == reads.end());
438 CHECK(next_write == writes.end());
439 }
440
SequencedSocketData(const MockConnect & connect,base::span<const MockRead> reads,base::span<const MockWrite> writes)441 SequencedSocketData::SequencedSocketData(const MockConnect& connect,
442 base::span<const MockRead> reads,
443 base::span<const MockWrite> writes)
444 : SequencedSocketData(reads, writes) {
445 set_connect_data(connect);
446 }
447
OnRead()448 MockRead SequencedSocketData::OnRead() {
449 CHECK_EQ(IDLE, read_state_);
450 CHECK(!helper_.AllReadDataConsumed())
451 << "Application tried to read but there is no read data left";
452
453 NET_TRACE(1, " *** ") << "sequence_number: " << sequence_number_;
454 const MockRead& next_read = helper_.PeekRead();
455 NET_TRACE(1, " *** ") << "next_read: " << next_read.sequence_number;
456 CHECK_GE(next_read.sequence_number, sequence_number_);
457
458 if (next_read.sequence_number <= sequence_number_) {
459 if (next_read.mode == SYNCHRONOUS) {
460 NET_TRACE(1, " *** ") << "Returning synchronously";
461 DumpMockReadWrite(next_read);
462 helper_.AdvanceRead();
463 ++sequence_number_;
464 MaybePostWriteCompleteTask();
465 return next_read;
466 }
467
468 // If the result is ERR_IO_PENDING, then pause.
469 if (next_read.result == ERR_IO_PENDING) {
470 NET_TRACE(1, " *** ") << "Pausing read at: " << sequence_number_;
471 read_state_ = PAUSED;
472 if (run_until_paused_run_loop_)
473 run_until_paused_run_loop_->Quit();
474 return MockRead(SYNCHRONOUS, ERR_IO_PENDING);
475 }
476 base::ThreadTaskRunnerHandle::Get()->PostTask(
477 FROM_HERE, base::BindOnce(&SequencedSocketData::OnReadComplete,
478 weak_factory_.GetWeakPtr()));
479 CHECK_NE(COMPLETING, write_state_);
480 read_state_ = COMPLETING;
481 } else if (next_read.mode == SYNCHRONOUS) {
482 ADD_FAILURE() << "Unable to perform synchronous IO while stopped";
483 return MockRead(SYNCHRONOUS, ERR_UNEXPECTED);
484 } else {
485 NET_TRACE(1, " *** ") << "Waiting for write to trigger read";
486 read_state_ = PENDING;
487 }
488
489 return MockRead(SYNCHRONOUS, ERR_IO_PENDING);
490 }
491
OnWrite(const std::string & data)492 MockWriteResult SequencedSocketData::OnWrite(const std::string& data) {
493 CHECK_EQ(IDLE, write_state_);
494 CHECK(!helper_.AllWriteDataConsumed())
495 << "\nNo more mock data to match write:\n"
496 << HexDump(data);
497
498 NET_TRACE(1, " *** ") << "sequence_number: " << sequence_number_;
499 const MockWrite& next_write = helper_.PeekWrite();
500 NET_TRACE(1, " *** ") << "next_write: " << next_write.sequence_number;
501 CHECK_GE(next_write.sequence_number, sequence_number_);
502
503 if (!helper_.VerifyWriteData(data, printer_))
504 return MockWriteResult(SYNCHRONOUS, ERR_UNEXPECTED);
505
506 if (next_write.sequence_number <= sequence_number_) {
507 if (next_write.mode == SYNCHRONOUS) {
508 helper_.AdvanceWrite();
509 ++sequence_number_;
510 MaybePostReadCompleteTask();
511 // In the case that the write was successful, return the number of bytes
512 // written. Otherwise return the error code.
513 int rv =
514 next_write.result != OK ? next_write.result : next_write.data_len;
515 NET_TRACE(1, " *** ") << "Returning synchronously";
516 return MockWriteResult(SYNCHRONOUS, rv);
517 }
518
519 // If the result is ERR_IO_PENDING, then pause.
520 if (next_write.result == ERR_IO_PENDING) {
521 NET_TRACE(1, " *** ") << "Pausing write at: " << sequence_number_;
522 write_state_ = PAUSED;
523 if (run_until_paused_run_loop_)
524 run_until_paused_run_loop_->Quit();
525 return MockWriteResult(SYNCHRONOUS, ERR_IO_PENDING);
526 }
527
528 NET_TRACE(1, " *** ") << "Posting task to complete write";
529 base::ThreadTaskRunnerHandle::Get()->PostTask(
530 FROM_HERE, base::BindOnce(&SequencedSocketData::OnWriteComplete,
531 weak_factory_.GetWeakPtr()));
532 CHECK_NE(COMPLETING, read_state_);
533 write_state_ = COMPLETING;
534 } else if (next_write.mode == SYNCHRONOUS) {
535 ADD_FAILURE() << "Unable to perform synchronous IO while stopped";
536 return MockWriteResult(SYNCHRONOUS, ERR_UNEXPECTED);
537 } else {
538 NET_TRACE(1, " *** ") << "Waiting for read to trigger write";
539 write_state_ = PENDING;
540 }
541
542 return MockWriteResult(SYNCHRONOUS, ERR_IO_PENDING);
543 }
544
AllReadDataConsumed() const545 bool SequencedSocketData::AllReadDataConsumed() const {
546 return helper_.AllReadDataConsumed();
547 }
548
CancelPendingRead()549 void SequencedSocketData::CancelPendingRead() {
550 DCHECK_EQ(PENDING, read_state_);
551
552 read_state_ = IDLE;
553 }
554
AllWriteDataConsumed() const555 bool SequencedSocketData::AllWriteDataConsumed() const {
556 return helper_.AllWriteDataConsumed();
557 }
558
IsIdle() const559 bool SequencedSocketData::IsIdle() const {
560 // If |busy_before_sync_reads_| is not set, always considered idle. If
561 // no reads left, or the next operation is a write, also consider it idle.
562 if (!busy_before_sync_reads_ || helper_.AllReadDataConsumed() ||
563 helper_.PeekRead().sequence_number != sequence_number_) {
564 return true;
565 }
566
567 // If the next operation is synchronous read, treat the socket as not idle.
568 if (helper_.PeekRead().mode == SYNCHRONOUS)
569 return false;
570 return true;
571 }
572
IsPaused() const573 bool SequencedSocketData::IsPaused() const {
574 // Both states should not be paused.
575 DCHECK(read_state_ != PAUSED || write_state_ != PAUSED);
576 return write_state_ == PAUSED || read_state_ == PAUSED;
577 }
578
Resume()579 void SequencedSocketData::Resume() {
580 if (!IsPaused()) {
581 ADD_FAILURE() << "Unable to Resume when not paused.";
582 return;
583 }
584
585 sequence_number_++;
586 if (read_state_ == PAUSED) {
587 read_state_ = PENDING;
588 helper_.AdvanceRead();
589 } else { // write_state_ == PAUSED
590 write_state_ = PENDING;
591 helper_.AdvanceWrite();
592 }
593
594 if (!helper_.AllWriteDataConsumed() &&
595 helper_.PeekWrite().sequence_number == sequence_number_) {
596 // The next event hasn't even started yet. Pausing isn't really needed in
597 // that case, but may as well support it.
598 if (write_state_ != PENDING)
599 return;
600 write_state_ = COMPLETING;
601 OnWriteComplete();
602 return;
603 }
604
605 CHECK(!helper_.AllReadDataConsumed());
606
607 // The next event hasn't even started yet. Pausing isn't really needed in
608 // that case, but may as well support it.
609 if (read_state_ != PENDING)
610 return;
611 read_state_ = COMPLETING;
612 OnReadComplete();
613 }
614
RunUntilPaused()615 void SequencedSocketData::RunUntilPaused() {
616 CHECK(!run_until_paused_run_loop_);
617
618 if (IsPaused())
619 return;
620
621 run_until_paused_run_loop_.reset(new base::RunLoop());
622 run_until_paused_run_loop_->Run();
623 run_until_paused_run_loop_.reset();
624 DCHECK(IsPaused());
625 }
626
MaybePostReadCompleteTask()627 void SequencedSocketData::MaybePostReadCompleteTask() {
628 NET_TRACE(1, " ****** ") << " current: " << sequence_number_;
629 // Only trigger the next read to complete if there is already a read pending
630 // which should complete at the current sequence number.
631 if (read_state_ != PENDING ||
632 helper_.PeekRead().sequence_number != sequence_number_) {
633 return;
634 }
635
636 // If the result is ERR_IO_PENDING, then pause.
637 if (helper_.PeekRead().result == ERR_IO_PENDING) {
638 NET_TRACE(1, " *** ") << "Pausing read at: " << sequence_number_;
639 read_state_ = PAUSED;
640 if (run_until_paused_run_loop_)
641 run_until_paused_run_loop_->Quit();
642 return;
643 }
644
645 NET_TRACE(1, " ****** ") << "Posting task to complete read: "
646 << sequence_number_;
647 base::ThreadTaskRunnerHandle::Get()->PostTask(
648 FROM_HERE, base::BindOnce(&SequencedSocketData::OnReadComplete,
649 weak_factory_.GetWeakPtr()));
650 CHECK_NE(COMPLETING, write_state_);
651 read_state_ = COMPLETING;
652 }
653
MaybePostWriteCompleteTask()654 void SequencedSocketData::MaybePostWriteCompleteTask() {
655 NET_TRACE(1, " ****** ") << " current: " << sequence_number_;
656 // Only trigger the next write to complete if there is already a write pending
657 // which should complete at the current sequence number.
658 if (write_state_ != PENDING ||
659 helper_.PeekWrite().sequence_number != sequence_number_) {
660 return;
661 }
662
663 // If the result is ERR_IO_PENDING, then pause.
664 if (helper_.PeekWrite().result == ERR_IO_PENDING) {
665 NET_TRACE(1, " *** ") << "Pausing write at: " << sequence_number_;
666 write_state_ = PAUSED;
667 if (run_until_paused_run_loop_)
668 run_until_paused_run_loop_->Quit();
669 return;
670 }
671
672 NET_TRACE(1, " ****** ") << "Posting task to complete write: "
673 << sequence_number_;
674 base::ThreadTaskRunnerHandle::Get()->PostTask(
675 FROM_HERE, base::BindOnce(&SequencedSocketData::OnWriteComplete,
676 weak_factory_.GetWeakPtr()));
677 CHECK_NE(COMPLETING, read_state_);
678 write_state_ = COMPLETING;
679 }
680
Reset()681 void SequencedSocketData::Reset() {
682 helper_.Reset();
683 sequence_number_ = 0;
684 read_state_ = IDLE;
685 write_state_ = IDLE;
686 weak_factory_.InvalidateWeakPtrs();
687 }
688
OnReadComplete()689 void SequencedSocketData::OnReadComplete() {
690 CHECK_EQ(COMPLETING, read_state_);
691 NET_TRACE(1, " *** ") << "Completing read for: " << sequence_number_;
692
693 MockRead data = helper_.AdvanceRead();
694 DCHECK_EQ(sequence_number_, data.sequence_number);
695 sequence_number_++;
696 read_state_ = IDLE;
697
698 // The result of this read completing might trigger the completion
699 // of a pending write. If so, post a task to complete the write later.
700 // Since the socket may call back into the SequencedSocketData
701 // from socket()->OnReadComplete(), trigger the write task to be posted
702 // before calling that.
703 MaybePostWriteCompleteTask();
704
705 if (!socket()) {
706 NET_TRACE(1, " *** ") << "No socket available to complete read";
707 return;
708 }
709
710 NET_TRACE(1, " *** ") << "Completing socket read for: "
711 << data.sequence_number;
712 DumpMockReadWrite(data);
713 socket()->OnReadComplete(data);
714 NET_TRACE(1, " *** ") << "Done";
715 }
716
OnWriteComplete()717 void SequencedSocketData::OnWriteComplete() {
718 CHECK_EQ(COMPLETING, write_state_);
719 NET_TRACE(1, " *** ") << " Completing write for: " << sequence_number_;
720
721 const MockWrite& data = helper_.AdvanceWrite();
722 DCHECK_EQ(sequence_number_, data.sequence_number);
723 sequence_number_++;
724 write_state_ = IDLE;
725 int rv = data.result == OK ? data.data_len : data.result;
726
727 // The result of this write completing might trigger the completion
728 // of a pending read. If so, post a task to complete the read later.
729 // Since the socket may call back into the SequencedSocketData
730 // from socket()->OnWriteComplete(), trigger the write task to be posted
731 // before calling that.
732 MaybePostReadCompleteTask();
733
734 if (!socket()) {
735 NET_TRACE(1, " *** ") << "No socket available to complete write";
736 return;
737 }
738
739 NET_TRACE(1, " *** ") << " Completing socket write for: "
740 << data.sequence_number;
741 socket()->OnWriteComplete(rv);
742 NET_TRACE(1, " *** ") << "Done";
743 }
744
745 SequencedSocketData::~SequencedSocketData() = default;
746
MockClientSocketFactory()747 MockClientSocketFactory::MockClientSocketFactory()
748 : enable_read_if_ready_(false) {}
749
750 MockClientSocketFactory::~MockClientSocketFactory() = default;
751
AddSocketDataProvider(SocketDataProvider * data)752 void MockClientSocketFactory::AddSocketDataProvider(
753 SocketDataProvider* data) {
754 mock_data_.Add(data);
755 }
756
AddSSLSocketDataProvider(SSLSocketDataProvider * data)757 void MockClientSocketFactory::AddSSLSocketDataProvider(
758 SSLSocketDataProvider* data) {
759 mock_ssl_data_.Add(data);
760 }
761
AddProxyClientSocketDataProvider(ProxyClientSocketDataProvider * data)762 void MockClientSocketFactory::AddProxyClientSocketDataProvider(
763 ProxyClientSocketDataProvider* data) {
764 mock_proxy_data_.Add(data);
765 }
766
ResetNextMockIndexes()767 void MockClientSocketFactory::ResetNextMockIndexes() {
768 mock_data_.ResetNextIndex();
769 mock_ssl_data_.ResetNextIndex();
770 }
771
772 std::unique_ptr<DatagramClientSocket>
CreateDatagramClientSocket(DatagramSocket::BindType bind_type,NetLog * net_log,const NetLogSource & source)773 MockClientSocketFactory::CreateDatagramClientSocket(
774 DatagramSocket::BindType bind_type,
775 NetLog* net_log,
776 const NetLogSource& source) {
777 SocketDataProvider* data_provider = mock_data_.GetNext();
778 std::unique_ptr<MockUDPClientSocket> socket(
779 new MockUDPClientSocket(data_provider, net_log));
780 if (bind_type == DatagramSocket::RANDOM_BIND)
781 socket->set_source_port(static_cast<uint16_t>(base::RandInt(1025, 65535)));
782 udp_client_socket_ports_.push_back(socket->source_port());
783 return std::move(socket);
784 }
785
786 std::unique_ptr<TransportClientSocket>
CreateTransportClientSocket(const AddressList & addresses,std::unique_ptr<SocketPerformanceWatcher> socket_performance_watcher,NetLog * net_log,const NetLogSource & source)787 MockClientSocketFactory::CreateTransportClientSocket(
788 const AddressList& addresses,
789 std::unique_ptr<SocketPerformanceWatcher> socket_performance_watcher,
790 NetLog* net_log,
791 const NetLogSource& source) {
792 SocketDataProvider* data_provider = mock_data_.GetNext();
793 std::unique_ptr<MockTCPClientSocket> socket(
794 new MockTCPClientSocket(addresses, net_log, data_provider));
795 if (enable_read_if_ready_)
796 socket->set_enable_read_if_ready(enable_read_if_ready_);
797 return std::move(socket);
798 }
799
CreateSSLClientSocket(SSLClientContext * context,std::unique_ptr<StreamSocket> stream_socket,const HostPortPair & host_and_port,const SSLConfig & ssl_config)800 std::unique_ptr<SSLClientSocket> MockClientSocketFactory::CreateSSLClientSocket(
801 SSLClientContext* context,
802 std::unique_ptr<StreamSocket> stream_socket,
803 const HostPortPair& host_and_port,
804 const SSLConfig& ssl_config) {
805 SSLSocketDataProvider* next_ssl_data = mock_ssl_data_.GetNext();
806 if (next_ssl_data->next_protos_expected_in_ssl_config.has_value()) {
807 EXPECT_EQ(next_ssl_data->next_protos_expected_in_ssl_config.value().size(),
808 ssl_config.alpn_protos.size());
809 EXPECT_TRUE(std::equal(
810 next_ssl_data->next_protos_expected_in_ssl_config.value().begin(),
811 next_ssl_data->next_protos_expected_in_ssl_config.value().end(),
812 ssl_config.alpn_protos.begin()));
813 }
814
815 // The protocol version used is a combination of the per-socket SSLConfig and
816 // the SSLConfigService.
817 EXPECT_EQ(
818 next_ssl_data->expected_ssl_version_min,
819 ssl_config.version_min_override.value_or(context->config().version_min));
820 EXPECT_EQ(
821 next_ssl_data->expected_ssl_version_max,
822 ssl_config.version_max_override.value_or(context->config().version_max));
823
824 if (next_ssl_data->expected_send_client_cert) {
825 // Client certificate preferences come from |context|.
826 scoped_refptr<X509Certificate> client_cert;
827 scoped_refptr<SSLPrivateKey> client_private_key;
828 bool send_client_cert = context->GetClientCertificate(
829 host_and_port, &client_cert, &client_private_key);
830
831 EXPECT_EQ(*next_ssl_data->expected_send_client_cert, send_client_cert);
832 // Note |send_client_cert| may be true while |client_cert| is null if the
833 // socket is configured to continue without a certificate, as opposed to
834 // surfacing the certificate challenge.
835 EXPECT_EQ(!!next_ssl_data->expected_client_cert, !!client_cert);
836 if (next_ssl_data->expected_client_cert && client_cert) {
837 EXPECT_TRUE(next_ssl_data->expected_client_cert->EqualsIncludingChain(
838 client_cert.get()));
839 }
840 }
841 if (next_ssl_data->expected_host_and_port) {
842 EXPECT_EQ(*next_ssl_data->expected_host_and_port, host_and_port);
843 }
844 if (next_ssl_data->expected_network_isolation_key) {
845 EXPECT_EQ(*next_ssl_data->expected_network_isolation_key,
846 ssl_config.network_isolation_key);
847 }
848 if (next_ssl_data->expected_disable_legacy_crypto) {
849 EXPECT_EQ(*next_ssl_data->expected_disable_legacy_crypto,
850 ssl_config.disable_legacy_crypto);
851 }
852 return std::unique_ptr<SSLClientSocket>(new MockSSLClientSocket(
853 std::move(stream_socket), host_and_port, ssl_config, next_ssl_data));
854 }
855
856 std::unique_ptr<ProxyClientSocket>
CreateProxyClientSocket(std::unique_ptr<StreamSocket> stream_socket,const std::string & user_agent,const HostPortPair & endpoint,const ProxyServer & proxy_server,HttpAuthController * http_auth_controller,bool tunnel,bool using_spdy,NextProto negotiated_protocol,ProxyDelegate * proxy_delegate,const NetworkTrafficAnnotationTag & traffic_annotation)857 MockClientSocketFactory::CreateProxyClientSocket(
858 std::unique_ptr<StreamSocket> stream_socket,
859 const std::string& user_agent,
860 const HostPortPair& endpoint,
861 const ProxyServer& proxy_server,
862 HttpAuthController* http_auth_controller,
863 bool tunnel,
864 bool using_spdy,
865 NextProto negotiated_protocol,
866 ProxyDelegate* proxy_delegate,
867 const NetworkTrafficAnnotationTag& traffic_annotation) {
868 if (use_mock_proxy_client_sockets_) {
869 ProxyClientSocketDataProvider* next_proxy_data = mock_proxy_data_.GetNext();
870 return std::make_unique<MockProxyClientSocket>(
871 std::move(stream_socket), http_auth_controller, next_proxy_data);
872 } else {
873 return GetDefaultFactory()->CreateProxyClientSocket(
874 std::move(stream_socket), user_agent, endpoint, proxy_server,
875 http_auth_controller, tunnel, using_spdy, negotiated_protocol,
876 proxy_delegate, traffic_annotation);
877 }
878 }
879
MockClientSocket(const NetLogWithSource & net_log)880 MockClientSocket::MockClientSocket(const NetLogWithSource& net_log)
881 : connected_(false), net_log_(net_log) {
882 local_addr_ = IPEndPoint(IPAddress(192, 0, 2, 33), 123);
883 peer_addr_ = IPEndPoint(IPAddress(192, 0, 2, 33), 0);
884 }
885
SetReceiveBufferSize(int32_t size)886 int MockClientSocket::SetReceiveBufferSize(int32_t size) {
887 return OK;
888 }
889
SetSendBufferSize(int32_t size)890 int MockClientSocket::SetSendBufferSize(int32_t size) {
891 return OK;
892 }
893
Bind(const net::IPEndPoint & local_addr)894 int MockClientSocket::Bind(const net::IPEndPoint& local_addr) {
895 local_addr_ = local_addr;
896 return net::OK;
897 }
898
SetNoDelay(bool no_delay)899 bool MockClientSocket::SetNoDelay(bool no_delay) {
900 return true;
901 }
902
SetKeepAlive(bool enable,int delay)903 bool MockClientSocket::SetKeepAlive(bool enable, int delay) {
904 return true;
905 }
906
Disconnect()907 void MockClientSocket::Disconnect() {
908 connected_ = false;
909 }
910
IsConnected() const911 bool MockClientSocket::IsConnected() const {
912 return connected_;
913 }
914
IsConnectedAndIdle() const915 bool MockClientSocket::IsConnectedAndIdle() const {
916 return connected_;
917 }
918
GetPeerAddress(IPEndPoint * address) const919 int MockClientSocket::GetPeerAddress(IPEndPoint* address) const {
920 if (!IsConnected())
921 return ERR_SOCKET_NOT_CONNECTED;
922 *address = peer_addr_;
923 return OK;
924 }
925
GetLocalAddress(IPEndPoint * address) const926 int MockClientSocket::GetLocalAddress(IPEndPoint* address) const {
927 *address = local_addr_;
928 return OK;
929 }
930
NetLog() const931 const NetLogWithSource& MockClientSocket::NetLog() const {
932 return net_log_;
933 }
934
WasAlpnNegotiated() const935 bool MockClientSocket::WasAlpnNegotiated() const {
936 return false;
937 }
938
GetNegotiatedProtocol() const939 NextProto MockClientSocket::GetNegotiatedProtocol() const {
940 return kProtoUnknown;
941 }
942
GetConnectionAttempts(ConnectionAttempts * out) const943 void MockClientSocket::GetConnectionAttempts(ConnectionAttempts* out) const {
944 out->clear();
945 }
946
947 MockClientSocket::~MockClientSocket() = default;
948
RunCallbackAsync(CompletionOnceCallback callback,int result)949 void MockClientSocket::RunCallbackAsync(CompletionOnceCallback callback,
950 int result) {
951 base::ThreadTaskRunnerHandle::Get()->PostTask(
952 FROM_HERE,
953 base::BindOnce(&MockClientSocket::RunCallback, weak_factory_.GetWeakPtr(),
954 std::move(callback), result));
955 }
956
RunCallback(CompletionOnceCallback callback,int result)957 void MockClientSocket::RunCallback(CompletionOnceCallback callback,
958 int result) {
959 std::move(callback).Run(result);
960 }
961
MockTCPClientSocket(const AddressList & addresses,net::NetLog * net_log,SocketDataProvider * data)962 MockTCPClientSocket::MockTCPClientSocket(const AddressList& addresses,
963 net::NetLog* net_log,
964 SocketDataProvider* data)
965 : MockClientSocket(NetLogWithSource::Make(net_log, NetLogSourceType::NONE)),
966 addresses_(addresses),
967 data_(data),
968 read_offset_(0),
969 read_data_(SYNCHRONOUS, ERR_UNEXPECTED),
970 need_read_data_(true),
971 peer_closed_connection_(false),
972 pending_read_buf_(nullptr),
973 pending_read_buf_len_(0),
974 was_used_to_convey_data_(false),
975 enable_read_if_ready_(false) {
976 DCHECK(data_);
977 peer_addr_ = data->connect_data().peer_addr;
978 data_->Initialize(this);
979 }
980
~MockTCPClientSocket()981 MockTCPClientSocket::~MockTCPClientSocket() {
982 if (data_)
983 data_->DetachSocket();
984 }
985
Read(IOBuffer * buf,int buf_len,CompletionOnceCallback callback)986 int MockTCPClientSocket::Read(IOBuffer* buf,
987 int buf_len,
988 CompletionOnceCallback callback) {
989 // If the buffer is already in use, a read is already in progress!
990 DCHECK(!pending_read_buf_);
991 // Use base::Unretained() is safe because MockClientSocket::RunCallbackAsync()
992 // takes a weak ptr of the base class, MockClientSocket.
993 int rv = ReadIfReadyImpl(
994 buf, buf_len,
995 base::BindOnce(&MockTCPClientSocket::RetryRead, base::Unretained(this)));
996 if (rv == ERR_IO_PENDING) {
997 DCHECK(callback);
998
999 pending_read_buf_ = buf;
1000 pending_read_buf_len_ = buf_len;
1001 pending_read_callback_ = std::move(callback);
1002 }
1003 return rv;
1004 }
1005
ReadIfReady(IOBuffer * buf,int buf_len,CompletionOnceCallback callback)1006 int MockTCPClientSocket::ReadIfReady(IOBuffer* buf,
1007 int buf_len,
1008 CompletionOnceCallback callback) {
1009 DCHECK(!pending_read_if_ready_callback_);
1010
1011 if (!enable_read_if_ready_)
1012 return ERR_READ_IF_READY_NOT_IMPLEMENTED;
1013 return ReadIfReadyImpl(buf, buf_len, std::move(callback));
1014 }
1015
CancelReadIfReady()1016 int MockTCPClientSocket::CancelReadIfReady() {
1017 DCHECK(pending_read_if_ready_callback_);
1018
1019 pending_read_if_ready_callback_.Reset();
1020 data_->CancelPendingRead();
1021 return OK;
1022 }
1023
Write(IOBuffer * buf,int buf_len,CompletionOnceCallback callback,const NetworkTrafficAnnotationTag &)1024 int MockTCPClientSocket::Write(
1025 IOBuffer* buf,
1026 int buf_len,
1027 CompletionOnceCallback callback,
1028 const NetworkTrafficAnnotationTag& /* traffic_annotation */) {
1029 DCHECK(buf);
1030 DCHECK_GT(buf_len, 0);
1031
1032 if (!connected_ || !data_)
1033 return ERR_UNEXPECTED;
1034
1035 std::string data(buf->data(), buf_len);
1036 MockWriteResult write_result = data_->OnWrite(data);
1037
1038 was_used_to_convey_data_ = true;
1039
1040 if (write_result.result == ERR_CONNECTION_CLOSED) {
1041 // This MockWrite is just a marker to instruct us to set
1042 // peer_closed_connection_.
1043 peer_closed_connection_ = true;
1044 }
1045 // ERR_IO_PENDING is a signal that the socket data will call back
1046 // asynchronously later.
1047 if (write_result.result == ERR_IO_PENDING) {
1048 pending_write_callback_ = std::move(callback);
1049 return ERR_IO_PENDING;
1050 }
1051
1052 if (write_result.mode == ASYNC) {
1053 RunCallbackAsync(std::move(callback), write_result.result);
1054 return ERR_IO_PENDING;
1055 }
1056
1057 return write_result.result;
1058 }
1059
SetReceiveBufferSize(int32_t size)1060 int MockTCPClientSocket::SetReceiveBufferSize(int32_t size) {
1061 if (!connected_)
1062 return net::ERR_UNEXPECTED;
1063 data_->set_receive_buffer_size(size);
1064 return data_->set_receive_buffer_size_result();
1065 }
1066
SetSendBufferSize(int32_t size)1067 int MockTCPClientSocket::SetSendBufferSize(int32_t size) {
1068 if (!connected_)
1069 return net::ERR_UNEXPECTED;
1070 data_->set_send_buffer_size(size);
1071 return data_->set_send_buffer_size_result();
1072 }
1073
SetNoDelay(bool no_delay)1074 bool MockTCPClientSocket::SetNoDelay(bool no_delay) {
1075 if (!connected_)
1076 return false;
1077 data_->set_no_delay(no_delay);
1078 return data_->set_no_delay_result();
1079 }
1080
SetKeepAlive(bool enable,int delay)1081 bool MockTCPClientSocket::SetKeepAlive(bool enable, int delay) {
1082 if (!connected_)
1083 return false;
1084 data_->set_keep_alive(enable, delay);
1085 return data_->set_keep_alive_result();
1086 }
1087
GetConnectionAttempts(ConnectionAttempts * out) const1088 void MockTCPClientSocket::GetConnectionAttempts(ConnectionAttempts* out) const {
1089 *out = connection_attempts_;
1090 }
1091
ClearConnectionAttempts()1092 void MockTCPClientSocket::ClearConnectionAttempts() {
1093 connection_attempts_.clear();
1094 }
1095
AddConnectionAttempts(const ConnectionAttempts & attempts)1096 void MockTCPClientSocket::AddConnectionAttempts(
1097 const ConnectionAttempts& attempts) {
1098 connection_attempts_.insert(connection_attempts_.begin(), attempts.begin(),
1099 attempts.end());
1100 }
1101
SetBeforeConnectCallback(const BeforeConnectCallback & before_connect_callback)1102 void MockTCPClientSocket::SetBeforeConnectCallback(
1103 const BeforeConnectCallback& before_connect_callback) {
1104 DCHECK(!before_connect_callback_);
1105 DCHECK(!connected_);
1106
1107 before_connect_callback_ = before_connect_callback;
1108 }
1109
Connect(CompletionOnceCallback callback)1110 int MockTCPClientSocket::Connect(CompletionOnceCallback callback) {
1111 if (!data_)
1112 return ERR_UNEXPECTED;
1113
1114 if (connected_)
1115 return OK;
1116
1117 // Setting socket options fails if not connected, so need to set this before
1118 // calling |before_connect_callback_|.
1119 connected_ = true;
1120
1121 if (before_connect_callback_) {
1122 int result = before_connect_callback_.Run();
1123 DCHECK_NE(result, ERR_IO_PENDING);
1124 if (result != net::OK) {
1125 connected_ = false;
1126 return result;
1127 }
1128 }
1129
1130 peer_closed_connection_ = false;
1131
1132 int result = data_->connect_data().result;
1133 IoMode mode = data_->connect_data().mode;
1134
1135 if (result != OK && result != ERR_IO_PENDING) {
1136 IPEndPoint address;
1137 if (GetPeerAddress(&address) == OK)
1138 connection_attempts_.push_back(ConnectionAttempt(address, result));
1139 }
1140
1141 if (mode == SYNCHRONOUS)
1142 return result;
1143
1144 DCHECK(callback);
1145
1146 if (result == ERR_IO_PENDING)
1147 pending_connect_callback_ = std::move(callback);
1148 else
1149 RunCallbackAsync(std::move(callback), result);
1150 return ERR_IO_PENDING;
1151 }
1152
Disconnect()1153 void MockTCPClientSocket::Disconnect() {
1154 MockClientSocket::Disconnect();
1155 pending_connect_callback_.Reset();
1156 pending_read_callback_.Reset();
1157 }
1158
IsConnected() const1159 bool MockTCPClientSocket::IsConnected() const {
1160 if (!data_)
1161 return false;
1162 return connected_ && !peer_closed_connection_;
1163 }
1164
IsConnectedAndIdle() const1165 bool MockTCPClientSocket::IsConnectedAndIdle() const {
1166 if (!data_)
1167 return false;
1168 return IsConnected() && data_->IsIdle();
1169 }
1170
GetPeerAddress(IPEndPoint * address) const1171 int MockTCPClientSocket::GetPeerAddress(IPEndPoint* address) const {
1172 if (addresses_.empty())
1173 return MockClientSocket::GetPeerAddress(address);
1174
1175 *address = addresses_[0];
1176 return OK;
1177 }
1178
WasEverUsed() const1179 bool MockTCPClientSocket::WasEverUsed() const {
1180 return was_used_to_convey_data_;
1181 }
1182
GetSSLInfo(SSLInfo * ssl_info)1183 bool MockTCPClientSocket::GetSSLInfo(SSLInfo* ssl_info) {
1184 return false;
1185 }
1186
OnReadComplete(const MockRead & data)1187 void MockTCPClientSocket::OnReadComplete(const MockRead& data) {
1188 // If |data_| has been destroyed, safest to just do nothing.
1189 if (!data_)
1190 return;
1191
1192 // There must be a read pending.
1193 DCHECK(pending_read_if_ready_callback_);
1194 // You can't complete a read with another ERR_IO_PENDING status code.
1195 DCHECK_NE(ERR_IO_PENDING, data.result);
1196 // Since we've been waiting for data, need_read_data_ should be true.
1197 DCHECK(need_read_data_);
1198
1199 read_data_ = data;
1200 need_read_data_ = false;
1201
1202 // The caller is simulating that this IO completes right now. Don't
1203 // let CompleteRead() schedule a callback.
1204 read_data_.mode = SYNCHRONOUS;
1205 RunCallback(std::move(pending_read_if_ready_callback_),
1206 read_data_.result > 0 ? OK : read_data_.result);
1207 }
1208
OnWriteComplete(int rv)1209 void MockTCPClientSocket::OnWriteComplete(int rv) {
1210 // If |data_| has been destroyed, safest to just do nothing.
1211 if (!data_)
1212 return;
1213
1214 // There must be a read pending.
1215 DCHECK(!pending_write_callback_.is_null());
1216 RunCallback(std::move(pending_write_callback_), rv);
1217 }
1218
OnConnectComplete(const MockConnect & data)1219 void MockTCPClientSocket::OnConnectComplete(const MockConnect& data) {
1220 // If |data_| has been destroyed, safest to just do nothing.
1221 if (!data_)
1222 return;
1223
1224 RunCallback(std::move(pending_connect_callback_), data.result);
1225 }
1226
OnDataProviderDestroyed()1227 void MockTCPClientSocket::OnDataProviderDestroyed() {
1228 data_ = nullptr;
1229 }
1230
RetryRead(int rv)1231 void MockTCPClientSocket::RetryRead(int rv) {
1232 DCHECK(pending_read_callback_);
1233 DCHECK(pending_read_buf_.get());
1234 DCHECK_LT(0, pending_read_buf_len_);
1235
1236 if (rv == OK) {
1237 rv = ReadIfReadyImpl(pending_read_buf_.get(), pending_read_buf_len_,
1238 base::BindOnce(&MockTCPClientSocket::RetryRead,
1239 base::Unretained(this)));
1240 if (rv == ERR_IO_PENDING)
1241 return;
1242 }
1243 pending_read_buf_ = nullptr;
1244 pending_read_buf_len_ = 0;
1245 RunCallback(std::move(pending_read_callback_), rv);
1246 }
1247
ReadIfReadyImpl(IOBuffer * buf,int buf_len,CompletionOnceCallback callback)1248 int MockTCPClientSocket::ReadIfReadyImpl(IOBuffer* buf,
1249 int buf_len,
1250 CompletionOnceCallback callback) {
1251 if (!connected_ || !data_)
1252 return ERR_UNEXPECTED;
1253
1254 DCHECK(!pending_read_if_ready_callback_);
1255
1256 if (need_read_data_) {
1257 read_data_ = data_->OnRead();
1258 if (read_data_.result == ERR_CONNECTION_CLOSED) {
1259 // This MockRead is just a marker to instruct us to set
1260 // peer_closed_connection_.
1261 peer_closed_connection_ = true;
1262 }
1263 if (read_data_.result == ERR_TEST_PEER_CLOSE_AFTER_NEXT_MOCK_READ) {
1264 // This MockRead is just a marker to instruct us to set
1265 // peer_closed_connection_. Skip it and get the next one.
1266 read_data_ = data_->OnRead();
1267 peer_closed_connection_ = true;
1268 }
1269 // ERR_IO_PENDING means that the SocketDataProvider is taking responsibility
1270 // to complete the async IO manually later (via OnReadComplete).
1271 if (read_data_.result == ERR_IO_PENDING) {
1272 // We need to be using async IO in this case.
1273 DCHECK(!callback.is_null());
1274 pending_read_if_ready_callback_ = std::move(callback);
1275 return ERR_IO_PENDING;
1276 }
1277 need_read_data_ = false;
1278 }
1279
1280 int result = read_data_.result;
1281 DCHECK_NE(ERR_IO_PENDING, result);
1282 if (read_data_.mode == ASYNC) {
1283 DCHECK(!callback.is_null());
1284 read_data_.mode = SYNCHRONOUS;
1285 pending_read_if_ready_callback_ = std::move(callback);
1286 // base::Unretained() is safe here because RunCallbackAsync will wrap it
1287 // with a callback associated with a weak ptr.
1288 RunCallbackAsync(
1289 base::BindOnce(&MockTCPClientSocket::RunReadIfReadyCallback,
1290 base::Unretained(this)),
1291 result);
1292 return ERR_IO_PENDING;
1293 }
1294
1295 was_used_to_convey_data_ = true;
1296 if (read_data_.data) {
1297 if (read_data_.data_len - read_offset_ > 0) {
1298 result = std::min(buf_len, read_data_.data_len - read_offset_);
1299 memcpy(buf->data(), read_data_.data + read_offset_, result);
1300 read_offset_ += result;
1301 if (read_offset_ == read_data_.data_len) {
1302 need_read_data_ = true;
1303 read_offset_ = 0;
1304 }
1305 } else {
1306 result = 0; // EOF
1307 }
1308 }
1309 return result;
1310 }
1311
RunReadIfReadyCallback(int result)1312 void MockTCPClientSocket::RunReadIfReadyCallback(int result) {
1313 // If ReadIfReady is already canceled, do nothing.
1314 if (!pending_read_if_ready_callback_)
1315 return;
1316 std::move(pending_read_if_ready_callback_).Run(result);
1317 }
1318
MockProxyClientSocket(std::unique_ptr<StreamSocket> socket,HttpAuthController * auth_controller,ProxyClientSocketDataProvider * data)1319 MockProxyClientSocket::MockProxyClientSocket(
1320 std::unique_ptr<StreamSocket> socket,
1321 HttpAuthController* auth_controller,
1322 ProxyClientSocketDataProvider* data)
1323 : net_log_(socket->NetLog()),
1324 socket_(std::move(socket)),
1325 data_(data),
1326 auth_controller_(auth_controller) {
1327 DCHECK(data_);
1328 }
1329
~MockProxyClientSocket()1330 MockProxyClientSocket::~MockProxyClientSocket() {
1331 Disconnect();
1332 }
1333
GetConnectResponseInfo() const1334 const HttpResponseInfo* MockProxyClientSocket::GetConnectResponseInfo() const {
1335 return nullptr;
1336 }
1337
1338 const scoped_refptr<HttpAuthController>&
GetAuthController() const1339 MockProxyClientSocket::GetAuthController() const {
1340 return auth_controller_;
1341 }
1342
RestartWithAuth(CompletionOnceCallback callback)1343 int MockProxyClientSocket::RestartWithAuth(CompletionOnceCallback callback) {
1344 return net::ERR_NOT_IMPLEMENTED;
1345 }
IsUsingSpdy() const1346 bool MockProxyClientSocket::IsUsingSpdy() const {
1347 return false;
1348 }
1349
GetProxyNegotiatedProtocol() const1350 NextProto MockProxyClientSocket::GetProxyNegotiatedProtocol() const {
1351 return kProtoUnknown;
1352 }
1353
Read(IOBuffer * buf,int buf_len,CompletionOnceCallback callback)1354 int MockProxyClientSocket::Read(IOBuffer* buf,
1355 int buf_len,
1356 CompletionOnceCallback callback) {
1357 return socket_->Read(buf, buf_len, std::move(callback));
1358 }
1359
ReadIfReady(IOBuffer * buf,int buf_len,CompletionOnceCallback callback)1360 int MockProxyClientSocket::ReadIfReady(IOBuffer* buf,
1361 int buf_len,
1362 CompletionOnceCallback callback) {
1363 return socket_->ReadIfReady(buf, buf_len, std::move(callback));
1364 }
1365
Write(IOBuffer * buf,int buf_len,CompletionOnceCallback callback,const NetworkTrafficAnnotationTag & traffic_annotation)1366 int MockProxyClientSocket::Write(
1367 IOBuffer* buf,
1368 int buf_len,
1369 CompletionOnceCallback callback,
1370 const NetworkTrafficAnnotationTag& traffic_annotation) {
1371 return socket_->Write(buf, buf_len, std::move(callback), traffic_annotation);
1372 }
1373
Connect(CompletionOnceCallback callback)1374 int MockProxyClientSocket::Connect(CompletionOnceCallback callback) {
1375 DCHECK(socket_->IsConnected());
1376 if (data_->connect.mode == ASYNC) {
1377 RunCallbackAsync(std::move(callback), data_->connect.result);
1378 return ERR_IO_PENDING;
1379 }
1380 return data_->connect.result;
1381 }
1382
Disconnect()1383 void MockProxyClientSocket::Disconnect() {
1384 if (socket_)
1385 socket_->Disconnect();
1386 }
1387
IsConnected() const1388 bool MockProxyClientSocket::IsConnected() const {
1389 return socket_->IsConnected();
1390 }
1391
IsConnectedAndIdle() const1392 bool MockProxyClientSocket::IsConnectedAndIdle() const {
1393 return socket_->IsConnectedAndIdle();
1394 }
1395
WasEverUsed() const1396 bool MockProxyClientSocket::WasEverUsed() const {
1397 return socket_->WasEverUsed();
1398 }
1399
GetLocalAddress(IPEndPoint * address) const1400 int MockProxyClientSocket::GetLocalAddress(IPEndPoint* address) const {
1401 *address = IPEndPoint(IPAddress(192, 0, 2, 33), 123);
1402 return OK;
1403 }
1404
GetPeerAddress(IPEndPoint * address) const1405 int MockProxyClientSocket::GetPeerAddress(IPEndPoint* address) const {
1406 return socket_->GetPeerAddress(address);
1407 }
1408
WasAlpnNegotiated() const1409 bool MockProxyClientSocket::WasAlpnNegotiated() const {
1410 return false;
1411 }
1412
GetNegotiatedProtocol() const1413 NextProto MockProxyClientSocket::GetNegotiatedProtocol() const {
1414 NOTIMPLEMENTED();
1415 return kProtoUnknown;
1416 }
1417
GetSSLInfo(SSLInfo * requested_ssl_info)1418 bool MockProxyClientSocket::GetSSLInfo(SSLInfo* requested_ssl_info) {
1419 NOTIMPLEMENTED();
1420 return false;
1421 }
1422
ApplySocketTag(const SocketTag & tag)1423 void MockProxyClientSocket::ApplySocketTag(const SocketTag& tag) {
1424 return socket_->ApplySocketTag(tag);
1425 }
1426
NetLog() const1427 const NetLogWithSource& MockProxyClientSocket::NetLog() const {
1428 return net_log_;
1429 }
1430
GetConnectionAttempts(ConnectionAttempts * out) const1431 void MockProxyClientSocket::GetConnectionAttempts(
1432 ConnectionAttempts* out) const {
1433 NOTIMPLEMENTED();
1434 out->clear();
1435 }
1436
GetTotalReceivedBytes() const1437 int64_t MockProxyClientSocket::GetTotalReceivedBytes() const {
1438 NOTIMPLEMENTED();
1439 return 0;
1440 }
1441
SetReceiveBufferSize(int32_t size)1442 int MockProxyClientSocket::SetReceiveBufferSize(int32_t size) {
1443 return OK;
1444 }
1445
SetSendBufferSize(int32_t size)1446 int MockProxyClientSocket::SetSendBufferSize(int32_t size) {
1447 return OK;
1448 }
1449
OnReadComplete(const MockRead & data)1450 void MockProxyClientSocket::OnReadComplete(const MockRead& data) {
1451 NOTIMPLEMENTED();
1452 }
1453
OnWriteComplete(int rv)1454 void MockProxyClientSocket::OnWriteComplete(int rv) {
1455 NOTIMPLEMENTED();
1456 }
1457
OnConnectComplete(const MockConnect & data)1458 void MockProxyClientSocket::OnConnectComplete(const MockConnect& data) {
1459 NOTIMPLEMENTED();
1460 }
1461
RunCallbackAsync(CompletionOnceCallback callback,int result)1462 void MockProxyClientSocket::RunCallbackAsync(CompletionOnceCallback callback,
1463 int result) {
1464 base::ThreadTaskRunnerHandle::Get()->PostTask(
1465 FROM_HERE,
1466 base::BindOnce(&MockProxyClientSocket::RunCallback,
1467 weak_factory_.GetWeakPtr(), std::move(callback), result));
1468 }
1469
RunCallback(CompletionOnceCallback callback,int result)1470 void MockProxyClientSocket::RunCallback(CompletionOnceCallback callback,
1471 int result) {
1472 std::move(callback).Run(result);
1473 }
1474
1475 // static
ConnectCallback(MockSSLClientSocket * ssl_client_socket,CompletionOnceCallback callback,int rv)1476 void MockSSLClientSocket::ConnectCallback(
1477 MockSSLClientSocket* ssl_client_socket,
1478 CompletionOnceCallback callback,
1479 int rv) {
1480 if (rv == OK)
1481 ssl_client_socket->connected_ = true;
1482 std::move(callback).Run(rv);
1483 }
1484
MockSSLClientSocket(std::unique_ptr<StreamSocket> stream_socket,const HostPortPair & host_and_port,const SSLConfig & ssl_config,SSLSocketDataProvider * data)1485 MockSSLClientSocket::MockSSLClientSocket(
1486 std::unique_ptr<StreamSocket> stream_socket,
1487 const HostPortPair& host_and_port,
1488 const SSLConfig& ssl_config,
1489 SSLSocketDataProvider* data)
1490 : net_log_(stream_socket->NetLog()),
1491 stream_socket_(std::move(stream_socket)),
1492 data_(data) {
1493 DCHECK(data_);
1494 peer_addr_ = data->connect.peer_addr;
1495 }
1496
~MockSSLClientSocket()1497 MockSSLClientSocket::~MockSSLClientSocket() {
1498 Disconnect();
1499 }
1500
Read(IOBuffer * buf,int buf_len,CompletionOnceCallback callback)1501 int MockSSLClientSocket::Read(IOBuffer* buf,
1502 int buf_len,
1503 CompletionOnceCallback callback) {
1504 return stream_socket_->Read(buf, buf_len, std::move(callback));
1505 }
1506
ReadIfReady(IOBuffer * buf,int buf_len,CompletionOnceCallback callback)1507 int MockSSLClientSocket::ReadIfReady(IOBuffer* buf,
1508 int buf_len,
1509 CompletionOnceCallback callback) {
1510 return stream_socket_->ReadIfReady(buf, buf_len, std::move(callback));
1511 }
1512
Write(IOBuffer * buf,int buf_len,CompletionOnceCallback callback,const NetworkTrafficAnnotationTag & traffic_annotation)1513 int MockSSLClientSocket::Write(
1514 IOBuffer* buf,
1515 int buf_len,
1516 CompletionOnceCallback callback,
1517 const NetworkTrafficAnnotationTag& traffic_annotation) {
1518 if (!data_->is_confirm_data_consumed)
1519 data_->write_called_before_confirm = true;
1520 return stream_socket_->Write(buf, buf_len, std::move(callback),
1521 traffic_annotation);
1522 }
1523
CancelReadIfReady()1524 int MockSSLClientSocket::CancelReadIfReady() {
1525 return stream_socket_->CancelReadIfReady();
1526 }
1527
Connect(CompletionOnceCallback callback)1528 int MockSSLClientSocket::Connect(CompletionOnceCallback callback) {
1529 DCHECK(stream_socket_->IsConnected());
1530 data_->is_connect_data_consumed = true;
1531 if (data_->connect.result == OK)
1532 connected_ = true;
1533 RunClosureIfNonNull(std::move(data_->connect_callback));
1534 if (data_->connect.mode == ASYNC) {
1535 RunCallbackAsync(std::move(callback), data_->connect.result);
1536 return ERR_IO_PENDING;
1537 }
1538 return data_->connect.result;
1539 }
1540
Disconnect()1541 void MockSSLClientSocket::Disconnect() {
1542 if (stream_socket_ != nullptr)
1543 stream_socket_->Disconnect();
1544 }
1545
RunConfirmHandshakeCallback(CompletionOnceCallback callback,int result)1546 void MockSSLClientSocket::RunConfirmHandshakeCallback(
1547 CompletionOnceCallback callback,
1548 int result) {
1549 data_->is_confirm_data_consumed = true;
1550 std::move(callback).Run(result);
1551 }
1552
ConfirmHandshake(CompletionOnceCallback callback)1553 int MockSSLClientSocket::ConfirmHandshake(CompletionOnceCallback callback) {
1554 DCHECK(stream_socket_->IsConnected());
1555 if (data_->is_confirm_data_consumed)
1556 return data_->confirm.result;
1557 RunClosureIfNonNull(std::move(data_->confirm_callback));
1558 if (data_->confirm.mode == ASYNC) {
1559 RunCallbackAsync(
1560 base::BindOnce(&MockSSLClientSocket::RunConfirmHandshakeCallback,
1561 base::Unretained(this), std::move(callback)),
1562 data_->confirm.result);
1563 return ERR_IO_PENDING;
1564 }
1565 data_->is_confirm_data_consumed = true;
1566 return data_->confirm.result;
1567 }
1568
IsConnected() const1569 bool MockSSLClientSocket::IsConnected() const {
1570 return stream_socket_->IsConnected();
1571 }
1572
IsConnectedAndIdle() const1573 bool MockSSLClientSocket::IsConnectedAndIdle() const {
1574 return stream_socket_->IsConnectedAndIdle();
1575 }
1576
WasEverUsed() const1577 bool MockSSLClientSocket::WasEverUsed() const {
1578 return stream_socket_->WasEverUsed();
1579 }
1580
GetLocalAddress(IPEndPoint * address) const1581 int MockSSLClientSocket::GetLocalAddress(IPEndPoint* address) const {
1582 *address = IPEndPoint(IPAddress(192, 0, 2, 33), 123);
1583 return OK;
1584 }
1585
GetPeerAddress(IPEndPoint * address) const1586 int MockSSLClientSocket::GetPeerAddress(IPEndPoint* address) const {
1587 return stream_socket_->GetPeerAddress(address);
1588 }
1589
WasAlpnNegotiated() const1590 bool MockSSLClientSocket::WasAlpnNegotiated() const {
1591 return data_->next_proto != kProtoUnknown;
1592 }
1593
GetNegotiatedProtocol() const1594 NextProto MockSSLClientSocket::GetNegotiatedProtocol() const {
1595 return data_->next_proto;
1596 }
1597
GetSSLInfo(SSLInfo * requested_ssl_info)1598 bool MockSSLClientSocket::GetSSLInfo(SSLInfo* requested_ssl_info) {
1599 requested_ssl_info->Reset();
1600 *requested_ssl_info = data_->ssl_info;
1601 return true;
1602 }
1603
ApplySocketTag(const SocketTag & tag)1604 void MockSSLClientSocket::ApplySocketTag(const SocketTag& tag) {
1605 return stream_socket_->ApplySocketTag(tag);
1606 }
1607
NetLog() const1608 const NetLogWithSource& MockSSLClientSocket::NetLog() const {
1609 return net_log_;
1610 }
1611
GetConnectionAttempts(ConnectionAttempts * out) const1612 void MockSSLClientSocket::GetConnectionAttempts(ConnectionAttempts* out) const {
1613 out->clear();
1614 }
1615
GetTotalReceivedBytes() const1616 int64_t MockSSLClientSocket::GetTotalReceivedBytes() const {
1617 NOTIMPLEMENTED();
1618 return 0;
1619 }
1620
GetTotalReceivedBytes() const1621 int64_t MockClientSocket::GetTotalReceivedBytes() const {
1622 NOTIMPLEMENTED();
1623 return 0;
1624 }
1625
SetReceiveBufferSize(int32_t size)1626 int MockSSLClientSocket::SetReceiveBufferSize(int32_t size) {
1627 return OK;
1628 }
1629
SetSendBufferSize(int32_t size)1630 int MockSSLClientSocket::SetSendBufferSize(int32_t size) {
1631 return OK;
1632 }
1633
GetSSLCertRequestInfo(SSLCertRequestInfo * cert_request_info) const1634 void MockSSLClientSocket::GetSSLCertRequestInfo(
1635 SSLCertRequestInfo* cert_request_info) const {
1636 DCHECK(cert_request_info);
1637 if (data_->cert_request_info) {
1638 cert_request_info->host_and_port =
1639 data_->cert_request_info->host_and_port;
1640 cert_request_info->is_proxy = data_->cert_request_info->is_proxy;
1641 cert_request_info->cert_authorities =
1642 data_->cert_request_info->cert_authorities;
1643 cert_request_info->cert_key_types =
1644 data_->cert_request_info->cert_key_types;
1645 } else {
1646 cert_request_info->Reset();
1647 }
1648 }
1649
ExportKeyingMaterial(const base::StringPiece & label,bool has_context,const base::StringPiece & context,unsigned char * out,unsigned int outlen)1650 int MockSSLClientSocket::ExportKeyingMaterial(const base::StringPiece& label,
1651 bool has_context,
1652 const base::StringPiece& context,
1653 unsigned char* out,
1654 unsigned int outlen) {
1655 memset(out, 'A', outlen);
1656 return OK;
1657 }
1658
RunCallbackAsync(CompletionOnceCallback callback,int result)1659 void MockSSLClientSocket::RunCallbackAsync(CompletionOnceCallback callback,
1660 int result) {
1661 base::ThreadTaskRunnerHandle::Get()->PostTask(
1662 FROM_HERE,
1663 base::BindOnce(&MockSSLClientSocket::RunCallback,
1664 weak_factory_.GetWeakPtr(), std::move(callback), result));
1665 }
1666
RunCallback(CompletionOnceCallback callback,int result)1667 void MockSSLClientSocket::RunCallback(CompletionOnceCallback callback,
1668 int result) {
1669 std::move(callback).Run(result);
1670 }
1671
OnReadComplete(const MockRead & data)1672 void MockSSLClientSocket::OnReadComplete(const MockRead& data) {
1673 NOTIMPLEMENTED();
1674 }
1675
OnWriteComplete(int rv)1676 void MockSSLClientSocket::OnWriteComplete(int rv) {
1677 NOTIMPLEMENTED();
1678 }
1679
OnConnectComplete(const MockConnect & data)1680 void MockSSLClientSocket::OnConnectComplete(const MockConnect& data) {
1681 NOTIMPLEMENTED();
1682 }
1683
MockUDPClientSocket(SocketDataProvider * data,net::NetLog * net_log)1684 MockUDPClientSocket::MockUDPClientSocket(SocketDataProvider* data,
1685 net::NetLog* net_log)
1686 : connected_(false),
1687 data_(data),
1688 read_offset_(0),
1689 read_data_(SYNCHRONOUS, ERR_UNEXPECTED),
1690 need_read_data_(true),
1691 source_port_(123),
1692 network_(NetworkChangeNotifier::kInvalidNetworkHandle),
1693 pending_read_buf_(nullptr),
1694 pending_read_buf_len_(0),
1695 net_log_(NetLogWithSource::Make(net_log, NetLogSourceType::NONE)) {
1696 if (data_) {
1697 data_->Initialize(this);
1698 peer_addr_ = data->connect_data().peer_addr;
1699 }
1700 }
1701
~MockUDPClientSocket()1702 MockUDPClientSocket::~MockUDPClientSocket() {
1703 if (data_)
1704 data_->DetachSocket();
1705 }
1706
Read(IOBuffer * buf,int buf_len,CompletionOnceCallback callback)1707 int MockUDPClientSocket::Read(IOBuffer* buf,
1708 int buf_len,
1709 CompletionOnceCallback callback) {
1710 DCHECK(callback);
1711
1712 if (!connected_ || !data_)
1713 return ERR_UNEXPECTED;
1714 data_transferred_ = true;
1715
1716 // If the buffer is already in use, a read is already in progress!
1717 DCHECK(!pending_read_buf_);
1718
1719 // Store our async IO data.
1720 pending_read_buf_ = buf;
1721 pending_read_buf_len_ = buf_len;
1722 pending_read_callback_ = std::move(callback);
1723
1724 if (need_read_data_) {
1725 read_data_ = data_->OnRead();
1726 // ERR_IO_PENDING means that the SocketDataProvider is taking responsibility
1727 // to complete the async IO manually later (via OnReadComplete).
1728 if (read_data_.result == ERR_IO_PENDING) {
1729 // We need to be using async IO in this case.
1730 DCHECK(!pending_read_callback_.is_null());
1731 return ERR_IO_PENDING;
1732 }
1733 need_read_data_ = false;
1734 }
1735
1736 return CompleteRead();
1737 }
1738
Write(IOBuffer * buf,int buf_len,CompletionOnceCallback callback,const NetworkTrafficAnnotationTag &)1739 int MockUDPClientSocket::Write(
1740 IOBuffer* buf,
1741 int buf_len,
1742 CompletionOnceCallback callback,
1743 const NetworkTrafficAnnotationTag& /* traffic_annotation */) {
1744 DCHECK(buf);
1745 DCHECK_GT(buf_len, 0);
1746 DCHECK(callback);
1747
1748 if (!connected_ || !data_)
1749 return ERR_UNEXPECTED;
1750 data_transferred_ = true;
1751
1752 std::string data(buf->data(), buf_len);
1753 MockWriteResult write_result = data_->OnWrite(data);
1754
1755 // ERR_IO_PENDING is a signal that the socket data will call back
1756 // asynchronously.
1757 if (write_result.result == ERR_IO_PENDING) {
1758 pending_write_callback_ = std::move(callback);
1759 return ERR_IO_PENDING;
1760 }
1761 if (write_result.mode == ASYNC) {
1762 RunCallbackAsync(std::move(callback), write_result.result);
1763 return ERR_IO_PENDING;
1764 }
1765 return write_result.result;
1766 }
1767
WriteAsync(const char * buffer,size_t buf_len,CompletionOnceCallback callback,const NetworkTrafficAnnotationTag &)1768 int MockUDPClientSocket::WriteAsync(
1769 const char* buffer,
1770 size_t buf_len,
1771 CompletionOnceCallback callback,
1772 const NetworkTrafficAnnotationTag& /* traffic_annotation */) {
1773 DCHECK(buffer);
1774 DCHECK_GT(buf_len, 0u);
1775 DCHECK(callback);
1776
1777 if (!connected_ || !data_)
1778 return ERR_UNEXPECTED;
1779 data_transferred_ = true;
1780
1781 std::string data(buffer, buf_len);
1782 MockWriteResult write_result = data_->OnWrite(data);
1783
1784 // ERR_IO_PENDING is a signal that the socket data will call back
1785 // asynchronously.
1786 if (write_result.result == ERR_IO_PENDING) {
1787 pending_write_callback_ = std::move(callback);
1788 return ERR_IO_PENDING;
1789 }
1790 if (write_result.mode == ASYNC) {
1791 RunCallbackAsync(std::move(callback), write_result.result);
1792 return ERR_IO_PENDING;
1793 }
1794 return write_result.result;
1795 }
1796
WriteAsync(DatagramBuffers buffers,CompletionOnceCallback callback,const NetworkTrafficAnnotationTag &)1797 int MockUDPClientSocket::WriteAsync(
1798 DatagramBuffers buffers,
1799 CompletionOnceCallback callback,
1800 const NetworkTrafficAnnotationTag& /* traffic_annotation */) {
1801 DCHECK(!buffers.empty());
1802 DCHECK(callback);
1803
1804 if (!connected_ || !data_)
1805 return ERR_UNEXPECTED;
1806
1807 unwritten_buffers_ = std::move(buffers);
1808
1809 int rv = 0;
1810 size_t buf_len = 0;
1811 do {
1812 auto& buf = unwritten_buffers_.front();
1813
1814 buf_len = buf->length();
1815 std::string data(buf->data(), buf_len);
1816 MockWriteResult write_result = data_->OnWrite(data);
1817 rv = write_result.result;
1818
1819 // ERR_IO_PENDING is a signal that the socket data will call back
1820 // asynchronously.
1821 if (write_result.result == ERR_IO_PENDING) {
1822 pending_write_callback_ = std::move(callback);
1823 return ERR_IO_PENDING;
1824 }
1825 if (write_result.mode == ASYNC) {
1826 RunCallbackAsync(std::move(callback), write_result.result);
1827 return ERR_IO_PENDING;
1828 }
1829
1830 if (rv < 0) {
1831 return rv;
1832 }
1833
1834 unwritten_buffers_.pop_front();
1835 } while (!unwritten_buffers_.empty());
1836
1837 return buf_len;
1838 }
1839
GetUnwrittenBuffers()1840 DatagramBuffers MockUDPClientSocket::GetUnwrittenBuffers() {
1841 return std::move(unwritten_buffers_);
1842 }
1843
SetReceiveBufferSize(int32_t size)1844 int MockUDPClientSocket::SetReceiveBufferSize(int32_t size) {
1845 return OK;
1846 }
1847
SetSendBufferSize(int32_t size)1848 int MockUDPClientSocket::SetSendBufferSize(int32_t size) {
1849 return OK;
1850 }
1851
SetDoNotFragment()1852 int MockUDPClientSocket::SetDoNotFragment() {
1853 return OK;
1854 }
1855
Close()1856 void MockUDPClientSocket::Close() {
1857 connected_ = false;
1858 }
1859
GetPeerAddress(IPEndPoint * address) const1860 int MockUDPClientSocket::GetPeerAddress(IPEndPoint* address) const {
1861 if (!data_)
1862 return ERR_UNEXPECTED;
1863
1864 *address = peer_addr_;
1865 return OK;
1866 }
1867
GetLocalAddress(IPEndPoint * address) const1868 int MockUDPClientSocket::GetLocalAddress(IPEndPoint* address) const {
1869 *address = IPEndPoint(IPAddress(192, 0, 2, 33), source_port_);
1870 return OK;
1871 }
1872
UseNonBlockingIO()1873 void MockUDPClientSocket::UseNonBlockingIO() {}
1874
SetWriteAsyncEnabled(bool enabled)1875 void MockUDPClientSocket::SetWriteAsyncEnabled(bool enabled) {}
WriteAsyncEnabled()1876 bool MockUDPClientSocket::WriteAsyncEnabled() {
1877 return false;
1878 }
SetMaxPacketSize(size_t max_packet_size)1879 void MockUDPClientSocket::SetMaxPacketSize(size_t max_packet_size) {}
SetWriteMultiCoreEnabled(bool enabled)1880 void MockUDPClientSocket::SetWriteMultiCoreEnabled(bool enabled) {}
SetSendmmsgEnabled(bool enabled)1881 void MockUDPClientSocket::SetSendmmsgEnabled(bool enabled) {}
SetWriteBatchingActive(bool active)1882 void MockUDPClientSocket::SetWriteBatchingActive(bool active) {}
SetMulticastInterface(uint32_t interface_index)1883 int MockUDPClientSocket::SetMulticastInterface(uint32_t interface_index) {
1884 return OK;
1885 }
1886
NetLog() const1887 const NetLogWithSource& MockUDPClientSocket::NetLog() const {
1888 return net_log_;
1889 }
1890
Connect(const IPEndPoint & address)1891 int MockUDPClientSocket::Connect(const IPEndPoint& address) {
1892 if (!data_)
1893 return ERR_UNEXPECTED;
1894 connected_ = true;
1895 peer_addr_ = address;
1896 return data_->connect_data().result;
1897 }
1898
ConnectUsingNetwork(NetworkChangeNotifier::NetworkHandle network,const IPEndPoint & address)1899 int MockUDPClientSocket::ConnectUsingNetwork(
1900 NetworkChangeNotifier::NetworkHandle network,
1901 const IPEndPoint& address) {
1902 DCHECK(!connected_);
1903 if (!data_)
1904 return ERR_UNEXPECTED;
1905 network_ = network;
1906 connected_ = true;
1907 peer_addr_ = address;
1908 return data_->connect_data().result;
1909 }
1910
ConnectUsingDefaultNetwork(const IPEndPoint & address)1911 int MockUDPClientSocket::ConnectUsingDefaultNetwork(const IPEndPoint& address) {
1912 DCHECK(!connected_);
1913 if (!data_)
1914 return ERR_UNEXPECTED;
1915 network_ = kDefaultNetworkForTests;
1916 connected_ = true;
1917 peer_addr_ = address;
1918 return data_->connect_data().result;
1919 }
1920
GetBoundNetwork() const1921 NetworkChangeNotifier::NetworkHandle MockUDPClientSocket::GetBoundNetwork()
1922 const {
1923 return network_;
1924 }
1925
ApplySocketTag(const SocketTag & tag)1926 void MockUDPClientSocket::ApplySocketTag(const SocketTag& tag) {
1927 tagged_before_data_transferred_ &= !data_transferred_ || tag == tag_;
1928 tag_ = tag;
1929 }
1930
OnReadComplete(const MockRead & data)1931 void MockUDPClientSocket::OnReadComplete(const MockRead& data) {
1932 if (!data_)
1933 return;
1934
1935 // There must be a read pending.
1936 DCHECK(pending_read_buf_.get());
1937 DCHECK(pending_read_callback_);
1938 // You can't complete a read with another ERR_IO_PENDING status code.
1939 DCHECK_NE(ERR_IO_PENDING, data.result);
1940 // Since we've been waiting for data, need_read_data_ should be true.
1941 DCHECK(need_read_data_);
1942
1943 read_data_ = data;
1944 need_read_data_ = false;
1945
1946 // The caller is simulating that this IO completes right now. Don't
1947 // let CompleteRead() schedule a callback.
1948 read_data_.mode = SYNCHRONOUS;
1949
1950 CompletionOnceCallback callback = std::move(pending_read_callback_);
1951 int rv = CompleteRead();
1952 RunCallback(std::move(callback), rv);
1953 }
1954
OnWriteComplete(int rv)1955 void MockUDPClientSocket::OnWriteComplete(int rv) {
1956 if (!data_)
1957 return;
1958
1959 // There must be a read pending.
1960 DCHECK(!pending_write_callback_.is_null());
1961 RunCallback(std::move(pending_write_callback_), rv);
1962 }
1963
OnConnectComplete(const MockConnect & data)1964 void MockUDPClientSocket::OnConnectComplete(const MockConnect& data) {
1965 NOTIMPLEMENTED();
1966 }
1967
OnDataProviderDestroyed()1968 void MockUDPClientSocket::OnDataProviderDestroyed() {
1969 data_ = nullptr;
1970 }
1971
CompleteRead()1972 int MockUDPClientSocket::CompleteRead() {
1973 DCHECK(pending_read_buf_.get());
1974 DCHECK(pending_read_buf_len_ > 0);
1975
1976 // Save the pending async IO data and reset our |pending_| state.
1977 scoped_refptr<IOBuffer> buf = pending_read_buf_;
1978 int buf_len = pending_read_buf_len_;
1979 CompletionOnceCallback callback = std::move(pending_read_callback_);
1980 pending_read_buf_ = nullptr;
1981 pending_read_buf_len_ = 0;
1982
1983 int result = read_data_.result;
1984 DCHECK(result != ERR_IO_PENDING);
1985
1986 if (read_data_.data) {
1987 if (read_data_.data_len - read_offset_ > 0) {
1988 result = std::min(buf_len, read_data_.data_len - read_offset_);
1989 memcpy(buf->data(), read_data_.data + read_offset_, result);
1990 read_offset_ += result;
1991 if (read_offset_ == read_data_.data_len) {
1992 need_read_data_ = true;
1993 read_offset_ = 0;
1994 }
1995 } else {
1996 result = 0; // EOF
1997 }
1998 }
1999
2000 if (read_data_.mode == ASYNC) {
2001 DCHECK(!callback.is_null());
2002 RunCallbackAsync(std::move(callback), result);
2003 return ERR_IO_PENDING;
2004 }
2005 return result;
2006 }
2007
RunCallbackAsync(CompletionOnceCallback callback,int result)2008 void MockUDPClientSocket::RunCallbackAsync(CompletionOnceCallback callback,
2009 int result) {
2010 base::ThreadTaskRunnerHandle::Get()->PostTask(
2011 FROM_HERE,
2012 base::BindOnce(&MockUDPClientSocket::RunCallback,
2013 weak_factory_.GetWeakPtr(), std::move(callback), result));
2014 }
2015
RunCallback(CompletionOnceCallback callback,int result)2016 void MockUDPClientSocket::RunCallback(CompletionOnceCallback callback,
2017 int result) {
2018 std::move(callback).Run(result);
2019 }
2020
TestSocketRequest(std::vector<TestSocketRequest * > * request_order,size_t * completion_count)2021 TestSocketRequest::TestSocketRequest(
2022 std::vector<TestSocketRequest*>* request_order,
2023 size_t* completion_count)
2024 : request_order_(request_order), completion_count_(completion_count) {
2025 DCHECK(request_order);
2026 DCHECK(completion_count);
2027 }
2028
2029 TestSocketRequest::~TestSocketRequest() = default;
2030
OnComplete(int result)2031 void TestSocketRequest::OnComplete(int result) {
2032 SetResult(result);
2033 (*completion_count_)++;
2034 request_order_->push_back(this);
2035 }
2036
2037 // static
2038 const int ClientSocketPoolTest::kIndexOutOfBounds = -1;
2039
2040 // static
2041 const int ClientSocketPoolTest::kRequestNotFound = -2;
2042
ClientSocketPoolTest()2043 ClientSocketPoolTest::ClientSocketPoolTest() : completion_count_(0) {}
2044 ClientSocketPoolTest::~ClientSocketPoolTest() = default;
2045
GetOrderOfRequest(size_t index) const2046 int ClientSocketPoolTest::GetOrderOfRequest(size_t index) const {
2047 index--;
2048 if (index >= requests_.size())
2049 return kIndexOutOfBounds;
2050
2051 for (size_t i = 0; i < request_order_.size(); i++)
2052 if (requests_[index].get() == request_order_[i])
2053 return i + 1;
2054
2055 return kRequestNotFound;
2056 }
2057
ReleaseOneConnection(KeepAlive keep_alive)2058 bool ClientSocketPoolTest::ReleaseOneConnection(KeepAlive keep_alive) {
2059 for (std::unique_ptr<TestSocketRequest>& it : requests_) {
2060 if (it->handle()->is_initialized()) {
2061 if (keep_alive == NO_KEEP_ALIVE)
2062 it->handle()->socket()->Disconnect();
2063 it->handle()->Reset();
2064 base::RunLoop().RunUntilIdle();
2065 return true;
2066 }
2067 }
2068 return false;
2069 }
2070
ReleaseAllConnections(KeepAlive keep_alive)2071 void ClientSocketPoolTest::ReleaseAllConnections(KeepAlive keep_alive) {
2072 bool released_one;
2073 do {
2074 released_one = ReleaseOneConnection(keep_alive);
2075 } while (released_one);
2076 }
2077
MockConnectJob(std::unique_ptr<StreamSocket> socket,ClientSocketHandle * handle,const SocketTag & socket_tag,CompletionOnceCallback callback,RequestPriority priority)2078 MockTransportClientSocketPool::MockConnectJob::MockConnectJob(
2079 std::unique_ptr<StreamSocket> socket,
2080 ClientSocketHandle* handle,
2081 const SocketTag& socket_tag,
2082 CompletionOnceCallback callback,
2083 RequestPriority priority)
2084 : socket_(std::move(socket)),
2085 handle_(handle),
2086 socket_tag_(socket_tag),
2087 user_callback_(std::move(callback)),
2088 priority_(priority) {}
2089
2090 MockTransportClientSocketPool::MockConnectJob::~MockConnectJob() = default;
2091
Connect()2092 int MockTransportClientSocketPool::MockConnectJob::Connect() {
2093 socket_->ApplySocketTag(socket_tag_);
2094 int rv = socket_->Connect(
2095 base::BindOnce(&MockConnectJob::OnConnect, base::Unretained(this)));
2096 if (rv != ERR_IO_PENDING) {
2097 user_callback_.Reset();
2098 OnConnect(rv);
2099 }
2100 return rv;
2101 }
2102
CancelHandle(const ClientSocketHandle * handle)2103 bool MockTransportClientSocketPool::MockConnectJob::CancelHandle(
2104 const ClientSocketHandle* handle) {
2105 if (handle != handle_)
2106 return false;
2107 socket_.reset();
2108 handle_ = nullptr;
2109 user_callback_.Reset();
2110 return true;
2111 }
2112
OnConnect(int rv)2113 void MockTransportClientSocketPool::MockConnectJob::OnConnect(int rv) {
2114 if (!socket_.get())
2115 return;
2116 if (rv == OK) {
2117 handle_->SetSocket(std::move(socket_));
2118
2119 // Needed for socket pool tests that layer other sockets on top of mock
2120 // sockets.
2121 LoadTimingInfo::ConnectTiming connect_timing;
2122 base::TimeTicks now = base::TimeTicks::Now();
2123 connect_timing.dns_start = now;
2124 connect_timing.dns_end = now;
2125 connect_timing.connect_start = now;
2126 connect_timing.connect_end = now;
2127 handle_->set_connect_timing(connect_timing);
2128 } else {
2129 socket_.reset();
2130
2131 // Needed to test copying of ConnectionAttempts in SSL ConnectJob.
2132 ConnectionAttempts attempts;
2133 attempts.push_back(ConnectionAttempt(IPEndPoint(), rv));
2134 handle_->set_connection_attempts(attempts);
2135 }
2136
2137 handle_ = nullptr;
2138
2139 if (!user_callback_.is_null()) {
2140 std::move(user_callback_).Run(rv);
2141 }
2142 }
2143
MockTransportClientSocketPool(int max_sockets,int max_sockets_per_group,const CommonConnectJobParams * common_connect_job_params)2144 MockTransportClientSocketPool::MockTransportClientSocketPool(
2145 int max_sockets,
2146 int max_sockets_per_group,
2147 const CommonConnectJobParams* common_connect_job_params)
2148 : TransportClientSocketPool(
2149 max_sockets,
2150 max_sockets_per_group,
2151 base::TimeDelta::FromSeconds(10) /* unused_idle_socket_timeout */,
2152 ProxyServer::Direct(),
2153 false /* is_for_websockets */,
2154 common_connect_job_params),
2155 client_socket_factory_(common_connect_job_params->client_socket_factory),
2156 last_request_priority_(DEFAULT_PRIORITY),
2157 release_count_(0),
2158 cancel_count_(0) {}
2159
2160 MockTransportClientSocketPool::~MockTransportClientSocketPool() = default;
2161
RequestSocket(const ClientSocketPool::GroupId & group_id,scoped_refptr<ClientSocketPool::SocketParams> socket_params,const base::Optional<NetworkTrafficAnnotationTag> & proxy_annotation_tag,RequestPriority priority,const SocketTag & socket_tag,RespectLimits respect_limits,ClientSocketHandle * handle,CompletionOnceCallback callback,const ProxyAuthCallback & on_auth_callback,const NetLogWithSource & net_log)2162 int MockTransportClientSocketPool::RequestSocket(
2163 const ClientSocketPool::GroupId& group_id,
2164 scoped_refptr<ClientSocketPool::SocketParams> socket_params,
2165 const base::Optional<NetworkTrafficAnnotationTag>& proxy_annotation_tag,
2166 RequestPriority priority,
2167 const SocketTag& socket_tag,
2168 RespectLimits respect_limits,
2169 ClientSocketHandle* handle,
2170 CompletionOnceCallback callback,
2171 const ProxyAuthCallback& on_auth_callback,
2172 const NetLogWithSource& net_log) {
2173 last_request_priority_ = priority;
2174 std::unique_ptr<StreamSocket> socket =
2175 client_socket_factory_->CreateTransportClientSocket(
2176 AddressList(), nullptr, net_log.net_log(), NetLogSource());
2177 MockConnectJob* job = new MockConnectJob(
2178 std::move(socket), handle, socket_tag, std::move(callback), priority);
2179 job_list_.push_back(base::WrapUnique(job));
2180 handle->set_group_generation(1);
2181 return job->Connect();
2182 }
2183
SetPriority(const ClientSocketPool::GroupId & group_id,ClientSocketHandle * handle,RequestPriority priority)2184 void MockTransportClientSocketPool::SetPriority(
2185 const ClientSocketPool::GroupId& group_id,
2186 ClientSocketHandle* handle,
2187 RequestPriority priority) {
2188 for (auto& job : job_list_) {
2189 if (job->handle() == handle) {
2190 job->set_priority(priority);
2191 return;
2192 }
2193 }
2194 NOTREACHED();
2195 }
2196
CancelRequest(const ClientSocketPool::GroupId & group_id,ClientSocketHandle * handle,bool cancel_connect_job)2197 void MockTransportClientSocketPool::CancelRequest(
2198 const ClientSocketPool::GroupId& group_id,
2199 ClientSocketHandle* handle,
2200 bool cancel_connect_job) {
2201 for (std::unique_ptr<MockConnectJob>& it : job_list_) {
2202 if (it->CancelHandle(handle)) {
2203 cancel_count_++;
2204 break;
2205 }
2206 }
2207 }
2208
ReleaseSocket(const ClientSocketPool::GroupId & group_id,std::unique_ptr<StreamSocket> socket,int64_t generation)2209 void MockTransportClientSocketPool::ReleaseSocket(
2210 const ClientSocketPool::GroupId& group_id,
2211 std::unique_ptr<StreamSocket> socket,
2212 int64_t generation) {
2213 EXPECT_EQ(1, generation);
2214 release_count_++;
2215 }
2216
WrappedStreamSocket(std::unique_ptr<StreamSocket> transport)2217 WrappedStreamSocket::WrappedStreamSocket(
2218 std::unique_ptr<StreamSocket> transport)
2219 : transport_(std::move(transport)) {}
~WrappedStreamSocket()2220 WrappedStreamSocket::~WrappedStreamSocket() {}
2221
Bind(const net::IPEndPoint & local_addr)2222 int WrappedStreamSocket::Bind(const net::IPEndPoint& local_addr) {
2223 NOTREACHED();
2224 return ERR_FAILED;
2225 }
2226
Connect(CompletionOnceCallback callback)2227 int WrappedStreamSocket::Connect(CompletionOnceCallback callback) {
2228 return transport_->Connect(std::move(callback));
2229 }
2230
Disconnect()2231 void WrappedStreamSocket::Disconnect() {
2232 transport_->Disconnect();
2233 }
2234
IsConnected() const2235 bool WrappedStreamSocket::IsConnected() const {
2236 return transport_->IsConnected();
2237 }
2238
IsConnectedAndIdle() const2239 bool WrappedStreamSocket::IsConnectedAndIdle() const {
2240 return transport_->IsConnectedAndIdle();
2241 }
2242
GetPeerAddress(IPEndPoint * address) const2243 int WrappedStreamSocket::GetPeerAddress(IPEndPoint* address) const {
2244 return transport_->GetPeerAddress(address);
2245 }
2246
GetLocalAddress(IPEndPoint * address) const2247 int WrappedStreamSocket::GetLocalAddress(IPEndPoint* address) const {
2248 return transport_->GetLocalAddress(address);
2249 }
2250
NetLog() const2251 const NetLogWithSource& WrappedStreamSocket::NetLog() const {
2252 return transport_->NetLog();
2253 }
2254
WasEverUsed() const2255 bool WrappedStreamSocket::WasEverUsed() const {
2256 return transport_->WasEverUsed();
2257 }
2258
WasAlpnNegotiated() const2259 bool WrappedStreamSocket::WasAlpnNegotiated() const {
2260 return transport_->WasAlpnNegotiated();
2261 }
2262
GetNegotiatedProtocol() const2263 NextProto WrappedStreamSocket::GetNegotiatedProtocol() const {
2264 return transport_->GetNegotiatedProtocol();
2265 }
2266
GetSSLInfo(SSLInfo * ssl_info)2267 bool WrappedStreamSocket::GetSSLInfo(SSLInfo* ssl_info) {
2268 return transport_->GetSSLInfo(ssl_info);
2269 }
2270
GetConnectionAttempts(ConnectionAttempts * out) const2271 void WrappedStreamSocket::GetConnectionAttempts(ConnectionAttempts* out) const {
2272 transport_->GetConnectionAttempts(out);
2273 }
2274
ClearConnectionAttempts()2275 void WrappedStreamSocket::ClearConnectionAttempts() {
2276 transport_->ClearConnectionAttempts();
2277 }
2278
AddConnectionAttempts(const ConnectionAttempts & attempts)2279 void WrappedStreamSocket::AddConnectionAttempts(
2280 const ConnectionAttempts& attempts) {
2281 transport_->AddConnectionAttempts(attempts);
2282 }
2283
GetTotalReceivedBytes() const2284 int64_t WrappedStreamSocket::GetTotalReceivedBytes() const {
2285 return transport_->GetTotalReceivedBytes();
2286 }
2287
ApplySocketTag(const SocketTag & tag)2288 void WrappedStreamSocket::ApplySocketTag(const SocketTag& tag) {
2289 transport_->ApplySocketTag(tag);
2290 }
2291
Read(IOBuffer * buf,int buf_len,CompletionOnceCallback callback)2292 int WrappedStreamSocket::Read(IOBuffer* buf,
2293 int buf_len,
2294 CompletionOnceCallback callback) {
2295 return transport_->Read(buf, buf_len, std::move(callback));
2296 }
2297
ReadIfReady(IOBuffer * buf,int buf_len,CompletionOnceCallback callback)2298 int WrappedStreamSocket::ReadIfReady(IOBuffer* buf,
2299 int buf_len,
2300 CompletionOnceCallback callback) {
2301 return transport_->ReadIfReady(buf, buf_len, std::move((callback)));
2302 }
2303
Write(IOBuffer * buf,int buf_len,CompletionOnceCallback callback,const NetworkTrafficAnnotationTag & traffic_annotation)2304 int WrappedStreamSocket::Write(
2305 IOBuffer* buf,
2306 int buf_len,
2307 CompletionOnceCallback callback,
2308 const NetworkTrafficAnnotationTag& traffic_annotation) {
2309 return transport_->Write(buf, buf_len, std::move(callback),
2310 TRAFFIC_ANNOTATION_FOR_TESTS);
2311 }
2312
SetReceiveBufferSize(int32_t size)2313 int WrappedStreamSocket::SetReceiveBufferSize(int32_t size) {
2314 return transport_->SetReceiveBufferSize(size);
2315 }
2316
SetSendBufferSize(int32_t size)2317 int WrappedStreamSocket::SetSendBufferSize(int32_t size) {
2318 return transport_->SetSendBufferSize(size);
2319 }
2320
Connect(CompletionOnceCallback callback)2321 int MockTaggingStreamSocket::Connect(CompletionOnceCallback callback) {
2322 connected_ = true;
2323 return WrappedStreamSocket::Connect(std::move(callback));
2324 }
2325
ApplySocketTag(const SocketTag & tag)2326 void MockTaggingStreamSocket::ApplySocketTag(const SocketTag& tag) {
2327 tagged_before_connected_ &= !connected_ || tag == tag_;
2328 tag_ = tag;
2329 transport_->ApplySocketTag(tag);
2330 }
2331
2332 std::unique_ptr<TransportClientSocket>
CreateTransportClientSocket(const AddressList & addresses,std::unique_ptr<SocketPerformanceWatcher> socket_performance_watcher,NetLog * net_log,const NetLogSource & source)2333 MockTaggingClientSocketFactory::CreateTransportClientSocket(
2334 const AddressList& addresses,
2335 std::unique_ptr<SocketPerformanceWatcher> socket_performance_watcher,
2336 NetLog* net_log,
2337 const NetLogSource& source) {
2338 std::unique_ptr<MockTaggingStreamSocket> socket(new MockTaggingStreamSocket(
2339 MockClientSocketFactory::CreateTransportClientSocket(
2340 addresses, std::move(socket_performance_watcher), net_log, source)));
2341 tcp_socket_ = socket.get();
2342 return std::move(socket);
2343 }
2344
2345 std::unique_ptr<DatagramClientSocket>
CreateDatagramClientSocket(DatagramSocket::BindType bind_type,NetLog * net_log,const NetLogSource & source)2346 MockTaggingClientSocketFactory::CreateDatagramClientSocket(
2347 DatagramSocket::BindType bind_type,
2348 NetLog* net_log,
2349 const NetLogSource& source) {
2350 std::unique_ptr<DatagramClientSocket> socket(
2351 MockClientSocketFactory::CreateDatagramClientSocket(bind_type, net_log,
2352 source));
2353 udp_socket_ = static_cast<MockUDPClientSocket*>(socket.get());
2354 return socket;
2355 }
2356
2357 const char kSOCKS4TestHost[] = "127.0.0.1";
2358 const int kSOCKS4TestPort = 80;
2359
2360 const char kSOCKS4OkRequestLocalHostPort80[] = {0x04, 0x01, 0x00, 0x50, 127,
2361 0, 0, 1, 0};
2362 const int kSOCKS4OkRequestLocalHostPort80Length =
2363 base::size(kSOCKS4OkRequestLocalHostPort80);
2364
2365 const char kSOCKS4OkReply[] = {0x00, 0x5A, 0x00, 0x00, 0, 0, 0, 0};
2366 const int kSOCKS4OkReplyLength = base::size(kSOCKS4OkReply);
2367
2368 const char kSOCKS5TestHost[] = "host";
2369 const int kSOCKS5TestPort = 80;
2370
2371 const char kSOCKS5GreetRequest[] = { 0x05, 0x01, 0x00 };
2372 const int kSOCKS5GreetRequestLength = base::size(kSOCKS5GreetRequest);
2373
2374 const char kSOCKS5GreetResponse[] = { 0x05, 0x00 };
2375 const int kSOCKS5GreetResponseLength = base::size(kSOCKS5GreetResponse);
2376
2377 const char kSOCKS5OkRequest[] =
2378 { 0x05, 0x01, 0x00, 0x03, 0x04, 'h', 'o', 's', 't', 0x00, 0x50 };
2379 const int kSOCKS5OkRequestLength = base::size(kSOCKS5OkRequest);
2380
2381 const char kSOCKS5OkResponse[] =
2382 { 0x05, 0x00, 0x00, 0x01, 127, 0, 0, 1, 0x00, 0x50 };
2383 const int kSOCKS5OkResponseLength = base::size(kSOCKS5OkResponse);
2384
CountReadBytes(base::span<const MockRead> reads)2385 int64_t CountReadBytes(base::span<const MockRead> reads) {
2386 int64_t total = 0;
2387 for (const MockRead& read : reads)
2388 total += read.data_len;
2389 return total;
2390 }
2391
CountWriteBytes(base::span<const MockWrite> writes)2392 int64_t CountWriteBytes(base::span<const MockWrite> writes) {
2393 int64_t total = 0;
2394 for (const MockWrite& write : writes)
2395 total += write.data_len;
2396 return total;
2397 }
2398
2399 #if defined(OS_ANDROID)
CanGetTaggedBytes()2400 bool CanGetTaggedBytes() {
2401 // In Android P, /proc/net/xt_qtaguid/stats is no longer guaranteed to be
2402 // present, and has been replaced with eBPF Traffic Monitoring in netd. See:
2403 // https://source.android.com/devices/tech/datausage/ebpf-traffic-monitor
2404 //
2405 // To read traffic statistics from netd, apps should use the API
2406 // NetworkStatsManager.queryDetailsForUidTag(). But this API does not provide
2407 // statistics for local traffic, only mobile and WiFi traffic, so it would not
2408 // work in tests that spin up a local server. So for now, GetTaggedBytes is
2409 // only supported on Android releases older than P.
2410 return base::android::BuildInfo::GetInstance()->sdk_int() <
2411 base::android::SDK_VERSION_P;
2412 }
2413
GetTaggedBytes(int32_t expected_tag)2414 uint64_t GetTaggedBytes(int32_t expected_tag) {
2415 EXPECT_TRUE(CanGetTaggedBytes());
2416
2417 // To determine how many bytes the system saw with a particular tag read
2418 // the /proc/net/xt_qtaguid/stats file which contains the kernel's
2419 // dump of all the UIDs and their tags sent and received bytes.
2420 uint64_t bytes = 0;
2421 std::string contents;
2422 EXPECT_TRUE(base::ReadFileToString(
2423 base::FilePath::FromUTF8Unsafe("/proc/net/xt_qtaguid/stats"), &contents));
2424 for (size_t i = contents.find('\n'); // Skip first line which is headers.
2425 i != std::string::npos && i < contents.length();) {
2426 uint64_t tag, rx_bytes;
2427 uid_t uid;
2428 int n;
2429 // Parse out the numbers we care about. For reference here's the column
2430 // headers:
2431 // idx iface acct_tag_hex uid_tag_int cnt_set rx_bytes rx_packets tx_bytes
2432 // tx_packets rx_tcp_bytes rx_tcp_packets rx_udp_bytes rx_udp_packets
2433 // rx_other_bytes rx_other_packets tx_tcp_bytes tx_tcp_packets tx_udp_bytes
2434 // tx_udp_packets tx_other_bytes tx_other_packets
2435 EXPECT_EQ(sscanf(contents.c_str() + i,
2436 "%*d %*s 0x%" SCNx64 " %d %*d %" SCNu64
2437 " %*d %*d %*d %*d %*d %*d %*d %*d "
2438 "%*d %*d %*d %*d %*d %*d %*d%n",
2439 &tag, &uid, &rx_bytes, &n),
2440 3);
2441 // If this line matches our UID and |expected_tag| then add it to the total.
2442 if (uid == getuid() && (int32_t)(tag >> 32) == expected_tag) {
2443 bytes += rx_bytes;
2444 }
2445 // Move |i| to the next line.
2446 i += n + 1;
2447 }
2448 return bytes;
2449 }
2450 #endif
2451
2452 } // namespace net
2453