1 // Copyright 2017 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #include "components/safe_browsing/content/renderer/websocket_sb_handshake_throttle.h"
6 
7 #include <utility>
8 
9 #include "base/callback.h"
10 #include "base/run_loop.h"
11 #include "base/test/task_environment.h"
12 #include "components/safe_browsing/content/common/safe_browsing.mojom.h"
13 #include "components/safe_browsing/core/common/safe_browsing_url_checker.mojom.h"
14 #include "ipc/ipc_message.h"
15 #include "mojo/public/cpp/bindings/pending_receiver.h"
16 #include "mojo/public/cpp/bindings/receiver.h"
17 #include "mojo/public/cpp/bindings/remote.h"
18 #include "net/http/http_request_headers.h"
19 #include "testing/gmock/include/gmock/gmock.h"
20 #include "testing/gtest/include/gtest/gtest.h"
21 #include "third_party/blink/public/mojom/loader/resource_load_info.mojom-shared.h"
22 #include "third_party/blink/public/platform/web_string.h"
23 #include "third_party/blink/public/platform/web_url.h"
24 
25 namespace safe_browsing {
26 
27 namespace {
28 
29 constexpr char kTestUrl[] = "wss://test/";
30 
31 class FakeSafeBrowsing : public mojom::SafeBrowsing {
32  public:
FakeSafeBrowsing()33   FakeSafeBrowsing()
34       : render_frame_id_(),
35         load_flags_(-1),
36         resource_type_(),
37         has_user_gesture_(false),
38         originated_from_service_worker_(false) {}
39 
CreateCheckerAndCheck(int32_t render_frame_id,mojo::PendingReceiver<mojom::SafeBrowsingUrlChecker> receiver,const GURL & url,const std::string & method,const net::HttpRequestHeaders & headers,int32_t load_flags,blink::mojom::ResourceType resource_type,bool has_user_gesture,bool originated_from_service_worker,CreateCheckerAndCheckCallback callback)40   void CreateCheckerAndCheck(
41       int32_t render_frame_id,
42       mojo::PendingReceiver<mojom::SafeBrowsingUrlChecker> receiver,
43       const GURL& url,
44       const std::string& method,
45       const net::HttpRequestHeaders& headers,
46       int32_t load_flags,
47       blink::mojom::ResourceType resource_type,
48       bool has_user_gesture,
49       bool originated_from_service_worker,
50       CreateCheckerAndCheckCallback callback) override {
51     render_frame_id_ = render_frame_id;
52     receiver_ = std::move(receiver);
53     url_ = url;
54     method_ = method;
55     headers_ = headers;
56     load_flags_ = load_flags;
57     resource_type_ = resource_type;
58     has_user_gesture_ = has_user_gesture;
59     originated_from_service_worker_ = originated_from_service_worker;
60     callback_ = std::move(callback);
61     run_loop_.Quit();
62   }
63 
Clone(mojo::PendingReceiver<mojom::SafeBrowsing> receiver)64   void Clone(mojo::PendingReceiver<mojom::SafeBrowsing> receiver) override {
65     NOTREACHED();
66   }
67 
RunUntilCalled()68   void RunUntilCalled() { run_loop_.Run(); }
69 
70   int32_t render_frame_id_;
71   mojo::PendingReceiver<mojom::SafeBrowsingUrlChecker> receiver_;
72   GURL url_;
73   std::string method_;
74   net::HttpRequestHeaders headers_;
75   int32_t load_flags_;
76   blink::mojom::ResourceType resource_type_;
77   bool has_user_gesture_;
78   bool originated_from_service_worker_;
79   CreateCheckerAndCheckCallback callback_;
80   base::RunLoop run_loop_;
81 };
82 
83 class FakeCallback {
84  public:
85   enum Result { RESULT_NOT_CALLED, RESULT_SUCCESS, RESULT_ERROR };
86 
FakeCallback()87   FakeCallback() : result_(RESULT_NOT_CALLED) {}
88 
OnCompletion(const base::Optional<blink::WebString> & message)89   void OnCompletion(const base::Optional<blink::WebString>& message) {
90     if (message) {
91       result_ = RESULT_ERROR;
92       message_ = *message;
93       run_loop_.Quit();
94       return;
95     }
96 
97     result_ = RESULT_SUCCESS;
98     run_loop_.Quit();
99   }
100 
RunUntilCalled()101   void RunUntilCalled() { run_loop_.Run(); }
102 
RunUntilIdle()103   void RunUntilIdle() { base::RunLoop().RunUntilIdle(); }
104 
105   Result result_;
106   blink::WebString message_;
107   base::RunLoop run_loop_;
108 };
109 
110 class WebSocketSBHandshakeThrottleTest : public ::testing::Test {
111  protected:
WebSocketSBHandshakeThrottleTest()112   WebSocketSBHandshakeThrottleTest() : mojo_receiver_(&safe_browsing_) {
113     mojo_receiver_.Bind(safe_browsing_remote_.BindNewPipeAndPassReceiver());
114     throttle_ = std::make_unique<WebSocketSBHandshakeThrottle>(
115         safe_browsing_remote_.get(), MSG_ROUTING_NONE);
116   }
117 
118   base::test::TaskEnvironment message_loop_;
119   FakeSafeBrowsing safe_browsing_;
120   mojo::Receiver<mojom::SafeBrowsing> mojo_receiver_;
121   mojo::Remote<mojom::SafeBrowsing> safe_browsing_remote_;
122   std::unique_ptr<WebSocketSBHandshakeThrottle> throttle_;
123   FakeCallback fake_callback_;
124 };
125 
TEST_F(WebSocketSBHandshakeThrottleTest,Construction)126 TEST_F(WebSocketSBHandshakeThrottleTest, Construction) {}
127 
TEST_F(WebSocketSBHandshakeThrottleTest,CheckArguments)128 TEST_F(WebSocketSBHandshakeThrottleTest, CheckArguments) {
129   throttle_->ThrottleHandshake(
130       GURL(kTestUrl), base::BindOnce(&FakeCallback::OnCompletion,
131                                      base::Unretained(&fake_callback_)));
132   safe_browsing_.RunUntilCalled();
133   EXPECT_EQ(MSG_ROUTING_NONE, safe_browsing_.render_frame_id_);
134   EXPECT_EQ(GURL(kTestUrl), safe_browsing_.url_);
135   EXPECT_EQ("GET", safe_browsing_.method_);
136   EXPECT_TRUE(safe_browsing_.headers_.GetHeaderVector().empty());
137   EXPECT_EQ(0, safe_browsing_.load_flags_);
138   EXPECT_EQ(blink::mojom::ResourceType::kSubResource,
139             safe_browsing_.resource_type_);
140   EXPECT_FALSE(safe_browsing_.has_user_gesture_);
141   EXPECT_FALSE(safe_browsing_.originated_from_service_worker_);
142   EXPECT_TRUE(safe_browsing_.callback_);
143 }
144 
TEST_F(WebSocketSBHandshakeThrottleTest,Safe)145 TEST_F(WebSocketSBHandshakeThrottleTest, Safe) {
146   throttle_->ThrottleHandshake(
147       GURL(kTestUrl), base::BindOnce(&FakeCallback::OnCompletion,
148                                      base::Unretained(&fake_callback_)));
149   safe_browsing_.RunUntilCalled();
150   std::move(safe_browsing_.callback_).Run(mojo::NullReceiver(), true, false);
151   fake_callback_.RunUntilCalled();
152   EXPECT_EQ(FakeCallback::RESULT_SUCCESS, fake_callback_.result_);
153 }
154 
TEST_F(WebSocketSBHandshakeThrottleTest,Unsafe)155 TEST_F(WebSocketSBHandshakeThrottleTest, Unsafe) {
156   throttle_->ThrottleHandshake(
157       GURL(kTestUrl), base::BindOnce(&FakeCallback::OnCompletion,
158                                      base::Unretained(&fake_callback_)));
159   safe_browsing_.RunUntilCalled();
160   std::move(safe_browsing_.callback_).Run(mojo::NullReceiver(), false, false);
161   fake_callback_.RunUntilCalled();
162   EXPECT_EQ(FakeCallback::RESULT_ERROR, fake_callback_.result_);
163   EXPECT_EQ(
164       blink::WebString(
165           "WebSocket connection to wss://test/ failed safe browsing check"),
166       fake_callback_.message_);
167 }
168 
TEST_F(WebSocketSBHandshakeThrottleTest,SlowCheckNotifier)169 TEST_F(WebSocketSBHandshakeThrottleTest, SlowCheckNotifier) {
170   throttle_->ThrottleHandshake(
171       GURL(kTestUrl), base::BindOnce(&FakeCallback::OnCompletion,
172                                      base::Unretained(&fake_callback_)));
173   safe_browsing_.RunUntilCalled();
174 
175   mojo::Remote<mojom::UrlCheckNotifier> slow_check_notifier;
176   std::move(safe_browsing_.callback_)
177       .Run(slow_check_notifier.BindNewPipeAndPassReceiver(), false, false);
178   fake_callback_.RunUntilIdle();
179   EXPECT_EQ(FakeCallback::RESULT_NOT_CALLED, fake_callback_.result_);
180 
181   slow_check_notifier->OnCompleteCheck(true, false);
182   fake_callback_.RunUntilCalled();
183   EXPECT_EQ(FakeCallback::RESULT_SUCCESS, fake_callback_.result_);
184 }
185 
TEST_F(WebSocketSBHandshakeThrottleTest,MojoServiceNotThere)186 TEST_F(WebSocketSBHandshakeThrottleTest, MojoServiceNotThere) {
187   mojo_receiver_.reset();
188   throttle_->ThrottleHandshake(
189       GURL(kTestUrl), base::BindOnce(&FakeCallback::OnCompletion,
190                                      base::Unretained(&fake_callback_)));
191   fake_callback_.RunUntilCalled();
192   EXPECT_EQ(FakeCallback::RESULT_SUCCESS, fake_callback_.result_);
193 }
194 
195 }  // namespace
196 
197 }  // namespace safe_browsing
198