1 // Copyright 2015 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 #include "net/websockets/websocket_stream_create_test_base.h"
6
7 #include <utility>
8
9 #include "base/callback.h"
10 #include "base/macros.h"
11 #include "net/base/ip_endpoint.h"
12 #include "net/http/http_request_headers.h"
13 #include "net/http/http_response_headers.h"
14 #include "net/log/net_log_with_source.h"
15 #include "net/traffic_annotation/network_traffic_annotation_test_helper.h"
16 #include "net/websockets/websocket_basic_handshake_stream.h"
17 #include "net/websockets/websocket_handshake_request_info.h"
18 #include "net/websockets/websocket_handshake_response_info.h"
19 #include "net/websockets/websocket_stream.h"
20 #include "url/gurl.h"
21 #include "url/origin.h"
22
23 namespace net {
24
25 using HeaderKeyValuePair = WebSocketStreamCreateTestBase::HeaderKeyValuePair;
26
27 class WebSocketStreamCreateTestBase::TestConnectDelegate
28 : public WebSocketStream::ConnectDelegate {
29 public:
TestConnectDelegate(WebSocketStreamCreateTestBase * owner,base::OnceClosure done_callback)30 TestConnectDelegate(WebSocketStreamCreateTestBase* owner,
31 base::OnceClosure done_callback)
32 : owner_(owner), done_callback_(std::move(done_callback)) {}
33
OnCreateRequest(URLRequest * request)34 void OnCreateRequest(URLRequest* request) override {
35 owner_->url_request_ = request;
36 }
37
OnSuccess(std::unique_ptr<WebSocketStream> stream,std::unique_ptr<WebSocketHandshakeResponseInfo> response)38 void OnSuccess(
39 std::unique_ptr<WebSocketStream> stream,
40 std::unique_ptr<WebSocketHandshakeResponseInfo> response) override {
41 if (owner_->response_info_)
42 ADD_FAILURE();
43 owner_->response_info_ = std::move(response);
44 stream.swap(owner_->stream_);
45 std::move(done_callback_).Run();
46 }
47
OnFailure(const std::string & message,int net_error,base::Optional<int> response_code)48 void OnFailure(const std::string& message,
49 int net_error,
50 base::Optional<int> response_code) override {
51 owner_->has_failed_ = true;
52 owner_->failure_message_ = message;
53 owner_->failure_response_code_ = response_code.value_or(-1);
54 std::move(done_callback_).Run();
55 }
56
OnStartOpeningHandshake(std::unique_ptr<WebSocketHandshakeRequestInfo> request)57 void OnStartOpeningHandshake(
58 std::unique_ptr<WebSocketHandshakeRequestInfo> request) override {
59 // Can be called multiple times (in the case of HTTP auth). Last call
60 // wins.
61 owner_->request_info_ = std::move(request);
62 }
63
OnSSLCertificateError(std::unique_ptr<WebSocketEventInterface::SSLErrorCallbacks> ssl_error_callbacks,int net_error,const SSLInfo & ssl_info,bool fatal)64 void OnSSLCertificateError(
65 std::unique_ptr<WebSocketEventInterface::SSLErrorCallbacks>
66 ssl_error_callbacks,
67 int net_error,
68 const SSLInfo& ssl_info,
69 bool fatal) override {
70 owner_->ssl_error_callbacks_ = std::move(ssl_error_callbacks);
71 owner_->ssl_info_ = ssl_info;
72 owner_->ssl_fatal_ = fatal;
73 }
74
OnAuthRequired(const AuthChallengeInfo & auth_info,scoped_refptr<HttpResponseHeaders> response_headers,const IPEndPoint & remote_endpoint,base::OnceCallback<void (const AuthCredentials *)> callback,base::Optional<AuthCredentials> * credentials)75 int OnAuthRequired(const AuthChallengeInfo& auth_info,
76 scoped_refptr<HttpResponseHeaders> response_headers,
77 const IPEndPoint& remote_endpoint,
78 base::OnceCallback<void(const AuthCredentials*)> callback,
79 base::Optional<AuthCredentials>* credentials) override {
80 owner_->run_loop_waiting_for_on_auth_required_.Quit();
81 owner_->auth_challenge_info_ = auth_info;
82 *credentials = owner_->auth_credentials_;
83 owner_->on_auth_required_callback_ = std::move(callback);
84 return owner_->on_auth_required_rv_;
85 }
86
87 private:
88 WebSocketStreamCreateTestBase* owner_;
89 base::OnceClosure done_callback_;
90 DISALLOW_COPY_AND_ASSIGN(TestConnectDelegate);
91 };
92
WebSocketStreamCreateTestBase()93 WebSocketStreamCreateTestBase::WebSocketStreamCreateTestBase()
94 : has_failed_(false), ssl_fatal_(false), url_request_(nullptr) {}
95
96 WebSocketStreamCreateTestBase::~WebSocketStreamCreateTestBase() = default;
97
CreateAndConnectStream(const GURL & socket_url,const std::vector<std::string> & sub_protocols,const url::Origin & origin,const SiteForCookies & site_for_cookies,const IsolationInfo & isolation_info,const HttpRequestHeaders & additional_headers,std::unique_ptr<base::OneShotTimer> timer)98 void WebSocketStreamCreateTestBase::CreateAndConnectStream(
99 const GURL& socket_url,
100 const std::vector<std::string>& sub_protocols,
101 const url::Origin& origin,
102 const SiteForCookies& site_for_cookies,
103 const IsolationInfo& isolation_info,
104 const HttpRequestHeaders& additional_headers,
105 std::unique_ptr<base::OneShotTimer> timer) {
106 auto connect_delegate = std::make_unique<TestConnectDelegate>(
107 this, connect_run_loop_.QuitClosure());
108 auto api_delegate = std::make_unique<TestWebSocketStreamRequestAPI>();
109 stream_request_ = WebSocketStream::CreateAndConnectStreamForTesting(
110 socket_url, sub_protocols, origin, site_for_cookies, isolation_info,
111 additional_headers, url_request_context_host_.GetURLRequestContext(),
112 NetLogWithSource(), TRAFFIC_ANNOTATION_FOR_TESTS,
113 std::move(connect_delegate),
114 timer ? std::move(timer) : std::make_unique<base::OneShotTimer>(),
115 std::move(api_delegate));
116 }
117
118 std::vector<HeaderKeyValuePair>
RequestHeadersToVector(const HttpRequestHeaders & headers)119 WebSocketStreamCreateTestBase::RequestHeadersToVector(
120 const HttpRequestHeaders& headers) {
121 HttpRequestHeaders::Iterator it(headers);
122 std::vector<HeaderKeyValuePair> result;
123 while (it.GetNext())
124 result.push_back(HeaderKeyValuePair(it.name(), it.value()));
125 return result;
126 }
127
128 std::vector<HeaderKeyValuePair>
ResponseHeadersToVector(const HttpResponseHeaders & headers)129 WebSocketStreamCreateTestBase::ResponseHeadersToVector(
130 const HttpResponseHeaders& headers) {
131 size_t iter = 0;
132 std::string name, value;
133 std::vector<HeaderKeyValuePair> result;
134 while (headers.EnumerateHeaderLines(&iter, &name, &value))
135 result.push_back(HeaderKeyValuePair(name, value));
136 return result;
137 }
138
WaitUntilConnectDone()139 void WebSocketStreamCreateTestBase::WaitUntilConnectDone() {
140 connect_run_loop_.Run();
141 }
142
WaitUntilOnAuthRequired()143 void WebSocketStreamCreateTestBase::WaitUntilOnAuthRequired() {
144 run_loop_waiting_for_on_auth_required_.Run();
145 }
146
NoSubProtocols()147 std::vector<std::string> WebSocketStreamCreateTestBase::NoSubProtocols() {
148 return std::vector<std::string>();
149 }
150
151 } // namespace net
152