1 // Copyright 2015 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #include "net/websockets/websocket_stream_create_test_base.h"
6 
7 #include <utility>
8 
9 #include "base/callback.h"
10 #include "base/macros.h"
11 #include "net/base/ip_endpoint.h"
12 #include "net/http/http_request_headers.h"
13 #include "net/http/http_response_headers.h"
14 #include "net/log/net_log_with_source.h"
15 #include "net/traffic_annotation/network_traffic_annotation_test_helper.h"
16 #include "net/websockets/websocket_basic_handshake_stream.h"
17 #include "net/websockets/websocket_handshake_request_info.h"
18 #include "net/websockets/websocket_handshake_response_info.h"
19 #include "net/websockets/websocket_stream.h"
20 #include "url/gurl.h"
21 #include "url/origin.h"
22 
23 namespace net {
24 
25 using HeaderKeyValuePair = WebSocketStreamCreateTestBase::HeaderKeyValuePair;
26 
27 class WebSocketStreamCreateTestBase::TestConnectDelegate
28     : public WebSocketStream::ConnectDelegate {
29  public:
TestConnectDelegate(WebSocketStreamCreateTestBase * owner,base::OnceClosure done_callback)30   TestConnectDelegate(WebSocketStreamCreateTestBase* owner,
31                       base::OnceClosure done_callback)
32       : owner_(owner), done_callback_(std::move(done_callback)) {}
33 
OnCreateRequest(URLRequest * request)34   void OnCreateRequest(URLRequest* request) override {
35     owner_->url_request_ = request;
36   }
37 
OnSuccess(std::unique_ptr<WebSocketStream> stream,std::unique_ptr<WebSocketHandshakeResponseInfo> response)38   void OnSuccess(
39       std::unique_ptr<WebSocketStream> stream,
40       std::unique_ptr<WebSocketHandshakeResponseInfo> response) override {
41     if (owner_->response_info_)
42       ADD_FAILURE();
43     owner_->response_info_ = std::move(response);
44     stream.swap(owner_->stream_);
45     std::move(done_callback_).Run();
46   }
47 
OnFailure(const std::string & message,int net_error,base::Optional<int> response_code)48   void OnFailure(const std::string& message,
49                  int net_error,
50                  base::Optional<int> response_code) override {
51     owner_->has_failed_ = true;
52     owner_->failure_message_ = message;
53     owner_->failure_response_code_ = response_code.value_or(-1);
54     std::move(done_callback_).Run();
55   }
56 
OnStartOpeningHandshake(std::unique_ptr<WebSocketHandshakeRequestInfo> request)57   void OnStartOpeningHandshake(
58       std::unique_ptr<WebSocketHandshakeRequestInfo> request) override {
59     // Can be called multiple times (in the case of HTTP auth). Last call
60     // wins.
61     owner_->request_info_ = std::move(request);
62   }
63 
OnSSLCertificateError(std::unique_ptr<WebSocketEventInterface::SSLErrorCallbacks> ssl_error_callbacks,int net_error,const SSLInfo & ssl_info,bool fatal)64   void OnSSLCertificateError(
65       std::unique_ptr<WebSocketEventInterface::SSLErrorCallbacks>
66           ssl_error_callbacks,
67       int net_error,
68       const SSLInfo& ssl_info,
69       bool fatal) override {
70     owner_->ssl_error_callbacks_ = std::move(ssl_error_callbacks);
71     owner_->ssl_info_ = ssl_info;
72     owner_->ssl_fatal_ = fatal;
73   }
74 
OnAuthRequired(const AuthChallengeInfo & auth_info,scoped_refptr<HttpResponseHeaders> response_headers,const IPEndPoint & remote_endpoint,base::OnceCallback<void (const AuthCredentials *)> callback,base::Optional<AuthCredentials> * credentials)75   int OnAuthRequired(const AuthChallengeInfo& auth_info,
76                      scoped_refptr<HttpResponseHeaders> response_headers,
77                      const IPEndPoint& remote_endpoint,
78                      base::OnceCallback<void(const AuthCredentials*)> callback,
79                      base::Optional<AuthCredentials>* credentials) override {
80     owner_->run_loop_waiting_for_on_auth_required_.Quit();
81     owner_->auth_challenge_info_ = auth_info;
82     *credentials = owner_->auth_credentials_;
83     owner_->on_auth_required_callback_ = std::move(callback);
84     return owner_->on_auth_required_rv_;
85   }
86 
87  private:
88   WebSocketStreamCreateTestBase* owner_;
89   base::OnceClosure done_callback_;
90   DISALLOW_COPY_AND_ASSIGN(TestConnectDelegate);
91 };
92 
WebSocketStreamCreateTestBase()93 WebSocketStreamCreateTestBase::WebSocketStreamCreateTestBase()
94     : has_failed_(false), ssl_fatal_(false), url_request_(nullptr) {}
95 
96 WebSocketStreamCreateTestBase::~WebSocketStreamCreateTestBase() = default;
97 
CreateAndConnectStream(const GURL & socket_url,const std::vector<std::string> & sub_protocols,const url::Origin & origin,const SiteForCookies & site_for_cookies,const IsolationInfo & isolation_info,const HttpRequestHeaders & additional_headers,std::unique_ptr<base::OneShotTimer> timer)98 void WebSocketStreamCreateTestBase::CreateAndConnectStream(
99     const GURL& socket_url,
100     const std::vector<std::string>& sub_protocols,
101     const url::Origin& origin,
102     const SiteForCookies& site_for_cookies,
103     const IsolationInfo& isolation_info,
104     const HttpRequestHeaders& additional_headers,
105     std::unique_ptr<base::OneShotTimer> timer) {
106   auto connect_delegate = std::make_unique<TestConnectDelegate>(
107       this, connect_run_loop_.QuitClosure());
108   auto api_delegate = std::make_unique<TestWebSocketStreamRequestAPI>();
109   stream_request_ = WebSocketStream::CreateAndConnectStreamForTesting(
110       socket_url, sub_protocols, origin, site_for_cookies, isolation_info,
111       additional_headers, url_request_context_host_.GetURLRequestContext(),
112       NetLogWithSource(), TRAFFIC_ANNOTATION_FOR_TESTS,
113       std::move(connect_delegate),
114       timer ? std::move(timer) : std::make_unique<base::OneShotTimer>(),
115       std::move(api_delegate));
116 }
117 
118 std::vector<HeaderKeyValuePair>
RequestHeadersToVector(const HttpRequestHeaders & headers)119 WebSocketStreamCreateTestBase::RequestHeadersToVector(
120     const HttpRequestHeaders& headers) {
121   HttpRequestHeaders::Iterator it(headers);
122   std::vector<HeaderKeyValuePair> result;
123   while (it.GetNext())
124     result.push_back(HeaderKeyValuePair(it.name(), it.value()));
125   return result;
126 }
127 
128 std::vector<HeaderKeyValuePair>
ResponseHeadersToVector(const HttpResponseHeaders & headers)129 WebSocketStreamCreateTestBase::ResponseHeadersToVector(
130     const HttpResponseHeaders& headers) {
131   size_t iter = 0;
132   std::string name, value;
133   std::vector<HeaderKeyValuePair> result;
134   while (headers.EnumerateHeaderLines(&iter, &name, &value))
135     result.push_back(HeaderKeyValuePair(name, value));
136   return result;
137 }
138 
WaitUntilConnectDone()139 void WebSocketStreamCreateTestBase::WaitUntilConnectDone() {
140   connect_run_loop_.Run();
141 }
142 
WaitUntilOnAuthRequired()143 void WebSocketStreamCreateTestBase::WaitUntilOnAuthRequired() {
144   run_loop_waiting_for_on_auth_required_.Run();
145 }
146 
NoSubProtocols()147 std::vector<std::string> WebSocketStreamCreateTestBase::NoSubProtocols() {
148   return std::vector<std::string>();
149 }
150 
151 }  // namespace net
152