1 // Copyright 2013 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #include "net/websockets/websocket_handshake_stream_create_helper.h"
6 
7 #include <string>
8 #include <utility>
9 #include <vector>
10 
11 #include "base/macros.h"
12 #include "base/memory/scoped_refptr.h"
13 #include "base/optional.h"
14 #include "net/base/completion_once_callback.h"
15 #include "net/base/host_port_pair.h"
16 #include "net/base/ip_endpoint.h"
17 #include "net/base/load_flags.h"
18 #include "net/base/net_errors.h"
19 #include "net/base/privacy_mode.h"
20 #include "net/base/proxy_server.h"
21 #include "net/http/http_network_session.h"
22 #include "net/http/http_request_headers.h"
23 #include "net/http/http_request_info.h"
24 #include "net/http/http_response_headers.h"
25 #include "net/http/http_response_info.h"
26 #include "net/log/net_log_with_source.h"
27 #include "net/socket/client_socket_handle.h"
28 #include "net/socket/connect_job.h"
29 #include "net/socket/socket_tag.h"
30 #include "net/socket/socket_test_util.h"
31 #include "net/socket/ssl_client_socket.h"
32 #include "net/socket/websocket_endpoint_lock_manager.h"
33 #include "net/spdy/spdy_session.h"
34 #include "net/spdy/spdy_session_key.h"
35 #include "net/spdy/spdy_test_util_common.h"
36 #include "net/ssl/ssl_config.h"
37 #include "net/ssl/ssl_info.h"
38 #include "net/test/cert_test_util.h"
39 #include "net/test/gtest_util.h"
40 #include "net/test/test_data_directory.h"
41 #include "net/test/test_with_task_environment.h"
42 #include "net/traffic_annotation/network_traffic_annotation.h"
43 #include "net/websockets/websocket_basic_handshake_stream.h"
44 #include "net/websockets/websocket_stream.h"
45 #include "net/websockets/websocket_test_util.h"
46 #include "testing/gmock/include/gmock/gmock.h"
47 #include "testing/gtest/include/gtest/gtest.h"
48 #include "url/gurl.h"
49 #include "url/origin.h"
50 
51 using ::net::test::IsError;
52 using ::net::test::IsOk;
53 using ::testing::StrictMock;
54 using ::testing::TestWithParam;
55 using ::testing::Values;
56 using ::testing::_;
57 
58 namespace net {
59 namespace {
60 
61 enum HandshakeStreamType { BASIC_HANDSHAKE_STREAM, HTTP2_HANDSHAKE_STREAM };
62 
63 // This class encapsulates the details of creating a mock ClientSocketHandle.
64 class MockClientSocketHandleFactory {
65  public:
MockClientSocketHandleFactory()66   MockClientSocketHandleFactory()
67       : common_connect_job_params_(
68             socket_factory_maker_.factory(),
69             nullptr /* host_resolver */,
70             nullptr /* http_auth_cache */,
71             nullptr /* http_auth_handler_factory */,
72             nullptr /* spdy_session_pool */,
73             nullptr /* quic_supported_versions */,
74             nullptr /* quic_stream_factory */,
75             nullptr /* proxy_delegate */,
76             nullptr /* http_user_agent_settings */,
77             nullptr /* ssl_client_context */,
78             nullptr /* socket_performance_watcher_factory */,
79             nullptr /* network_quality_estimator */,
80             nullptr /* net_log */,
81             nullptr /* websocket_endpoint_lock_manager */),
82         pool_(1, 1, &common_connect_job_params_) {}
83 
84   // The created socket expects |expect_written| to be written to the socket,
85   // and will respond with |return_to_read|. The test will fail if the expected
86   // text is not written, or if all the bytes are not read.
CreateClientSocketHandle(const std::string & expect_written,const std::string & return_to_read)87   std::unique_ptr<ClientSocketHandle> CreateClientSocketHandle(
88       const std::string& expect_written,
89       const std::string& return_to_read) {
90     socket_factory_maker_.SetExpectations(expect_written, return_to_read);
91     auto socket_handle = std::make_unique<ClientSocketHandle>();
92     socket_handle->Init(
93         ClientSocketPool::GroupId(
94             HostPortPair("a", 80), ClientSocketPool::SocketType::kHttp,
95             PrivacyMode::PRIVACY_MODE_DISABLED, NetworkIsolationKey(),
96             false /* disable_secure_dns */),
97         scoped_refptr<ClientSocketPool::SocketParams>(),
98         base::nullopt /* proxy_annotation_tag */, MEDIUM, SocketTag(),
99         ClientSocketPool::RespectLimits::ENABLED, CompletionOnceCallback(),
100         ClientSocketPool::ProxyAuthCallback(), &pool_, NetLogWithSource());
101     return socket_handle;
102   }
103 
104  private:
105   WebSocketMockClientSocketFactoryMaker socket_factory_maker_;
106   const CommonConnectJobParams common_connect_job_params_;
107   MockTransportClientSocketPool pool_;
108 
109   DISALLOW_COPY_AND_ASSIGN(MockClientSocketHandleFactory);
110 };
111 
112 class TestConnectDelegate : public WebSocketStream::ConnectDelegate {
113  public:
114   ~TestConnectDelegate() override = default;
115 
OnCreateRequest(URLRequest * request)116   void OnCreateRequest(URLRequest* request) override {}
OnSuccess(std::unique_ptr<WebSocketStream> stream,std::unique_ptr<WebSocketHandshakeResponseInfo> response)117   void OnSuccess(
118       std::unique_ptr<WebSocketStream> stream,
119       std::unique_ptr<WebSocketHandshakeResponseInfo> response) override {}
OnFailure(const std::string & failure_message)120   void OnFailure(const std::string& failure_message) override {}
OnStartOpeningHandshake(std::unique_ptr<WebSocketHandshakeRequestInfo> request)121   void OnStartOpeningHandshake(
122       std::unique_ptr<WebSocketHandshakeRequestInfo> request) override {}
OnSSLCertificateError(std::unique_ptr<WebSocketEventInterface::SSLErrorCallbacks> ssl_error_callbacks,int net_error,const SSLInfo & ssl_info,bool fatal)123   void OnSSLCertificateError(
124       std::unique_ptr<WebSocketEventInterface::SSLErrorCallbacks>
125           ssl_error_callbacks,
126       int net_error,
127       const SSLInfo& ssl_info,
128       bool fatal) override {}
OnAuthRequired(const AuthChallengeInfo & auth_info,scoped_refptr<HttpResponseHeaders> response_headers,const IPEndPoint & host_port_pair,base::OnceCallback<void (const AuthCredentials *)> callback,base::Optional<AuthCredentials> * credentials)129   int OnAuthRequired(const AuthChallengeInfo& auth_info,
130                      scoped_refptr<HttpResponseHeaders> response_headers,
131                      const IPEndPoint& host_port_pair,
132                      base::OnceCallback<void(const AuthCredentials*)> callback,
133                      base::Optional<AuthCredentials>* credentials) override {
134     *credentials = base::nullopt;
135     return OK;
136   }
137 };
138 
139 class MockWebSocketStreamRequestAPI : public WebSocketStreamRequestAPI {
140  public:
141   ~MockWebSocketStreamRequestAPI() override = default;
142 
143   MOCK_METHOD1(OnBasicHandshakeStreamCreated,
144                void(WebSocketBasicHandshakeStream* handshake_stream));
145   MOCK_METHOD1(OnHttp2HandshakeStreamCreated,
146                void(WebSocketHttp2HandshakeStream* handshake_stream));
147   MOCK_METHOD1(OnFailure, void(const std::string& message));
148 };
149 
150 class WebSocketHandshakeStreamCreateHelperTest
151     : public TestWithParam<HandshakeStreamType>,
152       public WithTaskEnvironment {
153  protected:
CreateAndInitializeStream(const std::vector<std::string> & sub_protocols,const WebSocketExtraHeaders & extra_request_headers,const WebSocketExtraHeaders & extra_response_headers)154   std::unique_ptr<WebSocketStream> CreateAndInitializeStream(
155       const std::vector<std::string>& sub_protocols,
156       const WebSocketExtraHeaders& extra_request_headers,
157       const WebSocketExtraHeaders& extra_response_headers) {
158     const char kPath[] = "/";
159     const char kOrigin[] = "http://origin.example.org";
160     const GURL url("wss://www.example.org/");
161     NetLogWithSource net_log;
162 
163     WebSocketHandshakeStreamCreateHelper create_helper(
164         &connect_delegate_, sub_protocols, &stream_request_);
165 
166     switch (GetParam()) {
167       case BASIC_HANDSHAKE_STREAM:
168         EXPECT_CALL(stream_request_, OnBasicHandshakeStreamCreated(_)).Times(1);
169         break;
170 
171       case HTTP2_HANDSHAKE_STREAM:
172         EXPECT_CALL(stream_request_, OnHttp2HandshakeStreamCreated(_)).Times(1);
173         break;
174 
175       default:
176         NOTREACHED();
177     }
178 
179     EXPECT_CALL(stream_request_, OnFailure(_)).Times(0);
180 
181     HttpRequestInfo request_info;
182     request_info.url = url;
183     request_info.method = "GET";
184     request_info.load_flags = LOAD_DISABLE_CACHE;
185     request_info.traffic_annotation =
186         MutableNetworkTrafficAnnotationTag(TRAFFIC_ANNOTATION_FOR_TESTS);
187 
188     auto headers = WebSocketCommonTestHeaders();
189 
190     switch (GetParam()) {
191       case BASIC_HANDSHAKE_STREAM: {
192         std::unique_ptr<ClientSocketHandle> socket_handle =
193             socket_handle_factory_.CreateClientSocketHandle(
194                 WebSocketStandardRequest(
195                     kPath, "www.example.org",
196                     url::Origin::Create(GURL(kOrigin)), "",
197                     WebSocketExtraHeadersToString(extra_request_headers)),
198                 WebSocketStandardResponse(
199                     WebSocketExtraHeadersToString(extra_response_headers)));
200 
201         std::unique_ptr<WebSocketHandshakeStreamBase> handshake =
202             create_helper.CreateBasicStream(std::move(socket_handle), false,
203                                             &websocket_endpoint_lock_manager_);
204 
205         // If in future the implementation type returned by CreateBasicStream()
206         // changes, this static_cast will be wrong. However, in that case the
207         // test will fail and AddressSanitizer should identify the issue.
208         static_cast<WebSocketBasicHandshakeStream*>(handshake.get())
209             ->SetWebSocketKeyForTesting("dGhlIHNhbXBsZSBub25jZQ==");
210 
211         int rv =
212             handshake->InitializeStream(&request_info, true, DEFAULT_PRIORITY,
213                                         net_log, CompletionOnceCallback());
214         EXPECT_THAT(rv, IsOk());
215 
216         HttpResponseInfo response;
217         TestCompletionCallback request_callback;
218         rv = handshake->SendRequest(headers, &response,
219                                     request_callback.callback());
220         EXPECT_THAT(rv, IsOk());
221 
222         TestCompletionCallback response_callback;
223         rv = handshake->ReadResponseHeaders(response_callback.callback());
224         EXPECT_THAT(rv, IsOk());
225         EXPECT_EQ(101, response.headers->response_code());
226         EXPECT_TRUE(response.headers->HasHeaderValue("Connection", "Upgrade"));
227         EXPECT_TRUE(response.headers->HasHeaderValue("Upgrade", "websocket"));
228         return handshake->Upgrade();
229       }
230       case HTTP2_HANDSHAKE_STREAM: {
231         SpdyTestUtil spdy_util;
232         spdy::SpdyHeaderBlock request_header_block = WebSocketHttp2Request(
233             kPath, "www.example.org", kOrigin, extra_request_headers);
234         spdy::SpdySerializedFrame request_headers(
235             spdy_util.ConstructSpdyHeaders(1, std::move(request_header_block),
236                                            DEFAULT_PRIORITY, false));
237         MockWrite writes[] = {CreateMockWrite(request_headers, 0)};
238 
239         spdy::SpdyHeaderBlock response_header_block =
240             WebSocketHttp2Response(extra_response_headers);
241         spdy::SpdySerializedFrame response_headers(
242             spdy_util.ConstructSpdyResponseHeaders(
243                 1, std::move(response_header_block), false));
244         MockRead reads[] = {CreateMockRead(response_headers, 1),
245                             MockRead(ASYNC, 0, 2)};
246 
247         SequencedSocketData data(reads, writes);
248 
249         SSLSocketDataProvider ssl(ASYNC, OK);
250         ssl.ssl_info.cert =
251             ImportCertFromFile(GetTestCertsDirectory(), "wildcard.pem");
252 
253         SpdySessionDependencies session_deps;
254         session_deps.socket_factory->AddSocketDataProvider(&data);
255         session_deps.socket_factory->AddSSLSocketDataProvider(&ssl);
256 
257         std::unique_ptr<HttpNetworkSession> http_network_session =
258             SpdySessionDependencies::SpdyCreateSession(&session_deps);
259         const SpdySessionKey key(
260             HostPortPair::FromURL(url), ProxyServer::Direct(),
261             PRIVACY_MODE_DISABLED, SpdySessionKey::IsProxySession::kFalse,
262             SocketTag(), NetworkIsolationKey(), false /* disable_secure_dns */);
263         base::WeakPtr<SpdySession> spdy_session =
264             CreateSpdySession(http_network_session.get(), key, net_log);
265         std::unique_ptr<WebSocketHandshakeStreamBase> handshake =
266             create_helper.CreateHttp2Stream(spdy_session);
267 
268         int rv = handshake->InitializeStream(
269             &request_info, true, DEFAULT_PRIORITY, NetLogWithSource(),
270             CompletionOnceCallback());
271         EXPECT_THAT(rv, IsOk());
272 
273         HttpResponseInfo response;
274         TestCompletionCallback request_callback;
275         rv = handshake->SendRequest(headers, &response,
276                                     request_callback.callback());
277         EXPECT_THAT(rv, IsError(ERR_IO_PENDING));
278         rv = request_callback.WaitForResult();
279         EXPECT_THAT(rv, IsOk());
280 
281         TestCompletionCallback response_callback;
282         rv = handshake->ReadResponseHeaders(response_callback.callback());
283         EXPECT_THAT(rv, IsError(ERR_IO_PENDING));
284         rv = response_callback.WaitForResult();
285         EXPECT_THAT(rv, IsOk());
286 
287         EXPECT_EQ(200, response.headers->response_code());
288         return handshake->Upgrade();
289       }
290       default:
291         NOTREACHED();
292         return nullptr;
293     }
294   }
295 
296  private:
297   MockClientSocketHandleFactory socket_handle_factory_;
298   TestConnectDelegate connect_delegate_;
299   StrictMock<MockWebSocketStreamRequestAPI> stream_request_;
300   WebSocketEndpointLockManager websocket_endpoint_lock_manager_;
301 };
302 
303 INSTANTIATE_TEST_SUITE_P(All,
304                          WebSocketHandshakeStreamCreateHelperTest,
305                          Values(BASIC_HANDSHAKE_STREAM,
306                                 HTTP2_HANDSHAKE_STREAM));
307 
308 // Confirm that the basic case works as expected.
TEST_P(WebSocketHandshakeStreamCreateHelperTest,BasicStream)309 TEST_P(WebSocketHandshakeStreamCreateHelperTest, BasicStream) {
310   std::unique_ptr<WebSocketStream> stream =
311       CreateAndInitializeStream({}, {}, {});
312   EXPECT_EQ("", stream->GetExtensions());
313   EXPECT_EQ("", stream->GetSubProtocol());
314 }
315 
316 // Verify that the sub-protocols are passed through.
TEST_P(WebSocketHandshakeStreamCreateHelperTest,SubProtocols)317 TEST_P(WebSocketHandshakeStreamCreateHelperTest, SubProtocols) {
318   std::vector<std::string> sub_protocols;
319   sub_protocols.push_back("chat");
320   sub_protocols.push_back("superchat");
321   std::unique_ptr<WebSocketStream> stream = CreateAndInitializeStream(
322       sub_protocols, {{"Sec-WebSocket-Protocol", "chat, superchat"}},
323       {{"Sec-WebSocket-Protocol", "superchat"}});
324   EXPECT_EQ("superchat", stream->GetSubProtocol());
325 }
326 
327 // Verify that extension name is available. Bad extension names are tested in
328 // websocket_stream_test.cc.
TEST_P(WebSocketHandshakeStreamCreateHelperTest,Extensions)329 TEST_P(WebSocketHandshakeStreamCreateHelperTest, Extensions) {
330   std::unique_ptr<WebSocketStream> stream = CreateAndInitializeStream(
331       {}, {}, {{"Sec-WebSocket-Extensions", "permessage-deflate"}});
332   EXPECT_EQ("permessage-deflate", stream->GetExtensions());
333 }
334 
335 // Verify that extension parameters are available. Bad parameters are tested in
336 // websocket_stream_test.cc.
TEST_P(WebSocketHandshakeStreamCreateHelperTest,ExtensionParameters)337 TEST_P(WebSocketHandshakeStreamCreateHelperTest, ExtensionParameters) {
338   std::unique_ptr<WebSocketStream> stream = CreateAndInitializeStream(
339       {}, {},
340       {{"Sec-WebSocket-Extensions",
341         "permessage-deflate;"
342         " client_max_window_bits=14; server_max_window_bits=14;"
343         " server_no_context_takeover; client_no_context_takeover"}});
344 
345   EXPECT_EQ(
346       "permessage-deflate;"
347       " client_max_window_bits=14; server_max_window_bits=14;"
348       " server_no_context_takeover; client_no_context_takeover",
349       stream->GetExtensions());
350 }
351 
352 }  // namespace
353 }  // namespace net
354