1 // Copyright 2017 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 #include "components/safe_browsing/content/renderer/websocket_sb_handshake_throttle.h"
6
7 #include <utility>
8
9 #include "base/callback.h"
10 #include "base/run_loop.h"
11 #include "base/test/task_environment.h"
12 #include "components/safe_browsing/content/common/safe_browsing.mojom.h"
13 #include "components/safe_browsing/core/common/safe_browsing_url_checker.mojom.h"
14 #include "ipc/ipc_message.h"
15 #include "mojo/public/cpp/bindings/pending_receiver.h"
16 #include "mojo/public/cpp/bindings/receiver.h"
17 #include "mojo/public/cpp/bindings/remote.h"
18 #include "net/http/http_request_headers.h"
19 #include "testing/gmock/include/gmock/gmock.h"
20 #include "testing/gtest/include/gtest/gtest.h"
21 #include "third_party/blink/public/mojom/loader/resource_load_info.mojom-shared.h"
22 #include "third_party/blink/public/platform/web_string.h"
23 #include "third_party/blink/public/platform/web_url.h"
24
25 namespace safe_browsing {
26
27 namespace {
28
29 constexpr char kTestUrl[] = "wss://test/";
30
31 class FakeSafeBrowsing : public mojom::SafeBrowsing {
32 public:
FakeSafeBrowsing()33 FakeSafeBrowsing()
34 : render_frame_id_(),
35 load_flags_(-1),
36 resource_type_(),
37 has_user_gesture_(false),
38 originated_from_service_worker_(false) {}
39
CreateCheckerAndCheck(int32_t render_frame_id,mojo::PendingReceiver<mojom::SafeBrowsingUrlChecker> receiver,const GURL & url,const std::string & method,const net::HttpRequestHeaders & headers,int32_t load_flags,blink::mojom::ResourceType resource_type,bool has_user_gesture,bool originated_from_service_worker,CreateCheckerAndCheckCallback callback)40 void CreateCheckerAndCheck(
41 int32_t render_frame_id,
42 mojo::PendingReceiver<mojom::SafeBrowsingUrlChecker> receiver,
43 const GURL& url,
44 const std::string& method,
45 const net::HttpRequestHeaders& headers,
46 int32_t load_flags,
47 blink::mojom::ResourceType resource_type,
48 bool has_user_gesture,
49 bool originated_from_service_worker,
50 CreateCheckerAndCheckCallback callback) override {
51 render_frame_id_ = render_frame_id;
52 receiver_ = std::move(receiver);
53 url_ = url;
54 method_ = method;
55 headers_ = headers;
56 load_flags_ = load_flags;
57 resource_type_ = resource_type;
58 has_user_gesture_ = has_user_gesture;
59 originated_from_service_worker_ = originated_from_service_worker;
60 callback_ = std::move(callback);
61 run_loop_.Quit();
62 }
63
Clone(mojo::PendingReceiver<mojom::SafeBrowsing> receiver)64 void Clone(mojo::PendingReceiver<mojom::SafeBrowsing> receiver) override {
65 NOTREACHED();
66 }
67
RunUntilCalled()68 void RunUntilCalled() { run_loop_.Run(); }
69
70 int32_t render_frame_id_;
71 mojo::PendingReceiver<mojom::SafeBrowsingUrlChecker> receiver_;
72 GURL url_;
73 std::string method_;
74 net::HttpRequestHeaders headers_;
75 int32_t load_flags_;
76 blink::mojom::ResourceType resource_type_;
77 bool has_user_gesture_;
78 bool originated_from_service_worker_;
79 CreateCheckerAndCheckCallback callback_;
80 base::RunLoop run_loop_;
81 };
82
83 class FakeCallback {
84 public:
85 enum Result { RESULT_NOT_CALLED, RESULT_SUCCESS, RESULT_ERROR };
86
FakeCallback()87 FakeCallback() : result_(RESULT_NOT_CALLED) {}
88
OnCompletion(const base::Optional<blink::WebString> & message)89 void OnCompletion(const base::Optional<blink::WebString>& message) {
90 if (message) {
91 result_ = RESULT_ERROR;
92 message_ = *message;
93 run_loop_.Quit();
94 return;
95 }
96
97 result_ = RESULT_SUCCESS;
98 run_loop_.Quit();
99 }
100
RunUntilCalled()101 void RunUntilCalled() { run_loop_.Run(); }
102
RunUntilIdle()103 void RunUntilIdle() { base::RunLoop().RunUntilIdle(); }
104
105 Result result_;
106 blink::WebString message_;
107 base::RunLoop run_loop_;
108 };
109
110 class WebSocketSBHandshakeThrottleTest : public ::testing::Test {
111 protected:
WebSocketSBHandshakeThrottleTest()112 WebSocketSBHandshakeThrottleTest() : mojo_receiver_(&safe_browsing_) {
113 mojo_receiver_.Bind(safe_browsing_remote_.BindNewPipeAndPassReceiver());
114 throttle_ = std::make_unique<WebSocketSBHandshakeThrottle>(
115 safe_browsing_remote_.get(), MSG_ROUTING_NONE);
116 }
117
118 base::test::TaskEnvironment message_loop_;
119 FakeSafeBrowsing safe_browsing_;
120 mojo::Receiver<mojom::SafeBrowsing> mojo_receiver_;
121 mojo::Remote<mojom::SafeBrowsing> safe_browsing_remote_;
122 std::unique_ptr<WebSocketSBHandshakeThrottle> throttle_;
123 FakeCallback fake_callback_;
124 };
125
TEST_F(WebSocketSBHandshakeThrottleTest,Construction)126 TEST_F(WebSocketSBHandshakeThrottleTest, Construction) {}
127
TEST_F(WebSocketSBHandshakeThrottleTest,CheckArguments)128 TEST_F(WebSocketSBHandshakeThrottleTest, CheckArguments) {
129 throttle_->ThrottleHandshake(
130 GURL(kTestUrl), base::BindOnce(&FakeCallback::OnCompletion,
131 base::Unretained(&fake_callback_)));
132 safe_browsing_.RunUntilCalled();
133 EXPECT_EQ(MSG_ROUTING_NONE, safe_browsing_.render_frame_id_);
134 EXPECT_EQ(GURL(kTestUrl), safe_browsing_.url_);
135 EXPECT_EQ("GET", safe_browsing_.method_);
136 EXPECT_TRUE(safe_browsing_.headers_.GetHeaderVector().empty());
137 EXPECT_EQ(0, safe_browsing_.load_flags_);
138 EXPECT_EQ(blink::mojom::ResourceType::kSubResource,
139 safe_browsing_.resource_type_);
140 EXPECT_FALSE(safe_browsing_.has_user_gesture_);
141 EXPECT_FALSE(safe_browsing_.originated_from_service_worker_);
142 EXPECT_TRUE(safe_browsing_.callback_);
143 }
144
TEST_F(WebSocketSBHandshakeThrottleTest,Safe)145 TEST_F(WebSocketSBHandshakeThrottleTest, Safe) {
146 throttle_->ThrottleHandshake(
147 GURL(kTestUrl), base::BindOnce(&FakeCallback::OnCompletion,
148 base::Unretained(&fake_callback_)));
149 safe_browsing_.RunUntilCalled();
150 std::move(safe_browsing_.callback_).Run(mojo::NullReceiver(), true, false);
151 fake_callback_.RunUntilCalled();
152 EXPECT_EQ(FakeCallback::RESULT_SUCCESS, fake_callback_.result_);
153 }
154
TEST_F(WebSocketSBHandshakeThrottleTest,Unsafe)155 TEST_F(WebSocketSBHandshakeThrottleTest, Unsafe) {
156 throttle_->ThrottleHandshake(
157 GURL(kTestUrl), base::BindOnce(&FakeCallback::OnCompletion,
158 base::Unretained(&fake_callback_)));
159 safe_browsing_.RunUntilCalled();
160 std::move(safe_browsing_.callback_).Run(mojo::NullReceiver(), false, false);
161 fake_callback_.RunUntilCalled();
162 EXPECT_EQ(FakeCallback::RESULT_ERROR, fake_callback_.result_);
163 EXPECT_EQ(
164 blink::WebString(
165 "WebSocket connection to wss://test/ failed safe browsing check"),
166 fake_callback_.message_);
167 }
168
TEST_F(WebSocketSBHandshakeThrottleTest,SlowCheckNotifier)169 TEST_F(WebSocketSBHandshakeThrottleTest, SlowCheckNotifier) {
170 throttle_->ThrottleHandshake(
171 GURL(kTestUrl), base::BindOnce(&FakeCallback::OnCompletion,
172 base::Unretained(&fake_callback_)));
173 safe_browsing_.RunUntilCalled();
174
175 mojo::Remote<mojom::UrlCheckNotifier> slow_check_notifier;
176 std::move(safe_browsing_.callback_)
177 .Run(slow_check_notifier.BindNewPipeAndPassReceiver(), false, false);
178 fake_callback_.RunUntilIdle();
179 EXPECT_EQ(FakeCallback::RESULT_NOT_CALLED, fake_callback_.result_);
180
181 slow_check_notifier->OnCompleteCheck(true, false);
182 fake_callback_.RunUntilCalled();
183 EXPECT_EQ(FakeCallback::RESULT_SUCCESS, fake_callback_.result_);
184 }
185
TEST_F(WebSocketSBHandshakeThrottleTest,MojoServiceNotThere)186 TEST_F(WebSocketSBHandshakeThrottleTest, MojoServiceNotThere) {
187 mojo_receiver_.reset();
188 throttle_->ThrottleHandshake(
189 GURL(kTestUrl), base::BindOnce(&FakeCallback::OnCompletion,
190 base::Unretained(&fake_callback_)));
191 fake_callback_.RunUntilCalled();
192 EXPECT_EQ(FakeCallback::RESULT_SUCCESS, fake_callback_.result_);
193 }
194
195 } // namespace
196
197 } // namespace safe_browsing
198