1 // Copyright (c) 2012 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #include "net/socket/socket_test_util.h"
6 
7 #include <inttypes.h>  // For SCNx64
8 #include <stdint.h>
9 #include <stdio.h>
10 
11 #include <algorithm>
12 #include <string>
13 #include <utility>
14 #include <vector>
15 
16 #include "base/bind.h"
17 #include "base/bind_helpers.h"
18 #include "base/callback_helpers.h"
19 #include "base/compiler_specific.h"
20 #include "base/files/file_util.h"
21 #include "base/location.h"
22 #include "base/logging.h"
23 #include "base/rand_util.h"
24 #include "base/run_loop.h"
25 #include "base/single_thread_task_runner.h"
26 #include "base/stl_util.h"
27 #include "base/threading/thread_task_runner_handle.h"
28 #include "base/time/time.h"
29 #include "net/base/address_family.h"
30 #include "net/base/address_list.h"
31 #include "net/base/auth.h"
32 #include "net/base/hex_utils.h"
33 #include "net/base/ip_address.h"
34 #include "net/base/load_timing_info.h"
35 #include "net/base/proxy_server.h"
36 #include "net/http/http_network_session.h"
37 #include "net/http/http_request_headers.h"
38 #include "net/http/http_response_headers.h"
39 #include "net/log/net_log_source.h"
40 #include "net/log/net_log_source_type.h"
41 #include "net/socket/connect_job.h"
42 #include "net/socket/socket.h"
43 #include "net/socket/stream_socket.h"
44 #include "net/socket/websocket_endpoint_lock_manager.h"
45 #include "net/ssl/ssl_cert_request_info.h"
46 #include "net/ssl/ssl_connection_status_flags.h"
47 #include "net/ssl/ssl_info.h"
48 #include "net/traffic_annotation/network_traffic_annotation.h"
49 #include "net/traffic_annotation/network_traffic_annotation_test_helper.h"
50 #include "testing/gtest/include/gtest/gtest.h"
51 
52 #if defined(OS_ANDROID)
53 #include "base/android/build_info.h"
54 #endif
55 
56 #define NET_TRACE(level, s) VLOG(level) << s << __FUNCTION__ << "() "
57 
58 namespace net {
59 namespace {
60 
AsciifyHigh(char x)61 inline char AsciifyHigh(char x) {
62   char nybble = static_cast<char>((x >> 4) & 0x0F);
63   return nybble + ((nybble < 0x0A) ? '0' : 'A' - 10);
64 }
65 
AsciifyLow(char x)66 inline char AsciifyLow(char x) {
67   char nybble = static_cast<char>((x >> 0) & 0x0F);
68   return nybble + ((nybble < 0x0A) ? '0' : 'A' - 10);
69 }
70 
Asciify(char x)71 inline char Asciify(char x) {
72   if ((x < 0) || !isprint(x))
73     return '.';
74   return x;
75 }
76 
DumpData(const char * data,int data_len)77 void DumpData(const char* data, int data_len) {
78   if (logging::LOG_INFO < logging::GetMinLogLevel())
79     return;
80   DVLOG(1) << "Length:  " << data_len;
81   const char* pfx = "Data:    ";
82   if (!data || (data_len <= 0)) {
83     DVLOG(1) << pfx << "<None>";
84   } else {
85     int i;
86     for (i = 0; i <= (data_len - 4); i += 4) {
87       DVLOG(1) << pfx
88                << AsciifyHigh(data[i + 0]) << AsciifyLow(data[i + 0])
89                << AsciifyHigh(data[i + 1]) << AsciifyLow(data[i + 1])
90                << AsciifyHigh(data[i + 2]) << AsciifyLow(data[i + 2])
91                << AsciifyHigh(data[i + 3]) << AsciifyLow(data[i + 3])
92                << "  '"
93                << Asciify(data[i + 0])
94                << Asciify(data[i + 1])
95                << Asciify(data[i + 2])
96                << Asciify(data[i + 3])
97                << "'";
98       pfx = "         ";
99     }
100     // Take care of any 'trailing' bytes, if data_len was not a multiple of 4.
101     switch (data_len - i) {
102       case 3:
103         DVLOG(1) << pfx
104                  << AsciifyHigh(data[i + 0]) << AsciifyLow(data[i + 0])
105                  << AsciifyHigh(data[i + 1]) << AsciifyLow(data[i + 1])
106                  << AsciifyHigh(data[i + 2]) << AsciifyLow(data[i + 2])
107                  << "    '"
108                  << Asciify(data[i + 0])
109                  << Asciify(data[i + 1])
110                  << Asciify(data[i + 2])
111                  << " '";
112         break;
113       case 2:
114         DVLOG(1) << pfx
115                  << AsciifyHigh(data[i + 0]) << AsciifyLow(data[i + 0])
116                  << AsciifyHigh(data[i + 1]) << AsciifyLow(data[i + 1])
117                  << "      '"
118                  << Asciify(data[i + 0])
119                  << Asciify(data[i + 1])
120                  << "  '";
121         break;
122       case 1:
123         DVLOG(1) << pfx
124                  << AsciifyHigh(data[i + 0]) << AsciifyLow(data[i + 0])
125                  << "        '"
126                  << Asciify(data[i + 0])
127                  << "   '";
128         break;
129     }
130   }
131 }
132 
133 template <MockReadWriteType type>
DumpMockReadWrite(const MockReadWrite<type> & r)134 void DumpMockReadWrite(const MockReadWrite<type>& r) {
135   if (logging::LOG_INFO < logging::GetMinLogLevel())
136     return;
137   DVLOG(1) << "Async:   " << (r.mode == ASYNC)
138            << "\nResult:  " << r.result;
139   DumpData(r.data, r.data_len);
140   const char* stop = (r.sequence_number & MockRead::STOPLOOP) ? " (STOP)" : "";
141   DVLOG(1) << "Stage:   " << (r.sequence_number & ~MockRead::STOPLOOP) << stop;
142 }
143 
RunClosureIfNonNull(base::OnceClosure closure)144 void RunClosureIfNonNull(base::OnceClosure closure) {
145   if (!closure.is_null()) {
146     std::move(closure).Run();
147   }
148 }
149 
150 }  // namespace
151 
MockConnect()152 MockConnect::MockConnect() : mode(ASYNC), result(OK) {
153   peer_addr = IPEndPoint(IPAddress(192, 0, 2, 33), 0);
154 }
155 
MockConnect(IoMode io_mode,int r)156 MockConnect::MockConnect(IoMode io_mode, int r) : mode(io_mode), result(r) {
157   peer_addr = IPEndPoint(IPAddress(192, 0, 2, 33), 0);
158 }
159 
MockConnect(IoMode io_mode,int r,IPEndPoint addr)160 MockConnect::MockConnect(IoMode io_mode, int r, IPEndPoint addr) :
161     mode(io_mode),
162     result(r),
163     peer_addr(addr) {
164 }
165 
166 MockConnect::~MockConnect() = default;
167 
MockConfirm()168 MockConfirm::MockConfirm() : mode(SYNCHRONOUS), result(OK) {}
169 
MockConfirm(IoMode io_mode,int r)170 MockConfirm::MockConfirm(IoMode io_mode, int r) : mode(io_mode), result(r) {}
171 
172 MockConfirm::~MockConfirm() = default;
173 
IsIdle() const174 bool SocketDataProvider::IsIdle() const {
175   return true;
176 }
177 
Initialize(AsyncSocket * socket)178 void SocketDataProvider::Initialize(AsyncSocket* socket) {
179   CHECK(!socket_);
180   CHECK(socket);
181   socket_ = socket;
182   Reset();
183 }
184 
DetachSocket()185 void SocketDataProvider::DetachSocket() {
186   CHECK(socket_);
187   socket_ = nullptr;
188 }
189 
SocketDataProvider()190 SocketDataProvider::SocketDataProvider() {}
191 
~SocketDataProvider()192 SocketDataProvider::~SocketDataProvider() {
193   if (socket_)
194     socket_->OnDataProviderDestroyed();
195 }
196 
StaticSocketDataHelper(base::span<const MockRead> reads,base::span<const MockWrite> writes)197 StaticSocketDataHelper::StaticSocketDataHelper(
198     base::span<const MockRead> reads,
199     base::span<const MockWrite> writes)
200     : reads_(reads), read_index_(0), writes_(writes), write_index_(0) {}
201 
202 StaticSocketDataHelper::~StaticSocketDataHelper() = default;
203 
PeekRead() const204 const MockRead& StaticSocketDataHelper::PeekRead() const {
205   CHECK(!AllReadDataConsumed());
206   return reads_[read_index_];
207 }
208 
PeekWrite() const209 const MockWrite& StaticSocketDataHelper::PeekWrite() const {
210   CHECK(!AllWriteDataConsumed());
211   return writes_[write_index_];
212 }
213 
AdvanceRead()214 const MockRead& StaticSocketDataHelper::AdvanceRead() {
215   CHECK(!AllReadDataConsumed());
216   return reads_[read_index_++];
217 }
218 
AdvanceWrite()219 const MockWrite& StaticSocketDataHelper::AdvanceWrite() {
220   CHECK(!AllWriteDataConsumed());
221   return writes_[write_index_++];
222 }
223 
Reset()224 void StaticSocketDataHelper::Reset() {
225   read_index_ = 0;
226   write_index_ = 0;
227 }
228 
VerifyWriteData(const std::string & data,SocketDataPrinter * printer)229 bool StaticSocketDataHelper::VerifyWriteData(const std::string& data,
230                                              SocketDataPrinter* printer) {
231   CHECK(!AllWriteDataConsumed());
232   // Check that the actual data matches the expectations, skipping over any
233   // pause events.
234   const MockWrite& next_write = PeekRealWrite();
235   if (!next_write.data)
236     return true;
237 
238   // Note: Partial writes are supported here.  If the expected data
239   // is a match, but shorter than the write actually written, that is legal.
240   // Example:
241   //   Application writes "foobarbaz" (9 bytes)
242   //   Expected write was "foo" (3 bytes)
243   //   This is a success, and the function returns true.
244   std::string expected_data(next_write.data, next_write.data_len);
245   std::string actual_data(data.substr(0, next_write.data_len));
246   EXPECT_GE(data.length(), expected_data.length());
247   EXPECT_TRUE(actual_data == expected_data)
248       << "Actual write data:\n" << HexDump(data)
249       << "Expected write data:\n" << HexDump(expected_data);
250   if (printer) {
251     EXPECT_TRUE(actual_data == expected_data)
252         << "Actual write data:\n"
253         << printer->PrintWrite(data) << "Expected write data:\n"
254         << printer->PrintWrite(expected_data);
255   }
256   return expected_data == actual_data;
257 }
258 
PeekRealWrite() const259 const MockWrite& StaticSocketDataHelper::PeekRealWrite() const {
260   for (size_t i = write_index_; i < write_count(); i++) {
261     if (writes_[i].mode != ASYNC || writes_[i].result != ERR_IO_PENDING)
262       return writes_[i];
263   }
264 
265   CHECK(false) << "No write data available.";
266   return writes_[0];  // Avoid warning about unreachable missing return.
267 }
268 
StaticSocketDataProvider()269 StaticSocketDataProvider::StaticSocketDataProvider()
270     : StaticSocketDataProvider(base::span<const MockRead>(),
271                                base::span<const MockWrite>()) {}
272 
StaticSocketDataProvider(base::span<const MockRead> reads,base::span<const MockWrite> writes)273 StaticSocketDataProvider::StaticSocketDataProvider(
274     base::span<const MockRead> reads,
275     base::span<const MockWrite> writes)
276     : helper_(reads, writes) {}
277 
278 StaticSocketDataProvider::~StaticSocketDataProvider() = default;
279 
Pause()280 void StaticSocketDataProvider::Pause() {
281   paused_ = true;
282 }
283 
Resume()284 void StaticSocketDataProvider::Resume() {
285   paused_ = false;
286 }
287 
OnRead()288 MockRead StaticSocketDataProvider::OnRead() {
289   if (AllReadDataConsumed()) {
290     const net::MockRead pending_read(net::SYNCHRONOUS, net::ERR_IO_PENDING);
291     return pending_read;
292   }
293 
294   return helper_.AdvanceRead();
295 }
296 
OnWrite(const std::string & data)297 MockWriteResult StaticSocketDataProvider::OnWrite(const std::string& data) {
298   if (helper_.write_count() == 0) {
299     // Not using mock writes; succeed synchronously.
300     return MockWriteResult(SYNCHRONOUS, data.length());
301   }
302   EXPECT_FALSE(helper_.AllWriteDataConsumed())
303       << "No more mock data to match write:\n"
304       << HexDump(data);
305   if (helper_.AllWriteDataConsumed()) {
306     return MockWriteResult(SYNCHRONOUS, ERR_UNEXPECTED);
307   }
308 
309   // Check that what we are writing matches the expectation.
310   // Then give the mocked return value.
311   if (!helper_.VerifyWriteData(data, printer_))
312     return MockWriteResult(SYNCHRONOUS, ERR_UNEXPECTED);
313 
314   const MockWrite& next_write = helper_.AdvanceWrite();
315   // In the case that the write was successful, return the number of bytes
316   // written. Otherwise return the error code.
317   int result =
318       next_write.result == OK ? next_write.data_len : next_write.result;
319   return MockWriteResult(next_write.mode, result);
320 }
321 
AllReadDataConsumed() const322 bool StaticSocketDataProvider::AllReadDataConsumed() const {
323   return paused_ || helper_.AllReadDataConsumed();
324 }
325 
AllWriteDataConsumed() const326 bool StaticSocketDataProvider::AllWriteDataConsumed() const {
327   return helper_.AllWriteDataConsumed();
328 }
329 
Reset()330 void StaticSocketDataProvider::Reset() {
331   helper_.Reset();
332 }
333 
ProxyClientSocketDataProvider(IoMode mode,int result)334 ProxyClientSocketDataProvider::ProxyClientSocketDataProvider(IoMode mode,
335                                                              int result)
336     : connect(mode, result) {}
337 
338 ProxyClientSocketDataProvider::ProxyClientSocketDataProvider(
339     const ProxyClientSocketDataProvider& other) = default;
340 
341 ProxyClientSocketDataProvider::~ProxyClientSocketDataProvider() = default;
342 
SSLSocketDataProvider(IoMode mode,int result)343 SSLSocketDataProvider::SSLSocketDataProvider(IoMode mode, int result)
344     : connect(mode, result),
345       next_proto(kProtoUnknown),
346       cert_request_info(nullptr),
347       expected_ssl_version_min(kDefaultSSLVersionMin),
348       expected_ssl_version_max(kDefaultSSLVersionMax) {
349   SSLConnectionStatusSetVersion(SSL_CONNECTION_VERSION_TLS1_3,
350                                 &ssl_info.connection_status);
351   // Set to TLS_CHACHA20_POLY1305_SHA256
352   SSLConnectionStatusSetCipherSuite(0x1301, &ssl_info.connection_status);
353 }
354 
355 SSLSocketDataProvider::SSLSocketDataProvider(
356     const SSLSocketDataProvider& other) = default;
357 
358 SSLSocketDataProvider::~SSLSocketDataProvider() = default;
359 
SequencedSocketData()360 SequencedSocketData::SequencedSocketData()
361     : SequencedSocketData(base::span<const MockRead>(),
362                           base::span<const MockWrite>()) {}
363 
SequencedSocketData(base::span<const MockRead> reads,base::span<const MockWrite> writes)364 SequencedSocketData::SequencedSocketData(base::span<const MockRead> reads,
365                                          base::span<const MockWrite> writes)
366     : helper_(reads, writes),
367       sequence_number_(0),
368       read_state_(IDLE),
369       write_state_(IDLE),
370       busy_before_sync_reads_(false) {
371   // Check that reads and writes have a contiguous set of sequence numbers
372   // starting from 0 and working their way up, with no repeats and skipping
373   // no values.
374   int next_sequence_number = 0;
375   bool last_event_was_pause = false;
376 
377   auto next_read = reads.begin();
378   auto next_write = writes.begin();
379   while (next_read != reads.end() || next_write != writes.end()) {
380     if (next_read != reads.end() &&
381         next_read->sequence_number == next_sequence_number) {
382       // Check if this is a pause.
383       if (next_read->mode == ASYNC && next_read->result == ERR_IO_PENDING) {
384         CHECK(!last_event_was_pause) << "Two pauses in a row are not allowed: "
385                                      << next_sequence_number;
386         last_event_was_pause = true;
387       } else if (last_event_was_pause) {
388         CHECK_EQ(ASYNC, next_read->mode)
389             << "A sync event after a pause makes no sense: "
390             << next_sequence_number;
391         CHECK_NE(ERR_IO_PENDING, next_read->result)
392             << "A pause event after a pause makes no sense: "
393             << next_sequence_number;
394         last_event_was_pause = false;
395       }
396 
397       ++next_read;
398       ++next_sequence_number;
399       continue;
400     }
401     if (next_write != writes.end() &&
402         next_write->sequence_number == next_sequence_number) {
403       // Check if this is a pause.
404       if (next_write->mode == ASYNC && next_write->result == ERR_IO_PENDING) {
405         CHECK(!last_event_was_pause) << "Two pauses in a row are not allowed: "
406                                      << next_sequence_number;
407         last_event_was_pause = true;
408       } else if (last_event_was_pause) {
409         CHECK_EQ(ASYNC, next_write->mode)
410             << "A sync event after a pause makes no sense: "
411             << next_sequence_number;
412         CHECK_NE(ERR_IO_PENDING, next_write->result)
413             << "A pause event after a pause makes no sense: "
414             << next_sequence_number;
415         last_event_was_pause = false;
416       }
417 
418       ++next_write;
419       ++next_sequence_number;
420       continue;
421     }
422     if (next_write != writes.end()) {
423       CHECK(false) << "Sequence number " << next_write->sequence_number
424                    << " not found where expected: " << next_sequence_number;
425     } else {
426       CHECK(false) << "Too few writes, next expected sequence number: "
427                    << next_sequence_number;
428     }
429     return;
430   }
431 
432   // Last event must not be a pause.  For the final event to indicate the
433   // operation never completes, it should be SYNCHRONOUS and return
434   // ERR_IO_PENDING.
435   CHECK(!last_event_was_pause);
436 
437   CHECK(next_read == reads.end());
438   CHECK(next_write == writes.end());
439 }
440 
SequencedSocketData(const MockConnect & connect,base::span<const MockRead> reads,base::span<const MockWrite> writes)441 SequencedSocketData::SequencedSocketData(const MockConnect& connect,
442                                          base::span<const MockRead> reads,
443                                          base::span<const MockWrite> writes)
444     : SequencedSocketData(reads, writes) {
445   set_connect_data(connect);
446 }
447 
OnRead()448 MockRead SequencedSocketData::OnRead() {
449   CHECK_EQ(IDLE, read_state_);
450   CHECK(!helper_.AllReadDataConsumed())
451       << "Application tried to read but there is no read data left";
452 
453   NET_TRACE(1, " *** ") << "sequence_number: " << sequence_number_;
454   const MockRead& next_read = helper_.PeekRead();
455   NET_TRACE(1, " *** ") << "next_read: " << next_read.sequence_number;
456   CHECK_GE(next_read.sequence_number, sequence_number_);
457 
458   if (next_read.sequence_number <= sequence_number_) {
459     if (next_read.mode == SYNCHRONOUS) {
460       NET_TRACE(1, " *** ") << "Returning synchronously";
461       DumpMockReadWrite(next_read);
462       helper_.AdvanceRead();
463       ++sequence_number_;
464       MaybePostWriteCompleteTask();
465       return next_read;
466     }
467 
468     // If the result is ERR_IO_PENDING, then pause.
469     if (next_read.result == ERR_IO_PENDING) {
470       NET_TRACE(1, " *** ") << "Pausing read at: " << sequence_number_;
471       read_state_ = PAUSED;
472       if (run_until_paused_run_loop_)
473         run_until_paused_run_loop_->Quit();
474       return MockRead(SYNCHRONOUS, ERR_IO_PENDING);
475     }
476     base::ThreadTaskRunnerHandle::Get()->PostTask(
477         FROM_HERE, base::BindOnce(&SequencedSocketData::OnReadComplete,
478                                   weak_factory_.GetWeakPtr()));
479     CHECK_NE(COMPLETING, write_state_);
480     read_state_ = COMPLETING;
481   } else if (next_read.mode == SYNCHRONOUS) {
482     ADD_FAILURE() << "Unable to perform synchronous IO while stopped";
483     return MockRead(SYNCHRONOUS, ERR_UNEXPECTED);
484   } else {
485     NET_TRACE(1, " *** ") << "Waiting for write to trigger read";
486     read_state_ = PENDING;
487   }
488 
489   return MockRead(SYNCHRONOUS, ERR_IO_PENDING);
490 }
491 
OnWrite(const std::string & data)492 MockWriteResult SequencedSocketData::OnWrite(const std::string& data) {
493   CHECK_EQ(IDLE, write_state_);
494   CHECK(!helper_.AllWriteDataConsumed())
495       << "\nNo more mock data to match write:\n"
496       << HexDump(data);
497 
498   NET_TRACE(1, " *** ") << "sequence_number: " << sequence_number_;
499   const MockWrite& next_write = helper_.PeekWrite();
500   NET_TRACE(1, " *** ") << "next_write: " << next_write.sequence_number;
501   CHECK_GE(next_write.sequence_number, sequence_number_);
502 
503   if (!helper_.VerifyWriteData(data, printer_))
504     return MockWriteResult(SYNCHRONOUS, ERR_UNEXPECTED);
505 
506   if (next_write.sequence_number <= sequence_number_) {
507     if (next_write.mode == SYNCHRONOUS) {
508       helper_.AdvanceWrite();
509       ++sequence_number_;
510       MaybePostReadCompleteTask();
511       // In the case that the write was successful, return the number of bytes
512       // written. Otherwise return the error code.
513       int rv =
514           next_write.result != OK ? next_write.result : next_write.data_len;
515       NET_TRACE(1, " *** ") << "Returning synchronously";
516       return MockWriteResult(SYNCHRONOUS, rv);
517     }
518 
519     // If the result is ERR_IO_PENDING, then pause.
520     if (next_write.result == ERR_IO_PENDING) {
521       NET_TRACE(1, " *** ") << "Pausing write at: " << sequence_number_;
522       write_state_ = PAUSED;
523       if (run_until_paused_run_loop_)
524         run_until_paused_run_loop_->Quit();
525       return MockWriteResult(SYNCHRONOUS, ERR_IO_PENDING);
526     }
527 
528     NET_TRACE(1, " *** ") << "Posting task to complete write";
529     base::ThreadTaskRunnerHandle::Get()->PostTask(
530         FROM_HERE, base::BindOnce(&SequencedSocketData::OnWriteComplete,
531                                   weak_factory_.GetWeakPtr()));
532     CHECK_NE(COMPLETING, read_state_);
533     write_state_ = COMPLETING;
534   } else if (next_write.mode == SYNCHRONOUS) {
535     ADD_FAILURE() << "Unable to perform synchronous IO while stopped";
536     return MockWriteResult(SYNCHRONOUS, ERR_UNEXPECTED);
537   } else {
538     NET_TRACE(1, " *** ") << "Waiting for read to trigger write";
539     write_state_ = PENDING;
540   }
541 
542   return MockWriteResult(SYNCHRONOUS, ERR_IO_PENDING);
543 }
544 
AllReadDataConsumed() const545 bool SequencedSocketData::AllReadDataConsumed() const {
546   return helper_.AllReadDataConsumed();
547 }
548 
CancelPendingRead()549 void SequencedSocketData::CancelPendingRead() {
550   DCHECK_EQ(PENDING, read_state_);
551 
552   read_state_ = IDLE;
553 }
554 
AllWriteDataConsumed() const555 bool SequencedSocketData::AllWriteDataConsumed() const {
556   return helper_.AllWriteDataConsumed();
557 }
558 
IsIdle() const559 bool SequencedSocketData::IsIdle() const {
560   // If |busy_before_sync_reads_| is not set, always considered idle.  If
561   // no reads left, or the next operation is a write, also consider it idle.
562   if (!busy_before_sync_reads_ || helper_.AllReadDataConsumed() ||
563       helper_.PeekRead().sequence_number != sequence_number_) {
564     return true;
565   }
566 
567   // If the next operation is synchronous read, treat the socket as not idle.
568   if (helper_.PeekRead().mode == SYNCHRONOUS)
569     return false;
570   return true;
571 }
572 
IsPaused() const573 bool SequencedSocketData::IsPaused() const {
574   // Both states should not be paused.
575   DCHECK(read_state_ != PAUSED || write_state_ != PAUSED);
576   return write_state_ == PAUSED || read_state_ == PAUSED;
577 }
578 
Resume()579 void SequencedSocketData::Resume() {
580   if (!IsPaused()) {
581     ADD_FAILURE() << "Unable to Resume when not paused.";
582     return;
583   }
584 
585   sequence_number_++;
586   if (read_state_ == PAUSED) {
587     read_state_ = PENDING;
588     helper_.AdvanceRead();
589   } else {  // write_state_ == PAUSED
590     write_state_ = PENDING;
591     helper_.AdvanceWrite();
592   }
593 
594   if (!helper_.AllWriteDataConsumed() &&
595       helper_.PeekWrite().sequence_number == sequence_number_) {
596     // The next event hasn't even started yet.  Pausing isn't really needed in
597     // that case, but may as well support it.
598     if (write_state_ != PENDING)
599       return;
600     write_state_ = COMPLETING;
601     OnWriteComplete();
602     return;
603   }
604 
605   CHECK(!helper_.AllReadDataConsumed());
606 
607   // The next event hasn't even started yet.  Pausing isn't really needed in
608   // that case, but may as well support it.
609   if (read_state_ != PENDING)
610     return;
611   read_state_ = COMPLETING;
612   OnReadComplete();
613 }
614 
RunUntilPaused()615 void SequencedSocketData::RunUntilPaused() {
616   CHECK(!run_until_paused_run_loop_);
617 
618   if (IsPaused())
619     return;
620 
621   run_until_paused_run_loop_.reset(new base::RunLoop());
622   run_until_paused_run_loop_->Run();
623   run_until_paused_run_loop_.reset();
624   DCHECK(IsPaused());
625 }
626 
MaybePostReadCompleteTask()627 void SequencedSocketData::MaybePostReadCompleteTask() {
628   NET_TRACE(1, " ****** ") << " current: " << sequence_number_;
629   // Only trigger the next read to complete if there is already a read pending
630   // which should complete at the current sequence number.
631   if (read_state_ != PENDING ||
632       helper_.PeekRead().sequence_number != sequence_number_) {
633     return;
634   }
635 
636   // If the result is ERR_IO_PENDING, then pause.
637   if (helper_.PeekRead().result == ERR_IO_PENDING) {
638     NET_TRACE(1, " *** ") << "Pausing read at: " << sequence_number_;
639     read_state_ = PAUSED;
640     if (run_until_paused_run_loop_)
641       run_until_paused_run_loop_->Quit();
642     return;
643   }
644 
645   NET_TRACE(1, " ****** ") << "Posting task to complete read: "
646                            << sequence_number_;
647   base::ThreadTaskRunnerHandle::Get()->PostTask(
648       FROM_HERE, base::BindOnce(&SequencedSocketData::OnReadComplete,
649                                 weak_factory_.GetWeakPtr()));
650   CHECK_NE(COMPLETING, write_state_);
651   read_state_ = COMPLETING;
652 }
653 
MaybePostWriteCompleteTask()654 void SequencedSocketData::MaybePostWriteCompleteTask() {
655   NET_TRACE(1, " ****** ") << " current: " << sequence_number_;
656   // Only trigger the next write to complete if there is already a write pending
657   // which should complete at the current sequence number.
658   if (write_state_ != PENDING ||
659       helper_.PeekWrite().sequence_number != sequence_number_) {
660     return;
661   }
662 
663   // If the result is ERR_IO_PENDING, then pause.
664   if (helper_.PeekWrite().result == ERR_IO_PENDING) {
665     NET_TRACE(1, " *** ") << "Pausing write at: " << sequence_number_;
666     write_state_ = PAUSED;
667     if (run_until_paused_run_loop_)
668       run_until_paused_run_loop_->Quit();
669     return;
670   }
671 
672   NET_TRACE(1, " ****** ") << "Posting task to complete write: "
673                            << sequence_number_;
674   base::ThreadTaskRunnerHandle::Get()->PostTask(
675       FROM_HERE, base::BindOnce(&SequencedSocketData::OnWriteComplete,
676                                 weak_factory_.GetWeakPtr()));
677   CHECK_NE(COMPLETING, read_state_);
678   write_state_ = COMPLETING;
679 }
680 
Reset()681 void SequencedSocketData::Reset() {
682   helper_.Reset();
683   sequence_number_ = 0;
684   read_state_ = IDLE;
685   write_state_ = IDLE;
686   weak_factory_.InvalidateWeakPtrs();
687 }
688 
OnReadComplete()689 void SequencedSocketData::OnReadComplete() {
690   CHECK_EQ(COMPLETING, read_state_);
691   NET_TRACE(1, " *** ") << "Completing read for: " << sequence_number_;
692 
693   MockRead data = helper_.AdvanceRead();
694   DCHECK_EQ(sequence_number_, data.sequence_number);
695   sequence_number_++;
696   read_state_ = IDLE;
697 
698   // The result of this read completing might trigger the completion
699   // of a pending write. If so, post a task to complete the write later.
700   // Since the socket may call back into the SequencedSocketData
701   // from socket()->OnReadComplete(), trigger the write task to be posted
702   // before calling that.
703   MaybePostWriteCompleteTask();
704 
705   if (!socket()) {
706     NET_TRACE(1, " *** ") << "No socket available to complete read";
707     return;
708   }
709 
710   NET_TRACE(1, " *** ") << "Completing socket read for: "
711                         << data.sequence_number;
712   DumpMockReadWrite(data);
713   socket()->OnReadComplete(data);
714   NET_TRACE(1, " *** ") << "Done";
715 }
716 
OnWriteComplete()717 void SequencedSocketData::OnWriteComplete() {
718   CHECK_EQ(COMPLETING, write_state_);
719   NET_TRACE(1, " *** ") << " Completing write for: " << sequence_number_;
720 
721   const MockWrite& data = helper_.AdvanceWrite();
722   DCHECK_EQ(sequence_number_, data.sequence_number);
723   sequence_number_++;
724   write_state_ = IDLE;
725   int rv = data.result == OK ? data.data_len : data.result;
726 
727   // The result of this write completing might trigger the completion
728   // of a pending read. If so, post a task to complete the read later.
729   // Since the socket may call back into the SequencedSocketData
730   // from socket()->OnWriteComplete(), trigger the write task to be posted
731   // before calling that.
732   MaybePostReadCompleteTask();
733 
734   if (!socket()) {
735     NET_TRACE(1, " *** ") << "No socket available to complete write";
736     return;
737   }
738 
739   NET_TRACE(1, " *** ") << " Completing socket write for: "
740                         << data.sequence_number;
741   socket()->OnWriteComplete(rv);
742   NET_TRACE(1, " *** ") << "Done";
743 }
744 
745 SequencedSocketData::~SequencedSocketData() = default;
746 
MockClientSocketFactory()747 MockClientSocketFactory::MockClientSocketFactory()
748     : enable_read_if_ready_(false) {}
749 
750 MockClientSocketFactory::~MockClientSocketFactory() = default;
751 
AddSocketDataProvider(SocketDataProvider * data)752 void MockClientSocketFactory::AddSocketDataProvider(
753     SocketDataProvider* data) {
754   mock_data_.Add(data);
755 }
756 
AddSSLSocketDataProvider(SSLSocketDataProvider * data)757 void MockClientSocketFactory::AddSSLSocketDataProvider(
758     SSLSocketDataProvider* data) {
759   mock_ssl_data_.Add(data);
760 }
761 
AddProxyClientSocketDataProvider(ProxyClientSocketDataProvider * data)762 void MockClientSocketFactory::AddProxyClientSocketDataProvider(
763     ProxyClientSocketDataProvider* data) {
764   mock_proxy_data_.Add(data);
765 }
766 
ResetNextMockIndexes()767 void MockClientSocketFactory::ResetNextMockIndexes() {
768   mock_data_.ResetNextIndex();
769   mock_ssl_data_.ResetNextIndex();
770 }
771 
772 std::unique_ptr<DatagramClientSocket>
CreateDatagramClientSocket(DatagramSocket::BindType bind_type,NetLog * net_log,const NetLogSource & source)773 MockClientSocketFactory::CreateDatagramClientSocket(
774     DatagramSocket::BindType bind_type,
775     NetLog* net_log,
776     const NetLogSource& source) {
777   SocketDataProvider* data_provider = mock_data_.GetNext();
778   std::unique_ptr<MockUDPClientSocket> socket(
779       new MockUDPClientSocket(data_provider, net_log));
780   if (bind_type == DatagramSocket::RANDOM_BIND)
781     socket->set_source_port(static_cast<uint16_t>(base::RandInt(1025, 65535)));
782   udp_client_socket_ports_.push_back(socket->source_port());
783   return std::move(socket);
784 }
785 
786 std::unique_ptr<TransportClientSocket>
CreateTransportClientSocket(const AddressList & addresses,std::unique_ptr<SocketPerformanceWatcher> socket_performance_watcher,NetLog * net_log,const NetLogSource & source)787 MockClientSocketFactory::CreateTransportClientSocket(
788     const AddressList& addresses,
789     std::unique_ptr<SocketPerformanceWatcher> socket_performance_watcher,
790     NetLog* net_log,
791     const NetLogSource& source) {
792   SocketDataProvider* data_provider = mock_data_.GetNext();
793   std::unique_ptr<MockTCPClientSocket> socket(
794       new MockTCPClientSocket(addresses, net_log, data_provider));
795   if (enable_read_if_ready_)
796     socket->set_enable_read_if_ready(enable_read_if_ready_);
797   return std::move(socket);
798 }
799 
CreateSSLClientSocket(SSLClientContext * context,std::unique_ptr<StreamSocket> stream_socket,const HostPortPair & host_and_port,const SSLConfig & ssl_config)800 std::unique_ptr<SSLClientSocket> MockClientSocketFactory::CreateSSLClientSocket(
801     SSLClientContext* context,
802     std::unique_ptr<StreamSocket> stream_socket,
803     const HostPortPair& host_and_port,
804     const SSLConfig& ssl_config) {
805   SSLSocketDataProvider* next_ssl_data = mock_ssl_data_.GetNext();
806   if (next_ssl_data->next_protos_expected_in_ssl_config.has_value()) {
807     EXPECT_EQ(next_ssl_data->next_protos_expected_in_ssl_config.value().size(),
808               ssl_config.alpn_protos.size());
809     EXPECT_TRUE(std::equal(
810         next_ssl_data->next_protos_expected_in_ssl_config.value().begin(),
811         next_ssl_data->next_protos_expected_in_ssl_config.value().end(),
812         ssl_config.alpn_protos.begin()));
813   }
814 
815   // The protocol version used is a combination of the per-socket SSLConfig and
816   // the SSLConfigService.
817   EXPECT_EQ(
818       next_ssl_data->expected_ssl_version_min,
819       ssl_config.version_min_override.value_or(context->config().version_min));
820   EXPECT_EQ(
821       next_ssl_data->expected_ssl_version_max,
822       ssl_config.version_max_override.value_or(context->config().version_max));
823 
824   if (next_ssl_data->expected_send_client_cert) {
825     // Client certificate preferences come from |context|.
826     scoped_refptr<X509Certificate> client_cert;
827     scoped_refptr<SSLPrivateKey> client_private_key;
828     bool send_client_cert = context->GetClientCertificate(
829         host_and_port, &client_cert, &client_private_key);
830 
831     EXPECT_EQ(*next_ssl_data->expected_send_client_cert, send_client_cert);
832     // Note |send_client_cert| may be true while |client_cert| is null if the
833     // socket is configured to continue without a certificate, as opposed to
834     // surfacing the certificate challenge.
835     EXPECT_EQ(!!next_ssl_data->expected_client_cert, !!client_cert);
836     if (next_ssl_data->expected_client_cert && client_cert) {
837       EXPECT_TRUE(next_ssl_data->expected_client_cert->EqualsIncludingChain(
838           client_cert.get()));
839     }
840   }
841   if (next_ssl_data->expected_host_and_port) {
842     EXPECT_EQ(*next_ssl_data->expected_host_and_port, host_and_port);
843   }
844   if (next_ssl_data->expected_network_isolation_key) {
845     EXPECT_EQ(*next_ssl_data->expected_network_isolation_key,
846               ssl_config.network_isolation_key);
847   }
848   if (next_ssl_data->expected_disable_legacy_crypto) {
849     EXPECT_EQ(*next_ssl_data->expected_disable_legacy_crypto,
850               ssl_config.disable_legacy_crypto);
851   }
852   return std::unique_ptr<SSLClientSocket>(new MockSSLClientSocket(
853       std::move(stream_socket), host_and_port, ssl_config, next_ssl_data));
854 }
855 
856 std::unique_ptr<ProxyClientSocket>
CreateProxyClientSocket(std::unique_ptr<StreamSocket> stream_socket,const std::string & user_agent,const HostPortPair & endpoint,const ProxyServer & proxy_server,HttpAuthController * http_auth_controller,bool tunnel,bool using_spdy,NextProto negotiated_protocol,ProxyDelegate * proxy_delegate,const NetworkTrafficAnnotationTag & traffic_annotation)857 MockClientSocketFactory::CreateProxyClientSocket(
858     std::unique_ptr<StreamSocket> stream_socket,
859     const std::string& user_agent,
860     const HostPortPair& endpoint,
861     const ProxyServer& proxy_server,
862     HttpAuthController* http_auth_controller,
863     bool tunnel,
864     bool using_spdy,
865     NextProto negotiated_protocol,
866     ProxyDelegate* proxy_delegate,
867     const NetworkTrafficAnnotationTag& traffic_annotation) {
868   if (use_mock_proxy_client_sockets_) {
869     ProxyClientSocketDataProvider* next_proxy_data = mock_proxy_data_.GetNext();
870     return std::make_unique<MockProxyClientSocket>(
871         std::move(stream_socket), http_auth_controller, next_proxy_data);
872   } else {
873     return GetDefaultFactory()->CreateProxyClientSocket(
874         std::move(stream_socket), user_agent, endpoint, proxy_server,
875         http_auth_controller, tunnel, using_spdy, negotiated_protocol,
876         proxy_delegate, traffic_annotation);
877   }
878 }
879 
MockClientSocket(const NetLogWithSource & net_log)880 MockClientSocket::MockClientSocket(const NetLogWithSource& net_log)
881     : connected_(false), net_log_(net_log) {
882   local_addr_ = IPEndPoint(IPAddress(192, 0, 2, 33), 123);
883   peer_addr_ = IPEndPoint(IPAddress(192, 0, 2, 33), 0);
884 }
885 
SetReceiveBufferSize(int32_t size)886 int MockClientSocket::SetReceiveBufferSize(int32_t size) {
887   return OK;
888 }
889 
SetSendBufferSize(int32_t size)890 int MockClientSocket::SetSendBufferSize(int32_t size) {
891   return OK;
892 }
893 
Bind(const net::IPEndPoint & local_addr)894 int MockClientSocket::Bind(const net::IPEndPoint& local_addr) {
895   local_addr_ = local_addr;
896   return net::OK;
897 }
898 
SetNoDelay(bool no_delay)899 bool MockClientSocket::SetNoDelay(bool no_delay) {
900   return true;
901 }
902 
SetKeepAlive(bool enable,int delay)903 bool MockClientSocket::SetKeepAlive(bool enable, int delay) {
904   return true;
905 }
906 
Disconnect()907 void MockClientSocket::Disconnect() {
908   connected_ = false;
909 }
910 
IsConnected() const911 bool MockClientSocket::IsConnected() const {
912   return connected_;
913 }
914 
IsConnectedAndIdle() const915 bool MockClientSocket::IsConnectedAndIdle() const {
916   return connected_;
917 }
918 
GetPeerAddress(IPEndPoint * address) const919 int MockClientSocket::GetPeerAddress(IPEndPoint* address) const {
920   if (!IsConnected())
921     return ERR_SOCKET_NOT_CONNECTED;
922   *address = peer_addr_;
923   return OK;
924 }
925 
GetLocalAddress(IPEndPoint * address) const926 int MockClientSocket::GetLocalAddress(IPEndPoint* address) const {
927   *address = local_addr_;
928   return OK;
929 }
930 
NetLog() const931 const NetLogWithSource& MockClientSocket::NetLog() const {
932   return net_log_;
933 }
934 
WasAlpnNegotiated() const935 bool MockClientSocket::WasAlpnNegotiated() const {
936   return false;
937 }
938 
GetNegotiatedProtocol() const939 NextProto MockClientSocket::GetNegotiatedProtocol() const {
940   return kProtoUnknown;
941 }
942 
GetConnectionAttempts(ConnectionAttempts * out) const943 void MockClientSocket::GetConnectionAttempts(ConnectionAttempts* out) const {
944   out->clear();
945 }
946 
947 MockClientSocket::~MockClientSocket() = default;
948 
RunCallbackAsync(CompletionOnceCallback callback,int result)949 void MockClientSocket::RunCallbackAsync(CompletionOnceCallback callback,
950                                         int result) {
951   base::ThreadTaskRunnerHandle::Get()->PostTask(
952       FROM_HERE,
953       base::BindOnce(&MockClientSocket::RunCallback, weak_factory_.GetWeakPtr(),
954                      std::move(callback), result));
955 }
956 
RunCallback(CompletionOnceCallback callback,int result)957 void MockClientSocket::RunCallback(CompletionOnceCallback callback,
958                                    int result) {
959   std::move(callback).Run(result);
960 }
961 
MockTCPClientSocket(const AddressList & addresses,net::NetLog * net_log,SocketDataProvider * data)962 MockTCPClientSocket::MockTCPClientSocket(const AddressList& addresses,
963                                          net::NetLog* net_log,
964                                          SocketDataProvider* data)
965     : MockClientSocket(NetLogWithSource::Make(net_log, NetLogSourceType::NONE)),
966       addresses_(addresses),
967       data_(data),
968       read_offset_(0),
969       read_data_(SYNCHRONOUS, ERR_UNEXPECTED),
970       need_read_data_(true),
971       peer_closed_connection_(false),
972       pending_read_buf_(nullptr),
973       pending_read_buf_len_(0),
974       was_used_to_convey_data_(false),
975       enable_read_if_ready_(false) {
976   DCHECK(data_);
977   peer_addr_ = data->connect_data().peer_addr;
978   data_->Initialize(this);
979 }
980 
~MockTCPClientSocket()981 MockTCPClientSocket::~MockTCPClientSocket() {
982   if (data_)
983     data_->DetachSocket();
984 }
985 
Read(IOBuffer * buf,int buf_len,CompletionOnceCallback callback)986 int MockTCPClientSocket::Read(IOBuffer* buf,
987                               int buf_len,
988                               CompletionOnceCallback callback) {
989   // If the buffer is already in use, a read is already in progress!
990   DCHECK(!pending_read_buf_);
991   // Use base::Unretained() is safe because MockClientSocket::RunCallbackAsync()
992   // takes a weak ptr of the base class, MockClientSocket.
993   int rv = ReadIfReadyImpl(
994       buf, buf_len,
995       base::BindOnce(&MockTCPClientSocket::RetryRead, base::Unretained(this)));
996   if (rv == ERR_IO_PENDING) {
997     DCHECK(callback);
998 
999     pending_read_buf_ = buf;
1000     pending_read_buf_len_ = buf_len;
1001     pending_read_callback_ = std::move(callback);
1002   }
1003   return rv;
1004 }
1005 
ReadIfReady(IOBuffer * buf,int buf_len,CompletionOnceCallback callback)1006 int MockTCPClientSocket::ReadIfReady(IOBuffer* buf,
1007                                      int buf_len,
1008                                      CompletionOnceCallback callback) {
1009   DCHECK(!pending_read_if_ready_callback_);
1010 
1011   if (!enable_read_if_ready_)
1012     return ERR_READ_IF_READY_NOT_IMPLEMENTED;
1013   return ReadIfReadyImpl(buf, buf_len, std::move(callback));
1014 }
1015 
CancelReadIfReady()1016 int MockTCPClientSocket::CancelReadIfReady() {
1017   DCHECK(pending_read_if_ready_callback_);
1018 
1019   pending_read_if_ready_callback_.Reset();
1020   data_->CancelPendingRead();
1021   return OK;
1022 }
1023 
Write(IOBuffer * buf,int buf_len,CompletionOnceCallback callback,const NetworkTrafficAnnotationTag &)1024 int MockTCPClientSocket::Write(
1025     IOBuffer* buf,
1026     int buf_len,
1027     CompletionOnceCallback callback,
1028     const NetworkTrafficAnnotationTag& /* traffic_annotation */) {
1029   DCHECK(buf);
1030   DCHECK_GT(buf_len, 0);
1031 
1032   if (!connected_ || !data_)
1033     return ERR_UNEXPECTED;
1034 
1035   std::string data(buf->data(), buf_len);
1036   MockWriteResult write_result = data_->OnWrite(data);
1037 
1038   was_used_to_convey_data_ = true;
1039 
1040   if (write_result.result == ERR_CONNECTION_CLOSED) {
1041     // This MockWrite is just a marker to instruct us to set
1042     // peer_closed_connection_.
1043     peer_closed_connection_ = true;
1044   }
1045   // ERR_IO_PENDING is a signal that the socket data will call back
1046   // asynchronously later.
1047   if (write_result.result == ERR_IO_PENDING) {
1048     pending_write_callback_ = std::move(callback);
1049     return ERR_IO_PENDING;
1050   }
1051 
1052   if (write_result.mode == ASYNC) {
1053     RunCallbackAsync(std::move(callback), write_result.result);
1054     return ERR_IO_PENDING;
1055   }
1056 
1057   return write_result.result;
1058 }
1059 
SetReceiveBufferSize(int32_t size)1060 int MockTCPClientSocket::SetReceiveBufferSize(int32_t size) {
1061   if (!connected_)
1062     return net::ERR_UNEXPECTED;
1063   data_->set_receive_buffer_size(size);
1064   return data_->set_receive_buffer_size_result();
1065 }
1066 
SetSendBufferSize(int32_t size)1067 int MockTCPClientSocket::SetSendBufferSize(int32_t size) {
1068   if (!connected_)
1069     return net::ERR_UNEXPECTED;
1070   data_->set_send_buffer_size(size);
1071   return data_->set_send_buffer_size_result();
1072 }
1073 
SetNoDelay(bool no_delay)1074 bool MockTCPClientSocket::SetNoDelay(bool no_delay) {
1075   if (!connected_)
1076     return false;
1077   data_->set_no_delay(no_delay);
1078   return data_->set_no_delay_result();
1079 }
1080 
SetKeepAlive(bool enable,int delay)1081 bool MockTCPClientSocket::SetKeepAlive(bool enable, int delay) {
1082   if (!connected_)
1083     return false;
1084   data_->set_keep_alive(enable, delay);
1085   return data_->set_keep_alive_result();
1086 }
1087 
GetConnectionAttempts(ConnectionAttempts * out) const1088 void MockTCPClientSocket::GetConnectionAttempts(ConnectionAttempts* out) const {
1089   *out = connection_attempts_;
1090 }
1091 
ClearConnectionAttempts()1092 void MockTCPClientSocket::ClearConnectionAttempts() {
1093   connection_attempts_.clear();
1094 }
1095 
AddConnectionAttempts(const ConnectionAttempts & attempts)1096 void MockTCPClientSocket::AddConnectionAttempts(
1097     const ConnectionAttempts& attempts) {
1098   connection_attempts_.insert(connection_attempts_.begin(), attempts.begin(),
1099                               attempts.end());
1100 }
1101 
SetBeforeConnectCallback(const BeforeConnectCallback & before_connect_callback)1102 void MockTCPClientSocket::SetBeforeConnectCallback(
1103     const BeforeConnectCallback& before_connect_callback) {
1104   DCHECK(!before_connect_callback_);
1105   DCHECK(!connected_);
1106 
1107   before_connect_callback_ = before_connect_callback;
1108 }
1109 
Connect(CompletionOnceCallback callback)1110 int MockTCPClientSocket::Connect(CompletionOnceCallback callback) {
1111   if (!data_)
1112     return ERR_UNEXPECTED;
1113 
1114   if (connected_)
1115     return OK;
1116 
1117   // Setting socket options fails if not connected, so need to set this before
1118   // calling |before_connect_callback_|.
1119   connected_ = true;
1120 
1121   if (before_connect_callback_) {
1122     int result = before_connect_callback_.Run();
1123     DCHECK_NE(result, ERR_IO_PENDING);
1124     if (result != net::OK) {
1125       connected_ = false;
1126       return result;
1127     }
1128   }
1129 
1130   peer_closed_connection_ = false;
1131 
1132   int result = data_->connect_data().result;
1133   IoMode mode = data_->connect_data().mode;
1134 
1135   if (result != OK && result != ERR_IO_PENDING) {
1136     IPEndPoint address;
1137     if (GetPeerAddress(&address) == OK)
1138       connection_attempts_.push_back(ConnectionAttempt(address, result));
1139   }
1140 
1141   if (mode == SYNCHRONOUS)
1142     return result;
1143 
1144   DCHECK(callback);
1145 
1146   if (result == ERR_IO_PENDING)
1147     pending_connect_callback_ = std::move(callback);
1148   else
1149     RunCallbackAsync(std::move(callback), result);
1150   return ERR_IO_PENDING;
1151 }
1152 
Disconnect()1153 void MockTCPClientSocket::Disconnect() {
1154   MockClientSocket::Disconnect();
1155   pending_connect_callback_.Reset();
1156   pending_read_callback_.Reset();
1157 }
1158 
IsConnected() const1159 bool MockTCPClientSocket::IsConnected() const {
1160   if (!data_)
1161     return false;
1162   return connected_ && !peer_closed_connection_;
1163 }
1164 
IsConnectedAndIdle() const1165 bool MockTCPClientSocket::IsConnectedAndIdle() const {
1166   if (!data_)
1167     return false;
1168   return IsConnected() && data_->IsIdle();
1169 }
1170 
GetPeerAddress(IPEndPoint * address) const1171 int MockTCPClientSocket::GetPeerAddress(IPEndPoint* address) const {
1172   if (addresses_.empty())
1173     return MockClientSocket::GetPeerAddress(address);
1174 
1175   *address = addresses_[0];
1176   return OK;
1177 }
1178 
WasEverUsed() const1179 bool MockTCPClientSocket::WasEverUsed() const {
1180   return was_used_to_convey_data_;
1181 }
1182 
GetSSLInfo(SSLInfo * ssl_info)1183 bool MockTCPClientSocket::GetSSLInfo(SSLInfo* ssl_info) {
1184   return false;
1185 }
1186 
OnReadComplete(const MockRead & data)1187 void MockTCPClientSocket::OnReadComplete(const MockRead& data) {
1188   // If |data_| has been destroyed, safest to just do nothing.
1189   if (!data_)
1190     return;
1191 
1192   // There must be a read pending.
1193   DCHECK(pending_read_if_ready_callback_);
1194   // You can't complete a read with another ERR_IO_PENDING status code.
1195   DCHECK_NE(ERR_IO_PENDING, data.result);
1196   // Since we've been waiting for data, need_read_data_ should be true.
1197   DCHECK(need_read_data_);
1198 
1199   read_data_ = data;
1200   need_read_data_ = false;
1201 
1202   // The caller is simulating that this IO completes right now.  Don't
1203   // let CompleteRead() schedule a callback.
1204   read_data_.mode = SYNCHRONOUS;
1205   RunCallback(std::move(pending_read_if_ready_callback_),
1206               read_data_.result > 0 ? OK : read_data_.result);
1207 }
1208 
OnWriteComplete(int rv)1209 void MockTCPClientSocket::OnWriteComplete(int rv) {
1210   // If |data_| has been destroyed, safest to just do nothing.
1211   if (!data_)
1212     return;
1213 
1214   // There must be a read pending.
1215   DCHECK(!pending_write_callback_.is_null());
1216   RunCallback(std::move(pending_write_callback_), rv);
1217 }
1218 
OnConnectComplete(const MockConnect & data)1219 void MockTCPClientSocket::OnConnectComplete(const MockConnect& data) {
1220   // If |data_| has been destroyed, safest to just do nothing.
1221   if (!data_)
1222     return;
1223 
1224   RunCallback(std::move(pending_connect_callback_), data.result);
1225 }
1226 
OnDataProviderDestroyed()1227 void MockTCPClientSocket::OnDataProviderDestroyed() {
1228   data_ = nullptr;
1229 }
1230 
RetryRead(int rv)1231 void MockTCPClientSocket::RetryRead(int rv) {
1232   DCHECK(pending_read_callback_);
1233   DCHECK(pending_read_buf_.get());
1234   DCHECK_LT(0, pending_read_buf_len_);
1235 
1236   if (rv == OK) {
1237     rv = ReadIfReadyImpl(pending_read_buf_.get(), pending_read_buf_len_,
1238                          base::BindOnce(&MockTCPClientSocket::RetryRead,
1239                                         base::Unretained(this)));
1240     if (rv == ERR_IO_PENDING)
1241       return;
1242   }
1243   pending_read_buf_ = nullptr;
1244   pending_read_buf_len_ = 0;
1245   RunCallback(std::move(pending_read_callback_), rv);
1246 }
1247 
ReadIfReadyImpl(IOBuffer * buf,int buf_len,CompletionOnceCallback callback)1248 int MockTCPClientSocket::ReadIfReadyImpl(IOBuffer* buf,
1249                                          int buf_len,
1250                                          CompletionOnceCallback callback) {
1251   if (!connected_ || !data_)
1252     return ERR_UNEXPECTED;
1253 
1254   DCHECK(!pending_read_if_ready_callback_);
1255 
1256   if (need_read_data_) {
1257     read_data_ = data_->OnRead();
1258     if (read_data_.result == ERR_CONNECTION_CLOSED) {
1259       // This MockRead is just a marker to instruct us to set
1260       // peer_closed_connection_.
1261       peer_closed_connection_ = true;
1262     }
1263     if (read_data_.result == ERR_TEST_PEER_CLOSE_AFTER_NEXT_MOCK_READ) {
1264       // This MockRead is just a marker to instruct us to set
1265       // peer_closed_connection_.  Skip it and get the next one.
1266       read_data_ = data_->OnRead();
1267       peer_closed_connection_ = true;
1268     }
1269     // ERR_IO_PENDING means that the SocketDataProvider is taking responsibility
1270     // to complete the async IO manually later (via OnReadComplete).
1271     if (read_data_.result == ERR_IO_PENDING) {
1272       // We need to be using async IO in this case.
1273       DCHECK(!callback.is_null());
1274       pending_read_if_ready_callback_ = std::move(callback);
1275       return ERR_IO_PENDING;
1276     }
1277     need_read_data_ = false;
1278   }
1279 
1280   int result = read_data_.result;
1281   DCHECK_NE(ERR_IO_PENDING, result);
1282   if (read_data_.mode == ASYNC) {
1283     DCHECK(!callback.is_null());
1284     read_data_.mode = SYNCHRONOUS;
1285     pending_read_if_ready_callback_ = std::move(callback);
1286     // base::Unretained() is safe here because RunCallbackAsync will wrap it
1287     // with a callback associated with a weak ptr.
1288     RunCallbackAsync(
1289         base::BindOnce(&MockTCPClientSocket::RunReadIfReadyCallback,
1290                        base::Unretained(this)),
1291         result);
1292     return ERR_IO_PENDING;
1293   }
1294 
1295   was_used_to_convey_data_ = true;
1296   if (read_data_.data) {
1297     if (read_data_.data_len - read_offset_ > 0) {
1298       result = std::min(buf_len, read_data_.data_len - read_offset_);
1299       memcpy(buf->data(), read_data_.data + read_offset_, result);
1300       read_offset_ += result;
1301       if (read_offset_ == read_data_.data_len) {
1302         need_read_data_ = true;
1303         read_offset_ = 0;
1304       }
1305     } else {
1306       result = 0;  // EOF
1307     }
1308   }
1309   return result;
1310 }
1311 
RunReadIfReadyCallback(int result)1312 void MockTCPClientSocket::RunReadIfReadyCallback(int result) {
1313   // If ReadIfReady is already canceled, do nothing.
1314   if (!pending_read_if_ready_callback_)
1315     return;
1316   std::move(pending_read_if_ready_callback_).Run(result);
1317 }
1318 
MockProxyClientSocket(std::unique_ptr<StreamSocket> socket,HttpAuthController * auth_controller,ProxyClientSocketDataProvider * data)1319 MockProxyClientSocket::MockProxyClientSocket(
1320     std::unique_ptr<StreamSocket> socket,
1321     HttpAuthController* auth_controller,
1322     ProxyClientSocketDataProvider* data)
1323     : net_log_(socket->NetLog()),
1324       socket_(std::move(socket)),
1325       data_(data),
1326       auth_controller_(auth_controller) {
1327   DCHECK(data_);
1328 }
1329 
~MockProxyClientSocket()1330 MockProxyClientSocket::~MockProxyClientSocket() {
1331   Disconnect();
1332 }
1333 
GetConnectResponseInfo() const1334 const HttpResponseInfo* MockProxyClientSocket::GetConnectResponseInfo() const {
1335   return nullptr;
1336 }
1337 
1338 const scoped_refptr<HttpAuthController>&
GetAuthController() const1339 MockProxyClientSocket::GetAuthController() const {
1340   return auth_controller_;
1341 }
1342 
RestartWithAuth(CompletionOnceCallback callback)1343 int MockProxyClientSocket::RestartWithAuth(CompletionOnceCallback callback) {
1344   return net::ERR_NOT_IMPLEMENTED;
1345 }
IsUsingSpdy() const1346 bool MockProxyClientSocket::IsUsingSpdy() const {
1347   return false;
1348 }
1349 
GetProxyNegotiatedProtocol() const1350 NextProto MockProxyClientSocket::GetProxyNegotiatedProtocol() const {
1351   return kProtoUnknown;
1352 }
1353 
Read(IOBuffer * buf,int buf_len,CompletionOnceCallback callback)1354 int MockProxyClientSocket::Read(IOBuffer* buf,
1355                                 int buf_len,
1356                                 CompletionOnceCallback callback) {
1357   return socket_->Read(buf, buf_len, std::move(callback));
1358 }
1359 
ReadIfReady(IOBuffer * buf,int buf_len,CompletionOnceCallback callback)1360 int MockProxyClientSocket::ReadIfReady(IOBuffer* buf,
1361                                        int buf_len,
1362                                        CompletionOnceCallback callback) {
1363   return socket_->ReadIfReady(buf, buf_len, std::move(callback));
1364 }
1365 
Write(IOBuffer * buf,int buf_len,CompletionOnceCallback callback,const NetworkTrafficAnnotationTag & traffic_annotation)1366 int MockProxyClientSocket::Write(
1367     IOBuffer* buf,
1368     int buf_len,
1369     CompletionOnceCallback callback,
1370     const NetworkTrafficAnnotationTag& traffic_annotation) {
1371   return socket_->Write(buf, buf_len, std::move(callback), traffic_annotation);
1372 }
1373 
Connect(CompletionOnceCallback callback)1374 int MockProxyClientSocket::Connect(CompletionOnceCallback callback) {
1375   DCHECK(socket_->IsConnected());
1376   if (data_->connect.mode == ASYNC) {
1377     RunCallbackAsync(std::move(callback), data_->connect.result);
1378     return ERR_IO_PENDING;
1379   }
1380   return data_->connect.result;
1381 }
1382 
Disconnect()1383 void MockProxyClientSocket::Disconnect() {
1384   if (socket_)
1385     socket_->Disconnect();
1386 }
1387 
IsConnected() const1388 bool MockProxyClientSocket::IsConnected() const {
1389   return socket_->IsConnected();
1390 }
1391 
IsConnectedAndIdle() const1392 bool MockProxyClientSocket::IsConnectedAndIdle() const {
1393   return socket_->IsConnectedAndIdle();
1394 }
1395 
WasEverUsed() const1396 bool MockProxyClientSocket::WasEverUsed() const {
1397   return socket_->WasEverUsed();
1398 }
1399 
GetLocalAddress(IPEndPoint * address) const1400 int MockProxyClientSocket::GetLocalAddress(IPEndPoint* address) const {
1401   *address = IPEndPoint(IPAddress(192, 0, 2, 33), 123);
1402   return OK;
1403 }
1404 
GetPeerAddress(IPEndPoint * address) const1405 int MockProxyClientSocket::GetPeerAddress(IPEndPoint* address) const {
1406   return socket_->GetPeerAddress(address);
1407 }
1408 
WasAlpnNegotiated() const1409 bool MockProxyClientSocket::WasAlpnNegotiated() const {
1410   return false;
1411 }
1412 
GetNegotiatedProtocol() const1413 NextProto MockProxyClientSocket::GetNegotiatedProtocol() const {
1414   NOTIMPLEMENTED();
1415   return kProtoUnknown;
1416 }
1417 
GetSSLInfo(SSLInfo * requested_ssl_info)1418 bool MockProxyClientSocket::GetSSLInfo(SSLInfo* requested_ssl_info) {
1419   NOTIMPLEMENTED();
1420   return false;
1421 }
1422 
ApplySocketTag(const SocketTag & tag)1423 void MockProxyClientSocket::ApplySocketTag(const SocketTag& tag) {
1424   return socket_->ApplySocketTag(tag);
1425 }
1426 
NetLog() const1427 const NetLogWithSource& MockProxyClientSocket::NetLog() const {
1428   return net_log_;
1429 }
1430 
GetConnectionAttempts(ConnectionAttempts * out) const1431 void MockProxyClientSocket::GetConnectionAttempts(
1432     ConnectionAttempts* out) const {
1433   NOTIMPLEMENTED();
1434   out->clear();
1435 }
1436 
GetTotalReceivedBytes() const1437 int64_t MockProxyClientSocket::GetTotalReceivedBytes() const {
1438   NOTIMPLEMENTED();
1439   return 0;
1440 }
1441 
SetReceiveBufferSize(int32_t size)1442 int MockProxyClientSocket::SetReceiveBufferSize(int32_t size) {
1443   return OK;
1444 }
1445 
SetSendBufferSize(int32_t size)1446 int MockProxyClientSocket::SetSendBufferSize(int32_t size) {
1447   return OK;
1448 }
1449 
OnReadComplete(const MockRead & data)1450 void MockProxyClientSocket::OnReadComplete(const MockRead& data) {
1451   NOTIMPLEMENTED();
1452 }
1453 
OnWriteComplete(int rv)1454 void MockProxyClientSocket::OnWriteComplete(int rv) {
1455   NOTIMPLEMENTED();
1456 }
1457 
OnConnectComplete(const MockConnect & data)1458 void MockProxyClientSocket::OnConnectComplete(const MockConnect& data) {
1459   NOTIMPLEMENTED();
1460 }
1461 
RunCallbackAsync(CompletionOnceCallback callback,int result)1462 void MockProxyClientSocket::RunCallbackAsync(CompletionOnceCallback callback,
1463                                              int result) {
1464   base::ThreadTaskRunnerHandle::Get()->PostTask(
1465       FROM_HERE,
1466       base::BindOnce(&MockProxyClientSocket::RunCallback,
1467                      weak_factory_.GetWeakPtr(), std::move(callback), result));
1468 }
1469 
RunCallback(CompletionOnceCallback callback,int result)1470 void MockProxyClientSocket::RunCallback(CompletionOnceCallback callback,
1471                                         int result) {
1472   std::move(callback).Run(result);
1473 }
1474 
1475 // static
ConnectCallback(MockSSLClientSocket * ssl_client_socket,CompletionOnceCallback callback,int rv)1476 void MockSSLClientSocket::ConnectCallback(
1477     MockSSLClientSocket* ssl_client_socket,
1478     CompletionOnceCallback callback,
1479     int rv) {
1480   if (rv == OK)
1481     ssl_client_socket->connected_ = true;
1482   std::move(callback).Run(rv);
1483 }
1484 
MockSSLClientSocket(std::unique_ptr<StreamSocket> stream_socket,const HostPortPair & host_and_port,const SSLConfig & ssl_config,SSLSocketDataProvider * data)1485 MockSSLClientSocket::MockSSLClientSocket(
1486     std::unique_ptr<StreamSocket> stream_socket,
1487     const HostPortPair& host_and_port,
1488     const SSLConfig& ssl_config,
1489     SSLSocketDataProvider* data)
1490     : net_log_(stream_socket->NetLog()),
1491       stream_socket_(std::move(stream_socket)),
1492       data_(data) {
1493   DCHECK(data_);
1494   peer_addr_ = data->connect.peer_addr;
1495 }
1496 
~MockSSLClientSocket()1497 MockSSLClientSocket::~MockSSLClientSocket() {
1498   Disconnect();
1499 }
1500 
Read(IOBuffer * buf,int buf_len,CompletionOnceCallback callback)1501 int MockSSLClientSocket::Read(IOBuffer* buf,
1502                               int buf_len,
1503                               CompletionOnceCallback callback) {
1504   return stream_socket_->Read(buf, buf_len, std::move(callback));
1505 }
1506 
ReadIfReady(IOBuffer * buf,int buf_len,CompletionOnceCallback callback)1507 int MockSSLClientSocket::ReadIfReady(IOBuffer* buf,
1508                                      int buf_len,
1509                                      CompletionOnceCallback callback) {
1510   return stream_socket_->ReadIfReady(buf, buf_len, std::move(callback));
1511 }
1512 
Write(IOBuffer * buf,int buf_len,CompletionOnceCallback callback,const NetworkTrafficAnnotationTag & traffic_annotation)1513 int MockSSLClientSocket::Write(
1514     IOBuffer* buf,
1515     int buf_len,
1516     CompletionOnceCallback callback,
1517     const NetworkTrafficAnnotationTag& traffic_annotation) {
1518   if (!data_->is_confirm_data_consumed)
1519     data_->write_called_before_confirm = true;
1520   return stream_socket_->Write(buf, buf_len, std::move(callback),
1521                                traffic_annotation);
1522 }
1523 
CancelReadIfReady()1524 int MockSSLClientSocket::CancelReadIfReady() {
1525   return stream_socket_->CancelReadIfReady();
1526 }
1527 
Connect(CompletionOnceCallback callback)1528 int MockSSLClientSocket::Connect(CompletionOnceCallback callback) {
1529   DCHECK(stream_socket_->IsConnected());
1530   data_->is_connect_data_consumed = true;
1531   if (data_->connect.result == OK)
1532     connected_ = true;
1533   RunClosureIfNonNull(std::move(data_->connect_callback));
1534   if (data_->connect.mode == ASYNC) {
1535     RunCallbackAsync(std::move(callback), data_->connect.result);
1536     return ERR_IO_PENDING;
1537   }
1538   return data_->connect.result;
1539 }
1540 
Disconnect()1541 void MockSSLClientSocket::Disconnect() {
1542   if (stream_socket_ != nullptr)
1543     stream_socket_->Disconnect();
1544 }
1545 
RunConfirmHandshakeCallback(CompletionOnceCallback callback,int result)1546 void MockSSLClientSocket::RunConfirmHandshakeCallback(
1547     CompletionOnceCallback callback,
1548     int result) {
1549   data_->is_confirm_data_consumed = true;
1550   std::move(callback).Run(result);
1551 }
1552 
ConfirmHandshake(CompletionOnceCallback callback)1553 int MockSSLClientSocket::ConfirmHandshake(CompletionOnceCallback callback) {
1554   DCHECK(stream_socket_->IsConnected());
1555   if (data_->is_confirm_data_consumed)
1556     return data_->confirm.result;
1557   RunClosureIfNonNull(std::move(data_->confirm_callback));
1558   if (data_->confirm.mode == ASYNC) {
1559     RunCallbackAsync(
1560         base::BindOnce(&MockSSLClientSocket::RunConfirmHandshakeCallback,
1561                        base::Unretained(this), std::move(callback)),
1562         data_->confirm.result);
1563     return ERR_IO_PENDING;
1564   }
1565   data_->is_confirm_data_consumed = true;
1566   return data_->confirm.result;
1567 }
1568 
IsConnected() const1569 bool MockSSLClientSocket::IsConnected() const {
1570   return stream_socket_->IsConnected();
1571 }
1572 
IsConnectedAndIdle() const1573 bool MockSSLClientSocket::IsConnectedAndIdle() const {
1574   return stream_socket_->IsConnectedAndIdle();
1575 }
1576 
WasEverUsed() const1577 bool MockSSLClientSocket::WasEverUsed() const {
1578   return stream_socket_->WasEverUsed();
1579 }
1580 
GetLocalAddress(IPEndPoint * address) const1581 int MockSSLClientSocket::GetLocalAddress(IPEndPoint* address) const {
1582   *address = IPEndPoint(IPAddress(192, 0, 2, 33), 123);
1583   return OK;
1584 }
1585 
GetPeerAddress(IPEndPoint * address) const1586 int MockSSLClientSocket::GetPeerAddress(IPEndPoint* address) const {
1587   return stream_socket_->GetPeerAddress(address);
1588 }
1589 
WasAlpnNegotiated() const1590 bool MockSSLClientSocket::WasAlpnNegotiated() const {
1591   return data_->next_proto != kProtoUnknown;
1592 }
1593 
GetNegotiatedProtocol() const1594 NextProto MockSSLClientSocket::GetNegotiatedProtocol() const {
1595   return data_->next_proto;
1596 }
1597 
GetSSLInfo(SSLInfo * requested_ssl_info)1598 bool MockSSLClientSocket::GetSSLInfo(SSLInfo* requested_ssl_info) {
1599   requested_ssl_info->Reset();
1600   *requested_ssl_info = data_->ssl_info;
1601   return true;
1602 }
1603 
ApplySocketTag(const SocketTag & tag)1604 void MockSSLClientSocket::ApplySocketTag(const SocketTag& tag) {
1605   return stream_socket_->ApplySocketTag(tag);
1606 }
1607 
NetLog() const1608 const NetLogWithSource& MockSSLClientSocket::NetLog() const {
1609   return net_log_;
1610 }
1611 
GetConnectionAttempts(ConnectionAttempts * out) const1612 void MockSSLClientSocket::GetConnectionAttempts(ConnectionAttempts* out) const {
1613   out->clear();
1614 }
1615 
GetTotalReceivedBytes() const1616 int64_t MockSSLClientSocket::GetTotalReceivedBytes() const {
1617   NOTIMPLEMENTED();
1618   return 0;
1619 }
1620 
GetTotalReceivedBytes() const1621 int64_t MockClientSocket::GetTotalReceivedBytes() const {
1622   NOTIMPLEMENTED();
1623   return 0;
1624 }
1625 
SetReceiveBufferSize(int32_t size)1626 int MockSSLClientSocket::SetReceiveBufferSize(int32_t size) {
1627   return OK;
1628 }
1629 
SetSendBufferSize(int32_t size)1630 int MockSSLClientSocket::SetSendBufferSize(int32_t size) {
1631   return OK;
1632 }
1633 
GetSSLCertRequestInfo(SSLCertRequestInfo * cert_request_info) const1634 void MockSSLClientSocket::GetSSLCertRequestInfo(
1635     SSLCertRequestInfo* cert_request_info) const {
1636   DCHECK(cert_request_info);
1637   if (data_->cert_request_info) {
1638     cert_request_info->host_and_port =
1639         data_->cert_request_info->host_and_port;
1640     cert_request_info->is_proxy = data_->cert_request_info->is_proxy;
1641     cert_request_info->cert_authorities =
1642         data_->cert_request_info->cert_authorities;
1643     cert_request_info->cert_key_types =
1644         data_->cert_request_info->cert_key_types;
1645   } else {
1646     cert_request_info->Reset();
1647   }
1648 }
1649 
ExportKeyingMaterial(const base::StringPiece & label,bool has_context,const base::StringPiece & context,unsigned char * out,unsigned int outlen)1650 int MockSSLClientSocket::ExportKeyingMaterial(const base::StringPiece& label,
1651                                               bool has_context,
1652                                               const base::StringPiece& context,
1653                                               unsigned char* out,
1654                                               unsigned int outlen) {
1655   memset(out, 'A', outlen);
1656   return OK;
1657 }
1658 
RunCallbackAsync(CompletionOnceCallback callback,int result)1659 void MockSSLClientSocket::RunCallbackAsync(CompletionOnceCallback callback,
1660                                            int result) {
1661   base::ThreadTaskRunnerHandle::Get()->PostTask(
1662       FROM_HERE,
1663       base::BindOnce(&MockSSLClientSocket::RunCallback,
1664                      weak_factory_.GetWeakPtr(), std::move(callback), result));
1665 }
1666 
RunCallback(CompletionOnceCallback callback,int result)1667 void MockSSLClientSocket::RunCallback(CompletionOnceCallback callback,
1668                                       int result) {
1669   std::move(callback).Run(result);
1670 }
1671 
OnReadComplete(const MockRead & data)1672 void MockSSLClientSocket::OnReadComplete(const MockRead& data) {
1673   NOTIMPLEMENTED();
1674 }
1675 
OnWriteComplete(int rv)1676 void MockSSLClientSocket::OnWriteComplete(int rv) {
1677   NOTIMPLEMENTED();
1678 }
1679 
OnConnectComplete(const MockConnect & data)1680 void MockSSLClientSocket::OnConnectComplete(const MockConnect& data) {
1681   NOTIMPLEMENTED();
1682 }
1683 
MockUDPClientSocket(SocketDataProvider * data,net::NetLog * net_log)1684 MockUDPClientSocket::MockUDPClientSocket(SocketDataProvider* data,
1685                                          net::NetLog* net_log)
1686     : connected_(false),
1687       data_(data),
1688       read_offset_(0),
1689       read_data_(SYNCHRONOUS, ERR_UNEXPECTED),
1690       need_read_data_(true),
1691       source_port_(123),
1692       network_(NetworkChangeNotifier::kInvalidNetworkHandle),
1693       pending_read_buf_(nullptr),
1694       pending_read_buf_len_(0),
1695       net_log_(NetLogWithSource::Make(net_log, NetLogSourceType::NONE)) {
1696   if (data_) {
1697     data_->Initialize(this);
1698     peer_addr_ = data->connect_data().peer_addr;
1699   }
1700 }
1701 
~MockUDPClientSocket()1702 MockUDPClientSocket::~MockUDPClientSocket() {
1703   if (data_)
1704     data_->DetachSocket();
1705 }
1706 
Read(IOBuffer * buf,int buf_len,CompletionOnceCallback callback)1707 int MockUDPClientSocket::Read(IOBuffer* buf,
1708                               int buf_len,
1709                               CompletionOnceCallback callback) {
1710   DCHECK(callback);
1711 
1712   if (!connected_ || !data_)
1713     return ERR_UNEXPECTED;
1714   data_transferred_ = true;
1715 
1716   // If the buffer is already in use, a read is already in progress!
1717   DCHECK(!pending_read_buf_);
1718 
1719   // Store our async IO data.
1720   pending_read_buf_ = buf;
1721   pending_read_buf_len_ = buf_len;
1722   pending_read_callback_ = std::move(callback);
1723 
1724   if (need_read_data_) {
1725     read_data_ = data_->OnRead();
1726     // ERR_IO_PENDING means that the SocketDataProvider is taking responsibility
1727     // to complete the async IO manually later (via OnReadComplete).
1728     if (read_data_.result == ERR_IO_PENDING) {
1729       // We need to be using async IO in this case.
1730       DCHECK(!pending_read_callback_.is_null());
1731       return ERR_IO_PENDING;
1732     }
1733     need_read_data_ = false;
1734   }
1735 
1736   return CompleteRead();
1737 }
1738 
Write(IOBuffer * buf,int buf_len,CompletionOnceCallback callback,const NetworkTrafficAnnotationTag &)1739 int MockUDPClientSocket::Write(
1740     IOBuffer* buf,
1741     int buf_len,
1742     CompletionOnceCallback callback,
1743     const NetworkTrafficAnnotationTag& /* traffic_annotation */) {
1744   DCHECK(buf);
1745   DCHECK_GT(buf_len, 0);
1746   DCHECK(callback);
1747 
1748   if (!connected_ || !data_)
1749     return ERR_UNEXPECTED;
1750   data_transferred_ = true;
1751 
1752   std::string data(buf->data(), buf_len);
1753   MockWriteResult write_result = data_->OnWrite(data);
1754 
1755   // ERR_IO_PENDING is a signal that the socket data will call back
1756   // asynchronously.
1757   if (write_result.result == ERR_IO_PENDING) {
1758     pending_write_callback_ = std::move(callback);
1759     return ERR_IO_PENDING;
1760   }
1761   if (write_result.mode == ASYNC) {
1762     RunCallbackAsync(std::move(callback), write_result.result);
1763     return ERR_IO_PENDING;
1764   }
1765   return write_result.result;
1766 }
1767 
WriteAsync(const char * buffer,size_t buf_len,CompletionOnceCallback callback,const NetworkTrafficAnnotationTag &)1768 int MockUDPClientSocket::WriteAsync(
1769     const char* buffer,
1770     size_t buf_len,
1771     CompletionOnceCallback callback,
1772     const NetworkTrafficAnnotationTag& /* traffic_annotation */) {
1773   DCHECK(buffer);
1774   DCHECK_GT(buf_len, 0u);
1775   DCHECK(callback);
1776 
1777   if (!connected_ || !data_)
1778     return ERR_UNEXPECTED;
1779   data_transferred_ = true;
1780 
1781   std::string data(buffer, buf_len);
1782   MockWriteResult write_result = data_->OnWrite(data);
1783 
1784   // ERR_IO_PENDING is a signal that the socket data will call back
1785   // asynchronously.
1786   if (write_result.result == ERR_IO_PENDING) {
1787     pending_write_callback_ = std::move(callback);
1788     return ERR_IO_PENDING;
1789   }
1790   if (write_result.mode == ASYNC) {
1791     RunCallbackAsync(std::move(callback), write_result.result);
1792     return ERR_IO_PENDING;
1793   }
1794   return write_result.result;
1795 }
1796 
WriteAsync(DatagramBuffers buffers,CompletionOnceCallback callback,const NetworkTrafficAnnotationTag &)1797 int MockUDPClientSocket::WriteAsync(
1798     DatagramBuffers buffers,
1799     CompletionOnceCallback callback,
1800     const NetworkTrafficAnnotationTag& /* traffic_annotation */) {
1801   DCHECK(!buffers.empty());
1802   DCHECK(callback);
1803 
1804   if (!connected_ || !data_)
1805     return ERR_UNEXPECTED;
1806 
1807   unwritten_buffers_ = std::move(buffers);
1808 
1809   int rv = 0;
1810   size_t buf_len = 0;
1811   do {
1812     auto& buf = unwritten_buffers_.front();
1813 
1814     buf_len = buf->length();
1815     std::string data(buf->data(), buf_len);
1816     MockWriteResult write_result = data_->OnWrite(data);
1817     rv = write_result.result;
1818 
1819     // ERR_IO_PENDING is a signal that the socket data will call back
1820     // asynchronously.
1821     if (write_result.result == ERR_IO_PENDING) {
1822       pending_write_callback_ = std::move(callback);
1823       return ERR_IO_PENDING;
1824     }
1825     if (write_result.mode == ASYNC) {
1826       RunCallbackAsync(std::move(callback), write_result.result);
1827       return ERR_IO_PENDING;
1828     }
1829 
1830     if (rv < 0) {
1831       return rv;
1832     }
1833 
1834     unwritten_buffers_.pop_front();
1835   } while (!unwritten_buffers_.empty());
1836 
1837   return buf_len;
1838 }
1839 
GetUnwrittenBuffers()1840 DatagramBuffers MockUDPClientSocket::GetUnwrittenBuffers() {
1841   return std::move(unwritten_buffers_);
1842 }
1843 
SetReceiveBufferSize(int32_t size)1844 int MockUDPClientSocket::SetReceiveBufferSize(int32_t size) {
1845   return OK;
1846 }
1847 
SetSendBufferSize(int32_t size)1848 int MockUDPClientSocket::SetSendBufferSize(int32_t size) {
1849   return OK;
1850 }
1851 
SetDoNotFragment()1852 int MockUDPClientSocket::SetDoNotFragment() {
1853   return OK;
1854 }
1855 
Close()1856 void MockUDPClientSocket::Close() {
1857   connected_ = false;
1858 }
1859 
GetPeerAddress(IPEndPoint * address) const1860 int MockUDPClientSocket::GetPeerAddress(IPEndPoint* address) const {
1861   if (!data_)
1862     return ERR_UNEXPECTED;
1863 
1864   *address = peer_addr_;
1865   return OK;
1866 }
1867 
GetLocalAddress(IPEndPoint * address) const1868 int MockUDPClientSocket::GetLocalAddress(IPEndPoint* address) const {
1869   *address = IPEndPoint(IPAddress(192, 0, 2, 33), source_port_);
1870   return OK;
1871 }
1872 
UseNonBlockingIO()1873 void MockUDPClientSocket::UseNonBlockingIO() {}
1874 
SetWriteAsyncEnabled(bool enabled)1875 void MockUDPClientSocket::SetWriteAsyncEnabled(bool enabled) {}
WriteAsyncEnabled()1876 bool MockUDPClientSocket::WriteAsyncEnabled() {
1877   return false;
1878 }
SetMaxPacketSize(size_t max_packet_size)1879 void MockUDPClientSocket::SetMaxPacketSize(size_t max_packet_size) {}
SetWriteMultiCoreEnabled(bool enabled)1880 void MockUDPClientSocket::SetWriteMultiCoreEnabled(bool enabled) {}
SetSendmmsgEnabled(bool enabled)1881 void MockUDPClientSocket::SetSendmmsgEnabled(bool enabled) {}
SetWriteBatchingActive(bool active)1882 void MockUDPClientSocket::SetWriteBatchingActive(bool active) {}
SetMulticastInterface(uint32_t interface_index)1883 int MockUDPClientSocket::SetMulticastInterface(uint32_t interface_index) {
1884   return OK;
1885 }
1886 
NetLog() const1887 const NetLogWithSource& MockUDPClientSocket::NetLog() const {
1888   return net_log_;
1889 }
1890 
Connect(const IPEndPoint & address)1891 int MockUDPClientSocket::Connect(const IPEndPoint& address) {
1892   if (!data_)
1893     return ERR_UNEXPECTED;
1894   connected_ = true;
1895   peer_addr_ = address;
1896   return data_->connect_data().result;
1897 }
1898 
ConnectUsingNetwork(NetworkChangeNotifier::NetworkHandle network,const IPEndPoint & address)1899 int MockUDPClientSocket::ConnectUsingNetwork(
1900     NetworkChangeNotifier::NetworkHandle network,
1901     const IPEndPoint& address) {
1902   DCHECK(!connected_);
1903   if (!data_)
1904     return ERR_UNEXPECTED;
1905   network_ = network;
1906   connected_ = true;
1907   peer_addr_ = address;
1908   return data_->connect_data().result;
1909 }
1910 
ConnectUsingDefaultNetwork(const IPEndPoint & address)1911 int MockUDPClientSocket::ConnectUsingDefaultNetwork(const IPEndPoint& address) {
1912   DCHECK(!connected_);
1913   if (!data_)
1914     return ERR_UNEXPECTED;
1915   network_ = kDefaultNetworkForTests;
1916   connected_ = true;
1917   peer_addr_ = address;
1918   return data_->connect_data().result;
1919 }
1920 
GetBoundNetwork() const1921 NetworkChangeNotifier::NetworkHandle MockUDPClientSocket::GetBoundNetwork()
1922     const {
1923   return network_;
1924 }
1925 
ApplySocketTag(const SocketTag & tag)1926 void MockUDPClientSocket::ApplySocketTag(const SocketTag& tag) {
1927   tagged_before_data_transferred_ &= !data_transferred_ || tag == tag_;
1928   tag_ = tag;
1929 }
1930 
OnReadComplete(const MockRead & data)1931 void MockUDPClientSocket::OnReadComplete(const MockRead& data) {
1932   if (!data_)
1933     return;
1934 
1935   // There must be a read pending.
1936   DCHECK(pending_read_buf_.get());
1937   DCHECK(pending_read_callback_);
1938   // You can't complete a read with another ERR_IO_PENDING status code.
1939   DCHECK_NE(ERR_IO_PENDING, data.result);
1940   // Since we've been waiting for data, need_read_data_ should be true.
1941   DCHECK(need_read_data_);
1942 
1943   read_data_ = data;
1944   need_read_data_ = false;
1945 
1946   // The caller is simulating that this IO completes right now.  Don't
1947   // let CompleteRead() schedule a callback.
1948   read_data_.mode = SYNCHRONOUS;
1949 
1950   CompletionOnceCallback callback = std::move(pending_read_callback_);
1951   int rv = CompleteRead();
1952   RunCallback(std::move(callback), rv);
1953 }
1954 
OnWriteComplete(int rv)1955 void MockUDPClientSocket::OnWriteComplete(int rv) {
1956   if (!data_)
1957     return;
1958 
1959   // There must be a read pending.
1960   DCHECK(!pending_write_callback_.is_null());
1961   RunCallback(std::move(pending_write_callback_), rv);
1962 }
1963 
OnConnectComplete(const MockConnect & data)1964 void MockUDPClientSocket::OnConnectComplete(const MockConnect& data) {
1965   NOTIMPLEMENTED();
1966 }
1967 
OnDataProviderDestroyed()1968 void MockUDPClientSocket::OnDataProviderDestroyed() {
1969   data_ = nullptr;
1970 }
1971 
CompleteRead()1972 int MockUDPClientSocket::CompleteRead() {
1973   DCHECK(pending_read_buf_.get());
1974   DCHECK(pending_read_buf_len_ > 0);
1975 
1976   // Save the pending async IO data and reset our |pending_| state.
1977   scoped_refptr<IOBuffer> buf = pending_read_buf_;
1978   int buf_len = pending_read_buf_len_;
1979   CompletionOnceCallback callback = std::move(pending_read_callback_);
1980   pending_read_buf_ = nullptr;
1981   pending_read_buf_len_ = 0;
1982 
1983   int result = read_data_.result;
1984   DCHECK(result != ERR_IO_PENDING);
1985 
1986   if (read_data_.data) {
1987     if (read_data_.data_len - read_offset_ > 0) {
1988       result = std::min(buf_len, read_data_.data_len - read_offset_);
1989       memcpy(buf->data(), read_data_.data + read_offset_, result);
1990       read_offset_ += result;
1991       if (read_offset_ == read_data_.data_len) {
1992         need_read_data_ = true;
1993         read_offset_ = 0;
1994       }
1995     } else {
1996       result = 0;  // EOF
1997     }
1998   }
1999 
2000   if (read_data_.mode == ASYNC) {
2001     DCHECK(!callback.is_null());
2002     RunCallbackAsync(std::move(callback), result);
2003     return ERR_IO_PENDING;
2004   }
2005   return result;
2006 }
2007 
RunCallbackAsync(CompletionOnceCallback callback,int result)2008 void MockUDPClientSocket::RunCallbackAsync(CompletionOnceCallback callback,
2009                                            int result) {
2010   base::ThreadTaskRunnerHandle::Get()->PostTask(
2011       FROM_HERE,
2012       base::BindOnce(&MockUDPClientSocket::RunCallback,
2013                      weak_factory_.GetWeakPtr(), std::move(callback), result));
2014 }
2015 
RunCallback(CompletionOnceCallback callback,int result)2016 void MockUDPClientSocket::RunCallback(CompletionOnceCallback callback,
2017                                       int result) {
2018   std::move(callback).Run(result);
2019 }
2020 
TestSocketRequest(std::vector<TestSocketRequest * > * request_order,size_t * completion_count)2021 TestSocketRequest::TestSocketRequest(
2022     std::vector<TestSocketRequest*>* request_order,
2023     size_t* completion_count)
2024     : request_order_(request_order), completion_count_(completion_count) {
2025   DCHECK(request_order);
2026   DCHECK(completion_count);
2027 }
2028 
2029 TestSocketRequest::~TestSocketRequest() = default;
2030 
OnComplete(int result)2031 void TestSocketRequest::OnComplete(int result) {
2032   SetResult(result);
2033   (*completion_count_)++;
2034   request_order_->push_back(this);
2035 }
2036 
2037 // static
2038 const int ClientSocketPoolTest::kIndexOutOfBounds = -1;
2039 
2040 // static
2041 const int ClientSocketPoolTest::kRequestNotFound = -2;
2042 
ClientSocketPoolTest()2043 ClientSocketPoolTest::ClientSocketPoolTest() : completion_count_(0) {}
2044 ClientSocketPoolTest::~ClientSocketPoolTest() = default;
2045 
GetOrderOfRequest(size_t index) const2046 int ClientSocketPoolTest::GetOrderOfRequest(size_t index) const {
2047   index--;
2048   if (index >= requests_.size())
2049     return kIndexOutOfBounds;
2050 
2051   for (size_t i = 0; i < request_order_.size(); i++)
2052     if (requests_[index].get() == request_order_[i])
2053       return i + 1;
2054 
2055   return kRequestNotFound;
2056 }
2057 
ReleaseOneConnection(KeepAlive keep_alive)2058 bool ClientSocketPoolTest::ReleaseOneConnection(KeepAlive keep_alive) {
2059   for (std::unique_ptr<TestSocketRequest>& it : requests_) {
2060     if (it->handle()->is_initialized()) {
2061       if (keep_alive == NO_KEEP_ALIVE)
2062         it->handle()->socket()->Disconnect();
2063       it->handle()->Reset();
2064       base::RunLoop().RunUntilIdle();
2065       return true;
2066     }
2067   }
2068   return false;
2069 }
2070 
ReleaseAllConnections(KeepAlive keep_alive)2071 void ClientSocketPoolTest::ReleaseAllConnections(KeepAlive keep_alive) {
2072   bool released_one;
2073   do {
2074     released_one = ReleaseOneConnection(keep_alive);
2075   } while (released_one);
2076 }
2077 
MockConnectJob(std::unique_ptr<StreamSocket> socket,ClientSocketHandle * handle,const SocketTag & socket_tag,CompletionOnceCallback callback,RequestPriority priority)2078 MockTransportClientSocketPool::MockConnectJob::MockConnectJob(
2079     std::unique_ptr<StreamSocket> socket,
2080     ClientSocketHandle* handle,
2081     const SocketTag& socket_tag,
2082     CompletionOnceCallback callback,
2083     RequestPriority priority)
2084     : socket_(std::move(socket)),
2085       handle_(handle),
2086       socket_tag_(socket_tag),
2087       user_callback_(std::move(callback)),
2088       priority_(priority) {}
2089 
2090 MockTransportClientSocketPool::MockConnectJob::~MockConnectJob() = default;
2091 
Connect()2092 int MockTransportClientSocketPool::MockConnectJob::Connect() {
2093   socket_->ApplySocketTag(socket_tag_);
2094   int rv = socket_->Connect(
2095       base::BindOnce(&MockConnectJob::OnConnect, base::Unretained(this)));
2096   if (rv != ERR_IO_PENDING) {
2097     user_callback_.Reset();
2098     OnConnect(rv);
2099   }
2100   return rv;
2101 }
2102 
CancelHandle(const ClientSocketHandle * handle)2103 bool MockTransportClientSocketPool::MockConnectJob::CancelHandle(
2104     const ClientSocketHandle* handle) {
2105   if (handle != handle_)
2106     return false;
2107   socket_.reset();
2108   handle_ = nullptr;
2109   user_callback_.Reset();
2110   return true;
2111 }
2112 
OnConnect(int rv)2113 void MockTransportClientSocketPool::MockConnectJob::OnConnect(int rv) {
2114   if (!socket_.get())
2115     return;
2116   if (rv == OK) {
2117     handle_->SetSocket(std::move(socket_));
2118 
2119     // Needed for socket pool tests that layer other sockets on top of mock
2120     // sockets.
2121     LoadTimingInfo::ConnectTiming connect_timing;
2122     base::TimeTicks now = base::TimeTicks::Now();
2123     connect_timing.dns_start = now;
2124     connect_timing.dns_end = now;
2125     connect_timing.connect_start = now;
2126     connect_timing.connect_end = now;
2127     handle_->set_connect_timing(connect_timing);
2128   } else {
2129     socket_.reset();
2130 
2131     // Needed to test copying of ConnectionAttempts in SSL ConnectJob.
2132     ConnectionAttempts attempts;
2133     attempts.push_back(ConnectionAttempt(IPEndPoint(), rv));
2134     handle_->set_connection_attempts(attempts);
2135   }
2136 
2137   handle_ = nullptr;
2138 
2139   if (!user_callback_.is_null()) {
2140     std::move(user_callback_).Run(rv);
2141   }
2142 }
2143 
MockTransportClientSocketPool(int max_sockets,int max_sockets_per_group,const CommonConnectJobParams * common_connect_job_params)2144 MockTransportClientSocketPool::MockTransportClientSocketPool(
2145     int max_sockets,
2146     int max_sockets_per_group,
2147     const CommonConnectJobParams* common_connect_job_params)
2148     : TransportClientSocketPool(
2149           max_sockets,
2150           max_sockets_per_group,
2151           base::TimeDelta::FromSeconds(10) /* unused_idle_socket_timeout */,
2152           ProxyServer::Direct(),
2153           false /* is_for_websockets */,
2154           common_connect_job_params),
2155       client_socket_factory_(common_connect_job_params->client_socket_factory),
2156       last_request_priority_(DEFAULT_PRIORITY),
2157       release_count_(0),
2158       cancel_count_(0) {}
2159 
2160 MockTransportClientSocketPool::~MockTransportClientSocketPool() = default;
2161 
RequestSocket(const ClientSocketPool::GroupId & group_id,scoped_refptr<ClientSocketPool::SocketParams> socket_params,const base::Optional<NetworkTrafficAnnotationTag> & proxy_annotation_tag,RequestPriority priority,const SocketTag & socket_tag,RespectLimits respect_limits,ClientSocketHandle * handle,CompletionOnceCallback callback,const ProxyAuthCallback & on_auth_callback,const NetLogWithSource & net_log)2162 int MockTransportClientSocketPool::RequestSocket(
2163     const ClientSocketPool::GroupId& group_id,
2164     scoped_refptr<ClientSocketPool::SocketParams> socket_params,
2165     const base::Optional<NetworkTrafficAnnotationTag>& proxy_annotation_tag,
2166     RequestPriority priority,
2167     const SocketTag& socket_tag,
2168     RespectLimits respect_limits,
2169     ClientSocketHandle* handle,
2170     CompletionOnceCallback callback,
2171     const ProxyAuthCallback& on_auth_callback,
2172     const NetLogWithSource& net_log) {
2173   last_request_priority_ = priority;
2174   std::unique_ptr<StreamSocket> socket =
2175       client_socket_factory_->CreateTransportClientSocket(
2176           AddressList(), nullptr, net_log.net_log(), NetLogSource());
2177   MockConnectJob* job = new MockConnectJob(
2178       std::move(socket), handle, socket_tag, std::move(callback), priority);
2179   job_list_.push_back(base::WrapUnique(job));
2180   handle->set_group_generation(1);
2181   return job->Connect();
2182 }
2183 
SetPriority(const ClientSocketPool::GroupId & group_id,ClientSocketHandle * handle,RequestPriority priority)2184 void MockTransportClientSocketPool::SetPriority(
2185     const ClientSocketPool::GroupId& group_id,
2186     ClientSocketHandle* handle,
2187     RequestPriority priority) {
2188   for (auto& job : job_list_) {
2189     if (job->handle() == handle) {
2190       job->set_priority(priority);
2191       return;
2192     }
2193   }
2194   NOTREACHED();
2195 }
2196 
CancelRequest(const ClientSocketPool::GroupId & group_id,ClientSocketHandle * handle,bool cancel_connect_job)2197 void MockTransportClientSocketPool::CancelRequest(
2198     const ClientSocketPool::GroupId& group_id,
2199     ClientSocketHandle* handle,
2200     bool cancel_connect_job) {
2201   for (std::unique_ptr<MockConnectJob>& it : job_list_) {
2202     if (it->CancelHandle(handle)) {
2203       cancel_count_++;
2204       break;
2205     }
2206   }
2207 }
2208 
ReleaseSocket(const ClientSocketPool::GroupId & group_id,std::unique_ptr<StreamSocket> socket,int64_t generation)2209 void MockTransportClientSocketPool::ReleaseSocket(
2210     const ClientSocketPool::GroupId& group_id,
2211     std::unique_ptr<StreamSocket> socket,
2212     int64_t generation) {
2213   EXPECT_EQ(1, generation);
2214   release_count_++;
2215 }
2216 
WrappedStreamSocket(std::unique_ptr<StreamSocket> transport)2217 WrappedStreamSocket::WrappedStreamSocket(
2218     std::unique_ptr<StreamSocket> transport)
2219     : transport_(std::move(transport)) {}
~WrappedStreamSocket()2220 WrappedStreamSocket::~WrappedStreamSocket() {}
2221 
Bind(const net::IPEndPoint & local_addr)2222 int WrappedStreamSocket::Bind(const net::IPEndPoint& local_addr) {
2223   NOTREACHED();
2224   return ERR_FAILED;
2225 }
2226 
Connect(CompletionOnceCallback callback)2227 int WrappedStreamSocket::Connect(CompletionOnceCallback callback) {
2228   return transport_->Connect(std::move(callback));
2229 }
2230 
Disconnect()2231 void WrappedStreamSocket::Disconnect() {
2232   transport_->Disconnect();
2233 }
2234 
IsConnected() const2235 bool WrappedStreamSocket::IsConnected() const {
2236   return transport_->IsConnected();
2237 }
2238 
IsConnectedAndIdle() const2239 bool WrappedStreamSocket::IsConnectedAndIdle() const {
2240   return transport_->IsConnectedAndIdle();
2241 }
2242 
GetPeerAddress(IPEndPoint * address) const2243 int WrappedStreamSocket::GetPeerAddress(IPEndPoint* address) const {
2244   return transport_->GetPeerAddress(address);
2245 }
2246 
GetLocalAddress(IPEndPoint * address) const2247 int WrappedStreamSocket::GetLocalAddress(IPEndPoint* address) const {
2248   return transport_->GetLocalAddress(address);
2249 }
2250 
NetLog() const2251 const NetLogWithSource& WrappedStreamSocket::NetLog() const {
2252   return transport_->NetLog();
2253 }
2254 
WasEverUsed() const2255 bool WrappedStreamSocket::WasEverUsed() const {
2256   return transport_->WasEverUsed();
2257 }
2258 
WasAlpnNegotiated() const2259 bool WrappedStreamSocket::WasAlpnNegotiated() const {
2260   return transport_->WasAlpnNegotiated();
2261 }
2262 
GetNegotiatedProtocol() const2263 NextProto WrappedStreamSocket::GetNegotiatedProtocol() const {
2264   return transport_->GetNegotiatedProtocol();
2265 }
2266 
GetSSLInfo(SSLInfo * ssl_info)2267 bool WrappedStreamSocket::GetSSLInfo(SSLInfo* ssl_info) {
2268   return transport_->GetSSLInfo(ssl_info);
2269 }
2270 
GetConnectionAttempts(ConnectionAttempts * out) const2271 void WrappedStreamSocket::GetConnectionAttempts(ConnectionAttempts* out) const {
2272   transport_->GetConnectionAttempts(out);
2273 }
2274 
ClearConnectionAttempts()2275 void WrappedStreamSocket::ClearConnectionAttempts() {
2276   transport_->ClearConnectionAttempts();
2277 }
2278 
AddConnectionAttempts(const ConnectionAttempts & attempts)2279 void WrappedStreamSocket::AddConnectionAttempts(
2280     const ConnectionAttempts& attempts) {
2281   transport_->AddConnectionAttempts(attempts);
2282 }
2283 
GetTotalReceivedBytes() const2284 int64_t WrappedStreamSocket::GetTotalReceivedBytes() const {
2285   return transport_->GetTotalReceivedBytes();
2286 }
2287 
ApplySocketTag(const SocketTag & tag)2288 void WrappedStreamSocket::ApplySocketTag(const SocketTag& tag) {
2289   transport_->ApplySocketTag(tag);
2290 }
2291 
Read(IOBuffer * buf,int buf_len,CompletionOnceCallback callback)2292 int WrappedStreamSocket::Read(IOBuffer* buf,
2293                               int buf_len,
2294                               CompletionOnceCallback callback) {
2295   return transport_->Read(buf, buf_len, std::move(callback));
2296 }
2297 
ReadIfReady(IOBuffer * buf,int buf_len,CompletionOnceCallback callback)2298 int WrappedStreamSocket::ReadIfReady(IOBuffer* buf,
2299                                      int buf_len,
2300                                      CompletionOnceCallback callback) {
2301   return transport_->ReadIfReady(buf, buf_len, std::move((callback)));
2302 }
2303 
Write(IOBuffer * buf,int buf_len,CompletionOnceCallback callback,const NetworkTrafficAnnotationTag & traffic_annotation)2304 int WrappedStreamSocket::Write(
2305     IOBuffer* buf,
2306     int buf_len,
2307     CompletionOnceCallback callback,
2308     const NetworkTrafficAnnotationTag& traffic_annotation) {
2309   return transport_->Write(buf, buf_len, std::move(callback),
2310                            TRAFFIC_ANNOTATION_FOR_TESTS);
2311 }
2312 
SetReceiveBufferSize(int32_t size)2313 int WrappedStreamSocket::SetReceiveBufferSize(int32_t size) {
2314   return transport_->SetReceiveBufferSize(size);
2315 }
2316 
SetSendBufferSize(int32_t size)2317 int WrappedStreamSocket::SetSendBufferSize(int32_t size) {
2318   return transport_->SetSendBufferSize(size);
2319 }
2320 
Connect(CompletionOnceCallback callback)2321 int MockTaggingStreamSocket::Connect(CompletionOnceCallback callback) {
2322   connected_ = true;
2323   return WrappedStreamSocket::Connect(std::move(callback));
2324 }
2325 
ApplySocketTag(const SocketTag & tag)2326 void MockTaggingStreamSocket::ApplySocketTag(const SocketTag& tag) {
2327   tagged_before_connected_ &= !connected_ || tag == tag_;
2328   tag_ = tag;
2329   transport_->ApplySocketTag(tag);
2330 }
2331 
2332 std::unique_ptr<TransportClientSocket>
CreateTransportClientSocket(const AddressList & addresses,std::unique_ptr<SocketPerformanceWatcher> socket_performance_watcher,NetLog * net_log,const NetLogSource & source)2333 MockTaggingClientSocketFactory::CreateTransportClientSocket(
2334     const AddressList& addresses,
2335     std::unique_ptr<SocketPerformanceWatcher> socket_performance_watcher,
2336     NetLog* net_log,
2337     const NetLogSource& source) {
2338   std::unique_ptr<MockTaggingStreamSocket> socket(new MockTaggingStreamSocket(
2339       MockClientSocketFactory::CreateTransportClientSocket(
2340           addresses, std::move(socket_performance_watcher), net_log, source)));
2341   tcp_socket_ = socket.get();
2342   return std::move(socket);
2343 }
2344 
2345 std::unique_ptr<DatagramClientSocket>
CreateDatagramClientSocket(DatagramSocket::BindType bind_type,NetLog * net_log,const NetLogSource & source)2346 MockTaggingClientSocketFactory::CreateDatagramClientSocket(
2347     DatagramSocket::BindType bind_type,
2348     NetLog* net_log,
2349     const NetLogSource& source) {
2350   std::unique_ptr<DatagramClientSocket> socket(
2351       MockClientSocketFactory::CreateDatagramClientSocket(bind_type, net_log,
2352                                                           source));
2353   udp_socket_ = static_cast<MockUDPClientSocket*>(socket.get());
2354   return socket;
2355 }
2356 
2357 const char kSOCKS4TestHost[] = "127.0.0.1";
2358 const int kSOCKS4TestPort = 80;
2359 
2360 const char kSOCKS4OkRequestLocalHostPort80[] = {0x04, 0x01, 0x00, 0x50, 127,
2361                                                 0,    0,    1,    0};
2362 const int kSOCKS4OkRequestLocalHostPort80Length =
2363     base::size(kSOCKS4OkRequestLocalHostPort80);
2364 
2365 const char kSOCKS4OkReply[] = {0x00, 0x5A, 0x00, 0x00, 0, 0, 0, 0};
2366 const int kSOCKS4OkReplyLength = base::size(kSOCKS4OkReply);
2367 
2368 const char kSOCKS5TestHost[] = "host";
2369 const int kSOCKS5TestPort = 80;
2370 
2371 const char kSOCKS5GreetRequest[] = { 0x05, 0x01, 0x00 };
2372 const int kSOCKS5GreetRequestLength = base::size(kSOCKS5GreetRequest);
2373 
2374 const char kSOCKS5GreetResponse[] = { 0x05, 0x00 };
2375 const int kSOCKS5GreetResponseLength = base::size(kSOCKS5GreetResponse);
2376 
2377 const char kSOCKS5OkRequest[] =
2378     { 0x05, 0x01, 0x00, 0x03, 0x04, 'h', 'o', 's', 't', 0x00, 0x50 };
2379 const int kSOCKS5OkRequestLength = base::size(kSOCKS5OkRequest);
2380 
2381 const char kSOCKS5OkResponse[] =
2382     { 0x05, 0x00, 0x00, 0x01, 127, 0, 0, 1, 0x00, 0x50 };
2383 const int kSOCKS5OkResponseLength = base::size(kSOCKS5OkResponse);
2384 
CountReadBytes(base::span<const MockRead> reads)2385 int64_t CountReadBytes(base::span<const MockRead> reads) {
2386   int64_t total = 0;
2387   for (const MockRead& read : reads)
2388     total += read.data_len;
2389   return total;
2390 }
2391 
CountWriteBytes(base::span<const MockWrite> writes)2392 int64_t CountWriteBytes(base::span<const MockWrite> writes) {
2393   int64_t total = 0;
2394   for (const MockWrite& write : writes)
2395     total += write.data_len;
2396   return total;
2397 }
2398 
2399 #if defined(OS_ANDROID)
CanGetTaggedBytes()2400 bool CanGetTaggedBytes() {
2401   // In Android P, /proc/net/xt_qtaguid/stats is no longer guaranteed to be
2402   // present, and has been replaced with eBPF Traffic Monitoring in netd. See:
2403   // https://source.android.com/devices/tech/datausage/ebpf-traffic-monitor
2404   //
2405   // To read traffic statistics from netd, apps should use the API
2406   // NetworkStatsManager.queryDetailsForUidTag(). But this API does not provide
2407   // statistics for local traffic, only mobile and WiFi traffic, so it would not
2408   // work in tests that spin up a local server. So for now, GetTaggedBytes is
2409   // only supported on Android releases older than P.
2410   return base::android::BuildInfo::GetInstance()->sdk_int() <
2411          base::android::SDK_VERSION_P;
2412 }
2413 
GetTaggedBytes(int32_t expected_tag)2414 uint64_t GetTaggedBytes(int32_t expected_tag) {
2415   EXPECT_TRUE(CanGetTaggedBytes());
2416 
2417   // To determine how many bytes the system saw with a particular tag read
2418   // the /proc/net/xt_qtaguid/stats file which contains the kernel's
2419   // dump of all the UIDs and their tags sent and received bytes.
2420   uint64_t bytes = 0;
2421   std::string contents;
2422   EXPECT_TRUE(base::ReadFileToString(
2423       base::FilePath::FromUTF8Unsafe("/proc/net/xt_qtaguid/stats"), &contents));
2424   for (size_t i = contents.find('\n');  // Skip first line which is headers.
2425        i != std::string::npos && i < contents.length();) {
2426     uint64_t tag, rx_bytes;
2427     uid_t uid;
2428     int n;
2429     // Parse out the numbers we care about. For reference here's the column
2430     // headers:
2431     // idx iface acct_tag_hex uid_tag_int cnt_set rx_bytes rx_packets tx_bytes
2432     // tx_packets rx_tcp_bytes rx_tcp_packets rx_udp_bytes rx_udp_packets
2433     // rx_other_bytes rx_other_packets tx_tcp_bytes tx_tcp_packets tx_udp_bytes
2434     // tx_udp_packets tx_other_bytes tx_other_packets
2435     EXPECT_EQ(sscanf(contents.c_str() + i,
2436                      "%*d %*s 0x%" SCNx64 " %d %*d %" SCNu64
2437                      " %*d %*d %*d %*d %*d %*d %*d %*d "
2438                      "%*d %*d %*d %*d %*d %*d %*d%n",
2439                      &tag, &uid, &rx_bytes, &n),
2440               3);
2441     // If this line matches our UID and |expected_tag| then add it to the total.
2442     if (uid == getuid() && (int32_t)(tag >> 32) == expected_tag) {
2443       bytes += rx_bytes;
2444     }
2445     // Move |i| to the next line.
2446     i += n + 1;
2447   }
2448   return bytes;
2449 }
2450 #endif
2451 
2452 }  // namespace net
2453