1 // Copyright (c) 2012 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #include "net/third_party/quiche/src/quic/test_tools/quic_test_client.h"
6 
7 #include <memory>
8 #include <utility>
9 #include <vector>
10 
11 #include "absl/strings/match.h"
12 #include "absl/strings/string_view.h"
13 #include "third_party/boringssl/src/include/openssl/x509.h"
14 #include "net/third_party/quiche/src/quic/core/crypto/proof_verifier.h"
15 #include "net/third_party/quiche/src/quic/core/http/quic_spdy_client_stream.h"
16 #include "net/third_party/quiche/src/quic/core/http/spdy_utils.h"
17 #include "net/third_party/quiche/src/quic/core/quic_epoll_connection_helper.h"
18 #include "net/third_party/quiche/src/quic/core/quic_packet_writer_wrapper.h"
19 #include "net/third_party/quiche/src/quic/core/quic_server_id.h"
20 #include "net/third_party/quiche/src/quic/core/quic_utils.h"
21 #include "net/third_party/quiche/src/quic/platform/api/quic_flags.h"
22 #include "net/third_party/quiche/src/quic/platform/api/quic_logging.h"
23 #include "net/third_party/quiche/src/quic/platform/api/quic_ptr_util.h"
24 #include "net/third_party/quiche/src/quic/platform/api/quic_stack_trace.h"
25 #include "net/third_party/quiche/src/quic/test_tools/crypto_test_utils.h"
26 #include "net/third_party/quiche/src/quic/test_tools/quic_client_peer.h"
27 #include "net/third_party/quiche/src/quic/test_tools/quic_connection_peer.h"
28 #include "net/third_party/quiche/src/quic/test_tools/quic_spdy_session_peer.h"
29 #include "net/third_party/quiche/src/quic/test_tools/quic_spdy_stream_peer.h"
30 #include "net/third_party/quiche/src/quic/test_tools/quic_test_utils.h"
31 #include "net/third_party/quiche/src/quic/tools/quic_url.h"
32 #include "net/third_party/quiche/src/common/platform/api/quiche_text_utils.h"
33 
34 namespace quic {
35 namespace test {
36 namespace {
37 
38 // RecordingProofVerifier accepts any certificate chain and records the common
39 // name of the leaf and then delegates the actual verification to an actual
40 // verifier. If no optional verifier is provided, then VerifyProof will return
41 // success.
42 class RecordingProofVerifier : public ProofVerifier {
43  public:
RecordingProofVerifier(std::unique_ptr<ProofVerifier> verifier)44   explicit RecordingProofVerifier(std::unique_ptr<ProofVerifier> verifier)
45       : verifier_(std::move(verifier)) {}
46 
47   // ProofVerifier interface.
VerifyProof(const std::string & hostname,const uint16_t port,const std::string & server_config,QuicTransportVersion transport_version,absl::string_view chlo_hash,const std::vector<std::string> & certs,const std::string & cert_sct,const std::string & signature,const ProofVerifyContext * context,std::string * error_details,std::unique_ptr<ProofVerifyDetails> * details,std::unique_ptr<ProofVerifierCallback> callback)48   QuicAsyncStatus VerifyProof(
49       const std::string& hostname,
50       const uint16_t port,
51       const std::string& server_config,
52       QuicTransportVersion transport_version,
53       absl::string_view chlo_hash,
54       const std::vector<std::string>& certs,
55       const std::string& cert_sct,
56       const std::string& signature,
57       const ProofVerifyContext* context,
58       std::string* error_details,
59       std::unique_ptr<ProofVerifyDetails>* details,
60       std::unique_ptr<ProofVerifierCallback> callback) override {
61     QuicAsyncStatus process_certs_result = ProcessCerts(certs, cert_sct);
62     if (process_certs_result != QUIC_SUCCESS) {
63       return process_certs_result;
64     }
65 
66     if (!verifier_) {
67       return QUIC_SUCCESS;
68     }
69 
70     return verifier_->VerifyProof(hostname, port, server_config,
71                                   transport_version, chlo_hash, certs, cert_sct,
72                                   signature, context, error_details, details,
73                                   std::move(callback));
74   }
75 
VerifyCertChain(const std::string &,const uint16_t,const std::vector<std::string> & certs,const std::string &,const std::string & cert_sct,const ProofVerifyContext *,std::string *,std::unique_ptr<ProofVerifyDetails> *,uint8_t *,std::unique_ptr<ProofVerifierCallback>)76   QuicAsyncStatus VerifyCertChain(
77       const std::string& /*hostname*/,
78       const uint16_t /*port*/,
79       const std::vector<std::string>& certs,
80       const std::string& /*ocsp_response*/,
81       const std::string& cert_sct,
82       const ProofVerifyContext* /*context*/,
83       std::string* /*error_details*/,
84       std::unique_ptr<ProofVerifyDetails>* /*details*/,
85       uint8_t* /*out_alert*/,
86       std::unique_ptr<ProofVerifierCallback> /*callback*/) override {
87     return ProcessCerts(certs, cert_sct);
88   }
89 
CreateDefaultContext()90   std::unique_ptr<ProofVerifyContext> CreateDefaultContext() override {
91     return verifier_ != nullptr ? verifier_->CreateDefaultContext() : nullptr;
92   }
93 
common_name() const94   const std::string& common_name() const { return common_name_; }
95 
cert_sct() const96   const std::string& cert_sct() const { return cert_sct_; }
97 
98  private:
ProcessCerts(const std::vector<std::string> & certs,const std::string & cert_sct)99   QuicAsyncStatus ProcessCerts(const std::vector<std::string>& certs,
100                                const std::string& cert_sct) {
101     common_name_.clear();
102     if (certs.empty()) {
103       return QUIC_FAILURE;
104     }
105 
106     const uint8_t* data;
107     data = reinterpret_cast<const uint8_t*>(certs[0].data());
108     bssl::UniquePtr<X509> cert(d2i_X509(nullptr, &data, certs[0].size()));
109     if (!cert.get()) {
110       return QUIC_FAILURE;
111     }
112 
113     static const unsigned kMaxCommonNameLength = 256;
114     char buf[kMaxCommonNameLength];
115     X509_NAME* subject_name = X509_get_subject_name(cert.get());
116     if (X509_NAME_get_text_by_NID(subject_name, NID_commonName, buf,
117                                   sizeof(buf)) <= 0) {
118       return QUIC_FAILURE;
119     }
120 
121     common_name_ = buf;
122     cert_sct_ = cert_sct;
123     return QUIC_SUCCESS;
124   }
125 
126   std::unique_ptr<ProofVerifier> verifier_;
127   std::string common_name_;
128   std::string cert_sct_;
129 };
130 }  // namespace
131 
132 class MockableQuicClientEpollNetworkHelper
133     : public QuicClientEpollNetworkHelper {
134  public:
135   using QuicClientEpollNetworkHelper::QuicClientEpollNetworkHelper;
136   ~MockableQuicClientEpollNetworkHelper() override = default;
137 
ProcessPacket(const QuicSocketAddress & self_address,const QuicSocketAddress & peer_address,const QuicReceivedPacket & packet)138   void ProcessPacket(const QuicSocketAddress& self_address,
139                      const QuicSocketAddress& peer_address,
140                      const QuicReceivedPacket& packet) override {
141     QuicClientEpollNetworkHelper::ProcessPacket(self_address, peer_address,
142                                                 packet);
143     if (track_last_incoming_packet_) {
144       last_incoming_packet_ = packet.Clone();
145     }
146   }
147 
CreateQuicPacketWriter()148   QuicPacketWriter* CreateQuicPacketWriter() override {
149     QuicPacketWriter* writer =
150         QuicClientEpollNetworkHelper::CreateQuicPacketWriter();
151     if (!test_writer_) {
152       return writer;
153     }
154     test_writer_->set_writer(writer);
155     return test_writer_;
156   }
157 
last_incoming_packet()158   const QuicReceivedPacket* last_incoming_packet() {
159     return last_incoming_packet_.get();
160   }
161 
set_track_last_incoming_packet(bool track)162   void set_track_last_incoming_packet(bool track) {
163     track_last_incoming_packet_ = track;
164   }
165 
UseWriter(QuicPacketWriterWrapper * writer)166   void UseWriter(QuicPacketWriterWrapper* writer) {
167     CHECK(test_writer_ == nullptr);
168     test_writer_ = writer;
169   }
170 
set_peer_address(const QuicSocketAddress & address)171   void set_peer_address(const QuicSocketAddress& address) {
172     CHECK(test_writer_ != nullptr);
173     test_writer_->set_peer_address(address);
174   }
175 
176  private:
177   QuicPacketWriterWrapper* test_writer_ = nullptr;
178   // The last incoming packet, iff |track_last_incoming_packet_| is true.
179   std::unique_ptr<QuicReceivedPacket> last_incoming_packet_;
180   // If true, copy each packet from ProcessPacket into |last_incoming_packet_|
181   bool track_last_incoming_packet_ = false;
182 };
183 
MockableQuicClient(QuicSocketAddress server_address,const QuicServerId & server_id,const ParsedQuicVersionVector & supported_versions,QuicEpollServer * epoll_server)184 MockableQuicClient::MockableQuicClient(
185     QuicSocketAddress server_address,
186     const QuicServerId& server_id,
187     const ParsedQuicVersionVector& supported_versions,
188     QuicEpollServer* epoll_server)
189     : MockableQuicClient(server_address,
190                          server_id,
191                          QuicConfig(),
192                          supported_versions,
193                          epoll_server) {}
194 
MockableQuicClient(QuicSocketAddress server_address,const QuicServerId & server_id,const QuicConfig & config,const ParsedQuicVersionVector & supported_versions,QuicEpollServer * epoll_server)195 MockableQuicClient::MockableQuicClient(
196     QuicSocketAddress server_address,
197     const QuicServerId& server_id,
198     const QuicConfig& config,
199     const ParsedQuicVersionVector& supported_versions,
200     QuicEpollServer* epoll_server)
201     : MockableQuicClient(server_address,
202                          server_id,
203                          config,
204                          supported_versions,
205                          epoll_server,
206                          nullptr) {}
207 
MockableQuicClient(QuicSocketAddress server_address,const QuicServerId & server_id,const QuicConfig & config,const ParsedQuicVersionVector & supported_versions,QuicEpollServer * epoll_server,std::unique_ptr<ProofVerifier> proof_verifier)208 MockableQuicClient::MockableQuicClient(
209     QuicSocketAddress server_address,
210     const QuicServerId& server_id,
211     const QuicConfig& config,
212     const ParsedQuicVersionVector& supported_versions,
213     QuicEpollServer* epoll_server,
214     std::unique_ptr<ProofVerifier> proof_verifier)
215     : MockableQuicClient(server_address,
216                          server_id,
217                          config,
218                          supported_versions,
219                          epoll_server,
220                          std::move(proof_verifier),
221                          nullptr) {}
222 
MockableQuicClient(QuicSocketAddress server_address,const QuicServerId & server_id,const QuicConfig & config,const ParsedQuicVersionVector & supported_versions,QuicEpollServer * epoll_server,std::unique_ptr<ProofVerifier> proof_verifier,std::unique_ptr<SessionCache> session_cache)223 MockableQuicClient::MockableQuicClient(
224     QuicSocketAddress server_address,
225     const QuicServerId& server_id,
226     const QuicConfig& config,
227     const ParsedQuicVersionVector& supported_versions,
228     QuicEpollServer* epoll_server,
229     std::unique_ptr<ProofVerifier> proof_verifier,
230     std::unique_ptr<SessionCache> session_cache)
231     : QuicClient(
232           server_address,
233           server_id,
234           supported_versions,
235           config,
236           epoll_server,
237           std::make_unique<MockableQuicClientEpollNetworkHelper>(epoll_server,
238                                                                  this),
239           QuicWrapUnique(new RecordingProofVerifier(std::move(proof_verifier))),
240           std::move(session_cache)),
241       override_server_connection_id_(EmptyQuicConnectionId()),
242       server_connection_id_overridden_(false),
243       override_client_connection_id_(EmptyQuicConnectionId()),
244       client_connection_id_overridden_(false) {}
245 
~MockableQuicClient()246 MockableQuicClient::~MockableQuicClient() {
247   if (connected()) {
248     Disconnect();
249   }
250 }
251 
252 MockableQuicClientEpollNetworkHelper*
mockable_network_helper()253 MockableQuicClient::mockable_network_helper() {
254   return static_cast<MockableQuicClientEpollNetworkHelper*>(
255       epoll_network_helper());
256 }
257 
258 const MockableQuicClientEpollNetworkHelper*
mockable_network_helper() const259 MockableQuicClient::mockable_network_helper() const {
260   return static_cast<const MockableQuicClientEpollNetworkHelper*>(
261       epoll_network_helper());
262 }
263 
GenerateNewConnectionId()264 QuicConnectionId MockableQuicClient::GenerateNewConnectionId() {
265   if (server_connection_id_overridden_) {
266     return override_server_connection_id_;
267   }
268   if (override_server_connection_id_length_ >= 0) {
269     return QuicUtils::CreateRandomConnectionId(
270         override_server_connection_id_length_);
271   }
272   return QuicClient::GenerateNewConnectionId();
273 }
274 
UseConnectionId(QuicConnectionId server_connection_id)275 void MockableQuicClient::UseConnectionId(
276     QuicConnectionId server_connection_id) {
277   server_connection_id_overridden_ = true;
278   override_server_connection_id_ = server_connection_id;
279 }
280 
UseConnectionIdLength(int server_connection_id_length)281 void MockableQuicClient::UseConnectionIdLength(
282     int server_connection_id_length) {
283   override_server_connection_id_length_ = server_connection_id_length;
284 }
285 
GetClientConnectionId()286 QuicConnectionId MockableQuicClient::GetClientConnectionId() {
287   if (client_connection_id_overridden_) {
288     return override_client_connection_id_;
289   }
290   if (override_client_connection_id_length_ >= 0) {
291     return QuicUtils::CreateRandomConnectionId(
292         override_client_connection_id_length_);
293   }
294   return QuicClient::GetClientConnectionId();
295 }
296 
UseClientConnectionId(QuicConnectionId client_connection_id)297 void MockableQuicClient::UseClientConnectionId(
298     QuicConnectionId client_connection_id) {
299   client_connection_id_overridden_ = true;
300   override_client_connection_id_ = client_connection_id;
301 }
302 
UseClientConnectionIdLength(int client_connection_id_length)303 void MockableQuicClient::UseClientConnectionIdLength(
304     int client_connection_id_length) {
305   override_client_connection_id_length_ = client_connection_id_length;
306 }
307 
UseWriter(QuicPacketWriterWrapper * writer)308 void MockableQuicClient::UseWriter(QuicPacketWriterWrapper* writer) {
309   mockable_network_helper()->UseWriter(writer);
310 }
311 
set_peer_address(const QuicSocketAddress & address)312 void MockableQuicClient::set_peer_address(const QuicSocketAddress& address) {
313   mockable_network_helper()->set_peer_address(address);
314 }
315 
last_incoming_packet()316 const QuicReceivedPacket* MockableQuicClient::last_incoming_packet() {
317   return mockable_network_helper()->last_incoming_packet();
318 }
319 
set_track_last_incoming_packet(bool track)320 void MockableQuicClient::set_track_last_incoming_packet(bool track) {
321   mockable_network_helper()->set_track_last_incoming_packet(track);
322 }
323 
QuicTestClient(QuicSocketAddress server_address,const std::string & server_hostname,const ParsedQuicVersionVector & supported_versions)324 QuicTestClient::QuicTestClient(
325     QuicSocketAddress server_address,
326     const std::string& server_hostname,
327     const ParsedQuicVersionVector& supported_versions)
328     : QuicTestClient(server_address,
329                      server_hostname,
330                      QuicConfig(),
331                      supported_versions) {}
332 
QuicTestClient(QuicSocketAddress server_address,const std::string & server_hostname,const QuicConfig & config,const ParsedQuicVersionVector & supported_versions)333 QuicTestClient::QuicTestClient(
334     QuicSocketAddress server_address,
335     const std::string& server_hostname,
336     const QuicConfig& config,
337     const ParsedQuicVersionVector& supported_versions)
338     : client_(new MockableQuicClient(
339           server_address,
340           QuicServerId(server_hostname, server_address.port(), false),
341           config,
342           supported_versions,
343           &epoll_server_)) {
344   Initialize();
345 }
346 
QuicTestClient(QuicSocketAddress server_address,const std::string & server_hostname,const QuicConfig & config,const ParsedQuicVersionVector & supported_versions,std::unique_ptr<ProofVerifier> proof_verifier)347 QuicTestClient::QuicTestClient(
348     QuicSocketAddress server_address,
349     const std::string& server_hostname,
350     const QuicConfig& config,
351     const ParsedQuicVersionVector& supported_versions,
352     std::unique_ptr<ProofVerifier> proof_verifier)
353     : client_(new MockableQuicClient(
354           server_address,
355           QuicServerId(server_hostname, server_address.port(), false),
356           config,
357           supported_versions,
358           &epoll_server_,
359           std::move(proof_verifier))) {
360   Initialize();
361 }
362 
QuicTestClient(QuicSocketAddress server_address,const std::string & server_hostname,const QuicConfig & config,const ParsedQuicVersionVector & supported_versions,std::unique_ptr<ProofVerifier> proof_verifier,std::unique_ptr<SessionCache> session_cache)363 QuicTestClient::QuicTestClient(
364     QuicSocketAddress server_address,
365     const std::string& server_hostname,
366     const QuicConfig& config,
367     const ParsedQuicVersionVector& supported_versions,
368     std::unique_ptr<ProofVerifier> proof_verifier,
369     std::unique_ptr<SessionCache> session_cache)
370     : client_(new MockableQuicClient(
371           server_address,
372           QuicServerId(server_hostname, server_address.port(), false),
373           config,
374           supported_versions,
375           &epoll_server_,
376           std::move(proof_verifier),
377           std::move(session_cache))) {
378   Initialize();
379 }
380 
381 QuicTestClient::QuicTestClient() = default;
382 
~QuicTestClient()383 QuicTestClient::~QuicTestClient() {
384   for (std::pair<QuicStreamId, QuicSpdyClientStream*> stream : open_streams_) {
385     stream.second->set_visitor(nullptr);
386   }
387 }
388 
Initialize()389 void QuicTestClient::Initialize() {
390   priority_ = 3;
391   connect_attempted_ = false;
392   auto_reconnect_ = false;
393   buffer_body_ = true;
394   num_requests_ = 0;
395   num_responses_ = 0;
396   ClearPerConnectionState();
397   // As chrome will generally do this, we want it to be the default when it's
398   // not overridden.
399   if (!client_->config()->HasSetBytesForConnectionIdToSend()) {
400     client_->config()->SetBytesForConnectionIdToSend(0);
401   }
402 }
403 
SetUserAgentID(const std::string & user_agent_id)404 void QuicTestClient::SetUserAgentID(const std::string& user_agent_id) {
405   client_->SetUserAgentID(user_agent_id);
406 }
407 
SendRequest(const std::string & uri)408 ssize_t QuicTestClient::SendRequest(const std::string& uri) {
409   spdy::SpdyHeaderBlock headers;
410   if (!PopulateHeaderBlockFromUrl(uri, &headers)) {
411     return 0;
412   }
413   return SendMessage(headers, "");
414 }
415 
SendRequestAndRstTogether(const std::string & uri)416 ssize_t QuicTestClient::SendRequestAndRstTogether(const std::string& uri) {
417   spdy::SpdyHeaderBlock headers;
418   if (!PopulateHeaderBlockFromUrl(uri, &headers)) {
419     return 0;
420   }
421 
422   QuicSpdyClientSession* session = client()->client_session();
423   QuicConnection::ScopedPacketFlusher flusher(session->connection());
424   ssize_t ret = SendMessage(headers, "", /*fin=*/true, /*flush=*/false);
425 
426   QuicStreamId stream_id = GetNthClientInitiatedBidirectionalStreamId(
427       session->transport_version(), 0);
428   session->ResetStream(stream_id, QUIC_STREAM_CANCELLED);
429   return ret;
430 }
431 
SendRequestsAndWaitForResponses(const std::vector<std::string> & url_list)432 void QuicTestClient::SendRequestsAndWaitForResponses(
433     const std::vector<std::string>& url_list) {
434   for (const std::string& url : url_list) {
435     SendRequest(url);
436   }
437   while (client()->WaitForEvents()) {
438   }
439 }
440 
GetOrCreateStreamAndSendRequest(const spdy::SpdyHeaderBlock * headers,absl::string_view body,bool fin,QuicReferenceCountedPointer<QuicAckListenerInterface> ack_listener)441 ssize_t QuicTestClient::GetOrCreateStreamAndSendRequest(
442     const spdy::SpdyHeaderBlock* headers,
443     absl::string_view body,
444     bool fin,
445     QuicReferenceCountedPointer<QuicAckListenerInterface> ack_listener) {
446   if (headers) {
447     QuicClientPushPromiseIndex::TryHandle* handle;
448     QuicAsyncStatus rv =
449         client()->push_promise_index()->Try(*headers, this, &handle);
450     if (rv == QUIC_SUCCESS)
451       return 1;
452     if (rv == QUIC_PENDING) {
453       // May need to retry request if asynchronous rendezvous fails.
454       std::unique_ptr<spdy::SpdyHeaderBlock> new_headers(
455           new spdy::SpdyHeaderBlock(headers->Clone()));
456       push_promise_data_to_resend_ = std::make_unique<TestClientDataToResend>(
457           std::move(new_headers), body, fin, this, std::move(ack_listener));
458       return 1;
459     }
460   }
461 
462   // Maybe it's better just to overload this.  it's just that we need
463   // for the GetOrCreateStream function to call something else...which
464   // is icky and complicated, but maybe not worse than this.
465   QuicSpdyClientStream* stream = GetOrCreateStream();
466   if (stream == nullptr) {
467     return 0;
468   }
469   QuicSpdyStreamPeer::set_ack_listener(stream, ack_listener);
470 
471   ssize_t ret = 0;
472   if (headers != nullptr) {
473     spdy::SpdyHeaderBlock spdy_headers(headers->Clone());
474     if (spdy_headers[":authority"].as_string().empty()) {
475       spdy_headers[":authority"] = client_->server_id().host();
476     }
477     ret = stream->SendRequest(std::move(spdy_headers), body, fin);
478     ++num_requests_;
479   } else {
480     stream->WriteOrBufferBody(std::string(body), fin);
481     ret = body.length();
482   }
483   return ret;
484 }
485 
SendMessage(const spdy::SpdyHeaderBlock & headers,absl::string_view body)486 ssize_t QuicTestClient::SendMessage(const spdy::SpdyHeaderBlock& headers,
487                                     absl::string_view body) {
488   return SendMessage(headers, body, /*fin=*/true);
489 }
490 
SendMessage(const spdy::SpdyHeaderBlock & headers,absl::string_view body,bool fin)491 ssize_t QuicTestClient::SendMessage(const spdy::SpdyHeaderBlock& headers,
492                                     absl::string_view body,
493                                     bool fin) {
494   return SendMessage(headers, body, fin, /*flush=*/true);
495 }
496 
SendMessage(const spdy::SpdyHeaderBlock & headers,absl::string_view body,bool fin,bool flush)497 ssize_t QuicTestClient::SendMessage(const spdy::SpdyHeaderBlock& headers,
498                                     absl::string_view body,
499                                     bool fin,
500                                     bool flush) {
501   // Always force creation of a stream for SendMessage.
502   latest_created_stream_ = nullptr;
503 
504   ssize_t ret = GetOrCreateStreamAndSendRequest(&headers, body, fin, nullptr);
505 
506   if (flush) {
507     WaitForWriteToFlush();
508   }
509   return ret;
510 }
511 
SendData(const std::string & data,bool last_data)512 ssize_t QuicTestClient::SendData(const std::string& data, bool last_data) {
513   return SendData(data, last_data, nullptr);
514 }
515 
SendData(const std::string & data,bool last_data,QuicReferenceCountedPointer<QuicAckListenerInterface> ack_listener)516 ssize_t QuicTestClient::SendData(
517     const std::string& data,
518     bool last_data,
519     QuicReferenceCountedPointer<QuicAckListenerInterface> ack_listener) {
520   return GetOrCreateStreamAndSendRequest(nullptr, absl::string_view(data),
521                                          last_data, std::move(ack_listener));
522 }
523 
response_complete() const524 bool QuicTestClient::response_complete() const {
525   return response_complete_;
526 }
527 
response_body_size() const528 int64_t QuicTestClient::response_body_size() const {
529   return response_body_size_;
530 }
531 
buffer_body() const532 bool QuicTestClient::buffer_body() const {
533   return buffer_body_;
534 }
535 
set_buffer_body(bool buffer_body)536 void QuicTestClient::set_buffer_body(bool buffer_body) {
537   buffer_body_ = buffer_body;
538 }
539 
response_body() const540 const std::string& QuicTestClient::response_body() const {
541   return response_;
542 }
543 
SendCustomSynchronousRequest(const spdy::SpdyHeaderBlock & headers,const std::string & body)544 std::string QuicTestClient::SendCustomSynchronousRequest(
545     const spdy::SpdyHeaderBlock& headers,
546     const std::string& body) {
547   // Clear connection state here and only track this synchronous request.
548   ClearPerConnectionState();
549   if (SendMessage(headers, body) == 0) {
550     QUIC_DLOG(ERROR) << "Failed the request for: " << headers.DebugString();
551     // Set the response_ explicitly.  Otherwise response_ will contain the
552     // response from the previously successful request.
553     response_ = "";
554   } else {
555     WaitForResponse();
556   }
557   return response_;
558 }
559 
SendSynchronousRequest(const std::string & uri)560 std::string QuicTestClient::SendSynchronousRequest(const std::string& uri) {
561   spdy::SpdyHeaderBlock headers;
562   if (!PopulateHeaderBlockFromUrl(uri, &headers)) {
563     return "";
564   }
565   return SendCustomSynchronousRequest(headers, "");
566 }
567 
SendConnectivityProbing()568 void QuicTestClient::SendConnectivityProbing() {
569   QuicConnection* connection = client()->client_session()->connection();
570   connection->SendConnectivityProbingPacket(connection->writer(),
571                                             connection->peer_address());
572 }
573 
SetLatestCreatedStream(QuicSpdyClientStream * stream)574 void QuicTestClient::SetLatestCreatedStream(QuicSpdyClientStream* stream) {
575   latest_created_stream_ = stream;
576   if (latest_created_stream_ != nullptr) {
577     open_streams_[stream->id()] = stream;
578     stream->set_visitor(this);
579   }
580 }
581 
GetOrCreateStream()582 QuicSpdyClientStream* QuicTestClient::GetOrCreateStream() {
583   if (!connect_attempted_ || auto_reconnect_) {
584     if (!connected()) {
585       Connect();
586     }
587     if (!connected()) {
588       return nullptr;
589     }
590   }
591   if (open_streams_.empty()) {
592     ClearPerConnectionState();
593   }
594   if (!latest_created_stream_) {
595     SetLatestCreatedStream(client_->CreateClientStream());
596     if (latest_created_stream_) {
597       latest_created_stream_->SetPriority(
598           spdy::SpdyStreamPrecedence(priority_));
599     }
600   }
601 
602   return latest_created_stream_;
603 }
604 
connection_error()605 QuicErrorCode QuicTestClient::connection_error() {
606   return client()->connection_error();
607 }
608 
cert_common_name() const609 const std::string& QuicTestClient::cert_common_name() const {
610   return reinterpret_cast<RecordingProofVerifier*>(client_->proof_verifier())
611       ->common_name();
612 }
613 
cert_sct() const614 const std::string& QuicTestClient::cert_sct() const {
615   return reinterpret_cast<RecordingProofVerifier*>(client_->proof_verifier())
616       ->cert_sct();
617 }
618 
GetServerConfig() const619 QuicTagValueMap QuicTestClient::GetServerConfig() const {
620   QuicCryptoClientConfig* config = client_->crypto_config();
621   QuicCryptoClientConfig::CachedState* state =
622       config->LookupOrCreate(client_->server_id());
623   const CryptoHandshakeMessage* handshake_msg = state->GetServerConfig();
624   if (handshake_msg != nullptr) {
625     return handshake_msg->tag_value_map();
626   } else {
627     return QuicTagValueMap();
628   }
629 }
630 
connected() const631 bool QuicTestClient::connected() const {
632   return client_->connected();
633 }
634 
Connect()635 void QuicTestClient::Connect() {
636   if (connected()) {
637     QUIC_BUG << "Cannot connect already-connected client";
638     return;
639   }
640   if (!connect_attempted_) {
641     client_->Initialize();
642   }
643 
644   // If we've been asked to override SNI, set it now
645   if (override_sni_set_) {
646     client_->set_server_id(
647         QuicServerId(override_sni_, address().port(), false));
648   }
649 
650   client_->Connect();
651   connect_attempted_ = true;
652 }
653 
ResetConnection()654 void QuicTestClient::ResetConnection() {
655   Disconnect();
656   Connect();
657 }
658 
Disconnect()659 void QuicTestClient::Disconnect() {
660   ClearPerConnectionState();
661   client_->Disconnect();
662   connect_attempted_ = false;
663 }
664 
local_address() const665 QuicSocketAddress QuicTestClient::local_address() const {
666   return client_->network_helper()->GetLatestClientAddress();
667 }
668 
ClearPerRequestState()669 void QuicTestClient::ClearPerRequestState() {
670   stream_error_ = QUIC_STREAM_NO_ERROR;
671   response_ = "";
672   response_complete_ = false;
673   response_headers_complete_ = false;
674   preliminary_headers_.clear();
675   response_headers_.clear();
676   response_trailers_.clear();
677   bytes_read_ = 0;
678   bytes_written_ = 0;
679   response_body_size_ = 0;
680 }
681 
HaveActiveStream()682 bool QuicTestClient::HaveActiveStream() {
683   return push_promise_data_to_resend_.get() || !open_streams_.empty();
684 }
685 
WaitUntil(int timeout_ms,std::function<bool ()> trigger)686 bool QuicTestClient::WaitUntil(int timeout_ms, std::function<bool()> trigger) {
687   int64_t timeout_us = timeout_ms * kNumMicrosPerMilli;
688   int64_t old_timeout_us = epoll_server()->timeout_in_us_for_test();
689   if (timeout_us > 0) {
690     epoll_server()->set_timeout_in_us(timeout_us);
691   }
692   const QuicClock* clock =
693       QuicConnectionPeer::GetHelper(client()->session()->connection())
694           ->GetClock();
695   QuicTime end_waiting_time =
696       clock->Now() + QuicTime::Delta::FromMicroseconds(timeout_us);
697   while (HaveActiveStream() && !(trigger && trigger()) &&
698          (timeout_us < 0 || clock->Now() < end_waiting_time)) {
699     client_->WaitForEvents();
700   }
701   ReadNextResponse();
702   if (timeout_us > 0) {
703     epoll_server()->set_timeout_in_us(old_timeout_us);
704   }
705   if (trigger && !trigger()) {
706     QUIC_VLOG(1) << "Client WaitUntil returning with trigger returning false.";
707     return false;
708   }
709   return true;
710 }
711 
Send(absl::string_view data)712 ssize_t QuicTestClient::Send(absl::string_view data) {
713   return SendData(std::string(data), false);
714 }
715 
response_headers_complete() const716 bool QuicTestClient::response_headers_complete() const {
717   for (std::pair<QuicStreamId, QuicSpdyClientStream*> stream : open_streams_) {
718     if (stream.second->headers_decompressed()) {
719       return true;
720     }
721   }
722   return response_headers_complete_;
723 }
724 
response_headers() const725 const spdy::SpdyHeaderBlock* QuicTestClient::response_headers() const {
726   for (std::pair<QuicStreamId, QuicSpdyClientStream*> stream : open_streams_) {
727     if (stream.second->headers_decompressed()) {
728       response_headers_ = stream.second->response_headers().Clone();
729       break;
730     }
731   }
732   return &response_headers_;
733 }
734 
preliminary_headers() const735 const spdy::SpdyHeaderBlock* QuicTestClient::preliminary_headers() const {
736   for (std::pair<QuicStreamId, QuicSpdyClientStream*> stream : open_streams_) {
737     size_t bytes_read =
738         stream.second->stream_bytes_read() + stream.second->header_bytes_read();
739     if (bytes_read > 0) {
740       preliminary_headers_ = stream.second->preliminary_headers().Clone();
741       break;
742     }
743   }
744   return &preliminary_headers_;
745 }
746 
response_trailers() const747 const spdy::SpdyHeaderBlock& QuicTestClient::response_trailers() const {
748   return response_trailers_;
749 }
750 
response_size() const751 int64_t QuicTestClient::response_size() const {
752   return bytes_read();
753 }
754 
bytes_read() const755 size_t QuicTestClient::bytes_read() const {
756   for (std::pair<QuicStreamId, QuicSpdyClientStream*> stream : open_streams_) {
757     size_t bytes_read = stream.second->total_body_bytes_read() +
758                         stream.second->header_bytes_read();
759     if (bytes_read > 0) {
760       return bytes_read;
761     }
762   }
763   return bytes_read_;
764 }
765 
bytes_written() const766 size_t QuicTestClient::bytes_written() const {
767   for (std::pair<QuicStreamId, QuicSpdyClientStream*> stream : open_streams_) {
768     size_t bytes_written = stream.second->stream_bytes_written() +
769                            stream.second->header_bytes_written();
770     if (bytes_written > 0) {
771       return bytes_written;
772     }
773   }
774   return bytes_written_;
775 }
776 
OnClose(QuicSpdyStream * stream)777 void QuicTestClient::OnClose(QuicSpdyStream* stream) {
778   if (stream == nullptr) {
779     return;
780   }
781   // Always close the stream, regardless of whether it was the last stream
782   // written.
783   client()->OnClose(stream);
784   ++num_responses_;
785   if (!QuicContainsKey(open_streams_, stream->id())) {
786     return;
787   }
788   if (latest_created_stream_ == stream) {
789     latest_created_stream_ = nullptr;
790   }
791   QuicSpdyClientStream* client_stream =
792       static_cast<QuicSpdyClientStream*>(stream);
793   QuicStreamId id = client_stream->id();
794   closed_stream_states_.insert(std::make_pair(
795       id,
796       PerStreamState(
797           client_stream->stream_error(), true,
798           client_stream->headers_decompressed(),
799           client_stream->response_headers(),
800           client_stream->preliminary_headers(),
801           (buffer_body() ? client_stream->data() : ""),
802           client_stream->received_trailers(),
803           // Use NumBytesConsumed to avoid counting retransmitted stream frames.
804           client_stream->total_body_bytes_read() +
805               client_stream->header_bytes_read(),
806           client_stream->stream_bytes_written() +
807               client_stream->header_bytes_written(),
808           client_stream->data().size())));
809   open_streams_.erase(id);
810 }
811 
CheckVary(const spdy::SpdyHeaderBlock &,const spdy::SpdyHeaderBlock &,const spdy::SpdyHeaderBlock &)812 bool QuicTestClient::CheckVary(
813     const spdy::SpdyHeaderBlock& /*client_request*/,
814     const spdy::SpdyHeaderBlock& /*promise_request*/,
815     const spdy::SpdyHeaderBlock& /*promise_response*/) {
816   return true;
817 }
818 
OnRendezvousResult(QuicSpdyStream * stream)819 void QuicTestClient::OnRendezvousResult(QuicSpdyStream* stream) {
820   std::unique_ptr<TestClientDataToResend> data_to_resend =
821       std::move(push_promise_data_to_resend_);
822   SetLatestCreatedStream(static_cast<QuicSpdyClientStream*>(stream));
823   if (stream) {
824     stream->OnBodyAvailable();
825   } else if (data_to_resend) {
826     data_to_resend->Resend();
827   }
828 }
829 
UseWriter(QuicPacketWriterWrapper * writer)830 void QuicTestClient::UseWriter(QuicPacketWriterWrapper* writer) {
831   client_->UseWriter(writer);
832 }
833 
UseConnectionId(QuicConnectionId server_connection_id)834 void QuicTestClient::UseConnectionId(QuicConnectionId server_connection_id) {
835   DCHECK(!connected());
836   client_->UseConnectionId(server_connection_id);
837 }
838 
UseConnectionIdLength(int server_connection_id_length)839 void QuicTestClient::UseConnectionIdLength(int server_connection_id_length) {
840   DCHECK(!connected());
841   client_->UseConnectionIdLength(server_connection_id_length);
842 }
843 
UseClientConnectionId(QuicConnectionId client_connection_id)844 void QuicTestClient::UseClientConnectionId(
845     QuicConnectionId client_connection_id) {
846   DCHECK(!connected());
847   client_->UseClientConnectionId(client_connection_id);
848 }
849 
UseClientConnectionIdLength(int client_connection_id_length)850 void QuicTestClient::UseClientConnectionIdLength(
851     int client_connection_id_length) {
852   DCHECK(!connected());
853   client_->UseClientConnectionIdLength(client_connection_id_length);
854 }
855 
MigrateSocket(const QuicIpAddress & new_host)856 bool QuicTestClient::MigrateSocket(const QuicIpAddress& new_host) {
857   return client_->MigrateSocket(new_host);
858 }
859 
MigrateSocketWithSpecifiedPort(const QuicIpAddress & new_host,int port)860 bool QuicTestClient::MigrateSocketWithSpecifiedPort(
861     const QuicIpAddress& new_host,
862     int port) {
863   client_->set_local_port(port);
864   return client_->MigrateSocket(new_host);
865 }
866 
bind_to_address() const867 QuicIpAddress QuicTestClient::bind_to_address() const {
868   return client_->bind_to_address();
869 }
870 
set_bind_to_address(QuicIpAddress address)871 void QuicTestClient::set_bind_to_address(QuicIpAddress address) {
872   client_->set_bind_to_address(address);
873 }
874 
address() const875 const QuicSocketAddress& QuicTestClient::address() const {
876   return client_->server_address();
877 }
878 
WaitForWriteToFlush()879 void QuicTestClient::WaitForWriteToFlush() {
880   while (connected() && client()->session()->HasDataToWrite()) {
881     client_->WaitForEvents();
882   }
883 }
884 
TestClientDataToResend(std::unique_ptr<spdy::SpdyHeaderBlock> headers,absl::string_view body,bool fin,QuicTestClient * test_client,QuicReferenceCountedPointer<QuicAckListenerInterface> ack_listener)885 QuicTestClient::TestClientDataToResend::TestClientDataToResend(
886     std::unique_ptr<spdy::SpdyHeaderBlock> headers,
887     absl::string_view body,
888     bool fin,
889     QuicTestClient* test_client,
890     QuicReferenceCountedPointer<QuicAckListenerInterface> ack_listener)
891     : QuicClient::QuicDataToResend(std::move(headers), body, fin),
892       test_client_(test_client),
893       ack_listener_(std::move(ack_listener)) {}
894 
895 QuicTestClient::TestClientDataToResend::~TestClientDataToResend() = default;
896 
Resend()897 void QuicTestClient::TestClientDataToResend::Resend() {
898   test_client_->GetOrCreateStreamAndSendRequest(headers_.get(), body_, fin_,
899                                                 ack_listener_);
900   headers_.reset();
901 }
902 
PerStreamState(const PerStreamState & other)903 QuicTestClient::PerStreamState::PerStreamState(const PerStreamState& other)
904     : stream_error(other.stream_error),
905       response_complete(other.response_complete),
906       response_headers_complete(other.response_headers_complete),
907       response_headers(other.response_headers.Clone()),
908       preliminary_headers(other.preliminary_headers.Clone()),
909       response(other.response),
910       response_trailers(other.response_trailers.Clone()),
911       bytes_read(other.bytes_read),
912       bytes_written(other.bytes_written),
913       response_body_size(other.response_body_size) {}
914 
PerStreamState(QuicRstStreamErrorCode stream_error,bool response_complete,bool response_headers_complete,const spdy::SpdyHeaderBlock & response_headers,const spdy::SpdyHeaderBlock & preliminary_headers,const std::string & response,const spdy::SpdyHeaderBlock & response_trailers,uint64_t bytes_read,uint64_t bytes_written,int64_t response_body_size)915 QuicTestClient::PerStreamState::PerStreamState(
916     QuicRstStreamErrorCode stream_error,
917     bool response_complete,
918     bool response_headers_complete,
919     const spdy::SpdyHeaderBlock& response_headers,
920     const spdy::SpdyHeaderBlock& preliminary_headers,
921     const std::string& response,
922     const spdy::SpdyHeaderBlock& response_trailers,
923     uint64_t bytes_read,
924     uint64_t bytes_written,
925     int64_t response_body_size)
926     : stream_error(stream_error),
927       response_complete(response_complete),
928       response_headers_complete(response_headers_complete),
929       response_headers(response_headers.Clone()),
930       preliminary_headers(preliminary_headers.Clone()),
931       response(response),
932       response_trailers(response_trailers.Clone()),
933       bytes_read(bytes_read),
934       bytes_written(bytes_written),
935       response_body_size(response_body_size) {}
936 
937 QuicTestClient::PerStreamState::~PerStreamState() = default;
938 
PopulateHeaderBlockFromUrl(const std::string & uri,spdy::SpdyHeaderBlock * headers)939 bool QuicTestClient::PopulateHeaderBlockFromUrl(
940     const std::string& uri,
941     spdy::SpdyHeaderBlock* headers) {
942   std::string url;
943   if (absl::StartsWith(uri, "https://") || absl::StartsWith(uri, "http://")) {
944     url = uri;
945   } else if (uri[0] == '/') {
946     url = "https://" + client_->server_id().host() + uri;
947   } else {
948     url = "https://" + uri;
949   }
950   return SpdyUtils::PopulateHeaderBlockFromUrl(url, headers);
951 }
952 
ReadNextResponse()953 void QuicTestClient::ReadNextResponse() {
954   if (closed_stream_states_.empty()) {
955     return;
956   }
957 
958   PerStreamState state(closed_stream_states_.front().second);
959 
960   stream_error_ = state.stream_error;
961   response_ = state.response;
962   response_complete_ = state.response_complete;
963   response_headers_complete_ = state.response_headers_complete;
964   preliminary_headers_ = state.preliminary_headers.Clone();
965   response_headers_ = state.response_headers.Clone();
966   response_trailers_ = state.response_trailers.Clone();
967   bytes_read_ = state.bytes_read;
968   bytes_written_ = state.bytes_written;
969   response_body_size_ = state.response_body_size;
970 
971   closed_stream_states_.pop_front();
972 }
973 
ClearPerConnectionState()974 void QuicTestClient::ClearPerConnectionState() {
975   ClearPerRequestState();
976   open_streams_.clear();
977   closed_stream_states_.clear();
978   latest_created_stream_ = nullptr;
979 }
980 
WaitForDelayedAcks()981 void QuicTestClient::WaitForDelayedAcks() {
982   // kWaitDuration is a period of time that is long enough for all delayed
983   // acks to be sent and received on the other end.
984   const QuicTime::Delta kWaitDuration =
985       4 * QuicTime::Delta::FromMilliseconds(kDefaultDelayedAckTimeMs);
986 
987   const QuicClock* clock = client()->client_session()->connection()->clock();
988 
989   QuicTime wait_until = clock->ApproximateNow() + kWaitDuration;
990   while (connected() && clock->ApproximateNow() < wait_until) {
991     // This waits for up to 50 ms.
992     client()->WaitForEvents();
993   }
994 }
995 
996 }  // namespace test
997 }  // namespace quic
998