1 // Copyright 2013 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 #include "net/websockets/websocket_handshake_stream_create_helper.h"
6
7 #include <string>
8 #include <utility>
9 #include <vector>
10
11 #include "base/macros.h"
12 #include "base/memory/scoped_refptr.h"
13 #include "base/optional.h"
14 #include "net/base/completion_once_callback.h"
15 #include "net/base/host_port_pair.h"
16 #include "net/base/ip_endpoint.h"
17 #include "net/base/load_flags.h"
18 #include "net/base/net_errors.h"
19 #include "net/base/privacy_mode.h"
20 #include "net/base/proxy_server.h"
21 #include "net/http/http_network_session.h"
22 #include "net/http/http_request_headers.h"
23 #include "net/http/http_request_info.h"
24 #include "net/http/http_response_headers.h"
25 #include "net/http/http_response_info.h"
26 #include "net/log/net_log_with_source.h"
27 #include "net/socket/client_socket_handle.h"
28 #include "net/socket/connect_job.h"
29 #include "net/socket/socket_tag.h"
30 #include "net/socket/socket_test_util.h"
31 #include "net/socket/ssl_client_socket.h"
32 #include "net/socket/websocket_endpoint_lock_manager.h"
33 #include "net/spdy/spdy_session.h"
34 #include "net/spdy/spdy_session_key.h"
35 #include "net/spdy/spdy_test_util_common.h"
36 #include "net/ssl/ssl_config.h"
37 #include "net/ssl/ssl_info.h"
38 #include "net/test/cert_test_util.h"
39 #include "net/test/gtest_util.h"
40 #include "net/test/test_data_directory.h"
41 #include "net/test/test_with_task_environment.h"
42 #include "net/traffic_annotation/network_traffic_annotation.h"
43 #include "net/websockets/websocket_basic_handshake_stream.h"
44 #include "net/websockets/websocket_stream.h"
45 #include "net/websockets/websocket_test_util.h"
46 #include "testing/gmock/include/gmock/gmock.h"
47 #include "testing/gtest/include/gtest/gtest.h"
48 #include "url/gurl.h"
49 #include "url/origin.h"
50
51 using ::net::test::IsError;
52 using ::net::test::IsOk;
53 using ::testing::StrictMock;
54 using ::testing::TestWithParam;
55 using ::testing::Values;
56 using ::testing::_;
57
58 namespace net {
59 namespace {
60
61 enum HandshakeStreamType { BASIC_HANDSHAKE_STREAM, HTTP2_HANDSHAKE_STREAM };
62
63 // This class encapsulates the details of creating a mock ClientSocketHandle.
64 class MockClientSocketHandleFactory {
65 public:
MockClientSocketHandleFactory()66 MockClientSocketHandleFactory()
67 : common_connect_job_params_(
68 socket_factory_maker_.factory(),
69 nullptr /* host_resolver */,
70 nullptr /* http_auth_cache */,
71 nullptr /* http_auth_handler_factory */,
72 nullptr /* spdy_session_pool */,
73 nullptr /* quic_supported_versions */,
74 nullptr /* quic_stream_factory */,
75 nullptr /* proxy_delegate */,
76 nullptr /* http_user_agent_settings */,
77 nullptr /* ssl_client_context */,
78 nullptr /* socket_performance_watcher_factory */,
79 nullptr /* network_quality_estimator */,
80 nullptr /* net_log */,
81 nullptr /* websocket_endpoint_lock_manager */),
82 pool_(1, 1, &common_connect_job_params_) {}
83
84 // The created socket expects |expect_written| to be written to the socket,
85 // and will respond with |return_to_read|. The test will fail if the expected
86 // text is not written, or if all the bytes are not read.
CreateClientSocketHandle(const std::string & expect_written,const std::string & return_to_read)87 std::unique_ptr<ClientSocketHandle> CreateClientSocketHandle(
88 const std::string& expect_written,
89 const std::string& return_to_read) {
90 socket_factory_maker_.SetExpectations(expect_written, return_to_read);
91 auto socket_handle = std::make_unique<ClientSocketHandle>();
92 socket_handle->Init(
93 ClientSocketPool::GroupId(
94 HostPortPair("a", 80), ClientSocketPool::SocketType::kHttp,
95 PrivacyMode::PRIVACY_MODE_DISABLED, NetworkIsolationKey(),
96 false /* disable_secure_dns */),
97 scoped_refptr<ClientSocketPool::SocketParams>(),
98 base::nullopt /* proxy_annotation_tag */, MEDIUM, SocketTag(),
99 ClientSocketPool::RespectLimits::ENABLED, CompletionOnceCallback(),
100 ClientSocketPool::ProxyAuthCallback(), &pool_, NetLogWithSource());
101 return socket_handle;
102 }
103
104 private:
105 WebSocketMockClientSocketFactoryMaker socket_factory_maker_;
106 const CommonConnectJobParams common_connect_job_params_;
107 MockTransportClientSocketPool pool_;
108
109 DISALLOW_COPY_AND_ASSIGN(MockClientSocketHandleFactory);
110 };
111
112 class TestConnectDelegate : public WebSocketStream::ConnectDelegate {
113 public:
114 ~TestConnectDelegate() override = default;
115
OnCreateRequest(URLRequest * request)116 void OnCreateRequest(URLRequest* request) override {}
OnSuccess(std::unique_ptr<WebSocketStream> stream,std::unique_ptr<WebSocketHandshakeResponseInfo> response)117 void OnSuccess(
118 std::unique_ptr<WebSocketStream> stream,
119 std::unique_ptr<WebSocketHandshakeResponseInfo> response) override {}
OnFailure(const std::string & failure_message)120 void OnFailure(const std::string& failure_message) override {}
OnStartOpeningHandshake(std::unique_ptr<WebSocketHandshakeRequestInfo> request)121 void OnStartOpeningHandshake(
122 std::unique_ptr<WebSocketHandshakeRequestInfo> request) override {}
OnSSLCertificateError(std::unique_ptr<WebSocketEventInterface::SSLErrorCallbacks> ssl_error_callbacks,int net_error,const SSLInfo & ssl_info,bool fatal)123 void OnSSLCertificateError(
124 std::unique_ptr<WebSocketEventInterface::SSLErrorCallbacks>
125 ssl_error_callbacks,
126 int net_error,
127 const SSLInfo& ssl_info,
128 bool fatal) override {}
OnAuthRequired(const AuthChallengeInfo & auth_info,scoped_refptr<HttpResponseHeaders> response_headers,const IPEndPoint & host_port_pair,base::OnceCallback<void (const AuthCredentials *)> callback,base::Optional<AuthCredentials> * credentials)129 int OnAuthRequired(const AuthChallengeInfo& auth_info,
130 scoped_refptr<HttpResponseHeaders> response_headers,
131 const IPEndPoint& host_port_pair,
132 base::OnceCallback<void(const AuthCredentials*)> callback,
133 base::Optional<AuthCredentials>* credentials) override {
134 *credentials = base::nullopt;
135 return OK;
136 }
137 };
138
139 class MockWebSocketStreamRequestAPI : public WebSocketStreamRequestAPI {
140 public:
141 ~MockWebSocketStreamRequestAPI() override = default;
142
143 MOCK_METHOD1(OnBasicHandshakeStreamCreated,
144 void(WebSocketBasicHandshakeStream* handshake_stream));
145 MOCK_METHOD1(OnHttp2HandshakeStreamCreated,
146 void(WebSocketHttp2HandshakeStream* handshake_stream));
147 MOCK_METHOD1(OnFailure, void(const std::string& message));
148 };
149
150 class WebSocketHandshakeStreamCreateHelperTest
151 : public TestWithParam<HandshakeStreamType>,
152 public WithTaskEnvironment {
153 protected:
CreateAndInitializeStream(const std::vector<std::string> & sub_protocols,const WebSocketExtraHeaders & extra_request_headers,const WebSocketExtraHeaders & extra_response_headers)154 std::unique_ptr<WebSocketStream> CreateAndInitializeStream(
155 const std::vector<std::string>& sub_protocols,
156 const WebSocketExtraHeaders& extra_request_headers,
157 const WebSocketExtraHeaders& extra_response_headers) {
158 const char kPath[] = "/";
159 const char kOrigin[] = "http://origin.example.org";
160 const GURL url("wss://www.example.org/");
161 NetLogWithSource net_log;
162
163 WebSocketHandshakeStreamCreateHelper create_helper(
164 &connect_delegate_, sub_protocols, &stream_request_);
165
166 switch (GetParam()) {
167 case BASIC_HANDSHAKE_STREAM:
168 EXPECT_CALL(stream_request_, OnBasicHandshakeStreamCreated(_)).Times(1);
169 break;
170
171 case HTTP2_HANDSHAKE_STREAM:
172 EXPECT_CALL(stream_request_, OnHttp2HandshakeStreamCreated(_)).Times(1);
173 break;
174
175 default:
176 NOTREACHED();
177 }
178
179 EXPECT_CALL(stream_request_, OnFailure(_)).Times(0);
180
181 HttpRequestInfo request_info;
182 request_info.url = url;
183 request_info.method = "GET";
184 request_info.load_flags = LOAD_DISABLE_CACHE;
185 request_info.traffic_annotation =
186 MutableNetworkTrafficAnnotationTag(TRAFFIC_ANNOTATION_FOR_TESTS);
187
188 auto headers = WebSocketCommonTestHeaders();
189
190 switch (GetParam()) {
191 case BASIC_HANDSHAKE_STREAM: {
192 std::unique_ptr<ClientSocketHandle> socket_handle =
193 socket_handle_factory_.CreateClientSocketHandle(
194 WebSocketStandardRequest(
195 kPath, "www.example.org",
196 url::Origin::Create(GURL(kOrigin)), "",
197 WebSocketExtraHeadersToString(extra_request_headers)),
198 WebSocketStandardResponse(
199 WebSocketExtraHeadersToString(extra_response_headers)));
200
201 std::unique_ptr<WebSocketHandshakeStreamBase> handshake =
202 create_helper.CreateBasicStream(std::move(socket_handle), false,
203 &websocket_endpoint_lock_manager_);
204
205 // If in future the implementation type returned by CreateBasicStream()
206 // changes, this static_cast will be wrong. However, in that case the
207 // test will fail and AddressSanitizer should identify the issue.
208 static_cast<WebSocketBasicHandshakeStream*>(handshake.get())
209 ->SetWebSocketKeyForTesting("dGhlIHNhbXBsZSBub25jZQ==");
210
211 int rv =
212 handshake->InitializeStream(&request_info, true, DEFAULT_PRIORITY,
213 net_log, CompletionOnceCallback());
214 EXPECT_THAT(rv, IsOk());
215
216 HttpResponseInfo response;
217 TestCompletionCallback request_callback;
218 rv = handshake->SendRequest(headers, &response,
219 request_callback.callback());
220 EXPECT_THAT(rv, IsOk());
221
222 TestCompletionCallback response_callback;
223 rv = handshake->ReadResponseHeaders(response_callback.callback());
224 EXPECT_THAT(rv, IsOk());
225 EXPECT_EQ(101, response.headers->response_code());
226 EXPECT_TRUE(response.headers->HasHeaderValue("Connection", "Upgrade"));
227 EXPECT_TRUE(response.headers->HasHeaderValue("Upgrade", "websocket"));
228 return handshake->Upgrade();
229 }
230 case HTTP2_HANDSHAKE_STREAM: {
231 SpdyTestUtil spdy_util;
232 spdy::SpdyHeaderBlock request_header_block = WebSocketHttp2Request(
233 kPath, "www.example.org", kOrigin, extra_request_headers);
234 spdy::SpdySerializedFrame request_headers(
235 spdy_util.ConstructSpdyHeaders(1, std::move(request_header_block),
236 DEFAULT_PRIORITY, false));
237 MockWrite writes[] = {CreateMockWrite(request_headers, 0)};
238
239 spdy::SpdyHeaderBlock response_header_block =
240 WebSocketHttp2Response(extra_response_headers);
241 spdy::SpdySerializedFrame response_headers(
242 spdy_util.ConstructSpdyResponseHeaders(
243 1, std::move(response_header_block), false));
244 MockRead reads[] = {CreateMockRead(response_headers, 1),
245 MockRead(ASYNC, 0, 2)};
246
247 SequencedSocketData data(reads, writes);
248
249 SSLSocketDataProvider ssl(ASYNC, OK);
250 ssl.ssl_info.cert =
251 ImportCertFromFile(GetTestCertsDirectory(), "wildcard.pem");
252
253 SpdySessionDependencies session_deps;
254 session_deps.socket_factory->AddSocketDataProvider(&data);
255 session_deps.socket_factory->AddSSLSocketDataProvider(&ssl);
256
257 std::unique_ptr<HttpNetworkSession> http_network_session =
258 SpdySessionDependencies::SpdyCreateSession(&session_deps);
259 const SpdySessionKey key(
260 HostPortPair::FromURL(url), ProxyServer::Direct(),
261 PRIVACY_MODE_DISABLED, SpdySessionKey::IsProxySession::kFalse,
262 SocketTag(), NetworkIsolationKey(), false /* disable_secure_dns */);
263 base::WeakPtr<SpdySession> spdy_session =
264 CreateSpdySession(http_network_session.get(), key, net_log);
265 std::unique_ptr<WebSocketHandshakeStreamBase> handshake =
266 create_helper.CreateHttp2Stream(spdy_session);
267
268 int rv = handshake->InitializeStream(
269 &request_info, true, DEFAULT_PRIORITY, NetLogWithSource(),
270 CompletionOnceCallback());
271 EXPECT_THAT(rv, IsOk());
272
273 HttpResponseInfo response;
274 TestCompletionCallback request_callback;
275 rv = handshake->SendRequest(headers, &response,
276 request_callback.callback());
277 EXPECT_THAT(rv, IsError(ERR_IO_PENDING));
278 rv = request_callback.WaitForResult();
279 EXPECT_THAT(rv, IsOk());
280
281 TestCompletionCallback response_callback;
282 rv = handshake->ReadResponseHeaders(response_callback.callback());
283 EXPECT_THAT(rv, IsError(ERR_IO_PENDING));
284 rv = response_callback.WaitForResult();
285 EXPECT_THAT(rv, IsOk());
286
287 EXPECT_EQ(200, response.headers->response_code());
288 return handshake->Upgrade();
289 }
290 default:
291 NOTREACHED();
292 return nullptr;
293 }
294 }
295
296 private:
297 MockClientSocketHandleFactory socket_handle_factory_;
298 TestConnectDelegate connect_delegate_;
299 StrictMock<MockWebSocketStreamRequestAPI> stream_request_;
300 WebSocketEndpointLockManager websocket_endpoint_lock_manager_;
301 };
302
303 INSTANTIATE_TEST_SUITE_P(All,
304 WebSocketHandshakeStreamCreateHelperTest,
305 Values(BASIC_HANDSHAKE_STREAM,
306 HTTP2_HANDSHAKE_STREAM));
307
308 // Confirm that the basic case works as expected.
TEST_P(WebSocketHandshakeStreamCreateHelperTest,BasicStream)309 TEST_P(WebSocketHandshakeStreamCreateHelperTest, BasicStream) {
310 std::unique_ptr<WebSocketStream> stream =
311 CreateAndInitializeStream({}, {}, {});
312 EXPECT_EQ("", stream->GetExtensions());
313 EXPECT_EQ("", stream->GetSubProtocol());
314 }
315
316 // Verify that the sub-protocols are passed through.
TEST_P(WebSocketHandshakeStreamCreateHelperTest,SubProtocols)317 TEST_P(WebSocketHandshakeStreamCreateHelperTest, SubProtocols) {
318 std::vector<std::string> sub_protocols;
319 sub_protocols.push_back("chat");
320 sub_protocols.push_back("superchat");
321 std::unique_ptr<WebSocketStream> stream = CreateAndInitializeStream(
322 sub_protocols, {{"Sec-WebSocket-Protocol", "chat, superchat"}},
323 {{"Sec-WebSocket-Protocol", "superchat"}});
324 EXPECT_EQ("superchat", stream->GetSubProtocol());
325 }
326
327 // Verify that extension name is available. Bad extension names are tested in
328 // websocket_stream_test.cc.
TEST_P(WebSocketHandshakeStreamCreateHelperTest,Extensions)329 TEST_P(WebSocketHandshakeStreamCreateHelperTest, Extensions) {
330 std::unique_ptr<WebSocketStream> stream = CreateAndInitializeStream(
331 {}, {}, {{"Sec-WebSocket-Extensions", "permessage-deflate"}});
332 EXPECT_EQ("permessage-deflate", stream->GetExtensions());
333 }
334
335 // Verify that extension parameters are available. Bad parameters are tested in
336 // websocket_stream_test.cc.
TEST_P(WebSocketHandshakeStreamCreateHelperTest,ExtensionParameters)337 TEST_P(WebSocketHandshakeStreamCreateHelperTest, ExtensionParameters) {
338 std::unique_ptr<WebSocketStream> stream = CreateAndInitializeStream(
339 {}, {},
340 {{"Sec-WebSocket-Extensions",
341 "permessage-deflate;"
342 " client_max_window_bits=14; server_max_window_bits=14;"
343 " server_no_context_takeover; client_no_context_takeover"}});
344
345 EXPECT_EQ(
346 "permessage-deflate;"
347 " client_max_window_bits=14; server_max_window_bits=14;"
348 " server_no_context_takeover; client_no_context_takeover",
349 stream->GetExtensions());
350 }
351
352 } // namespace
353 } // namespace net
354