1 // Copyright 2017 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 #include "chrome/browser/safe_browsing/client_side_detection_host.h"
6
7 #include "base/run_loop.h"
8 #include "chrome/browser/profiles/profile.h"
9 #include "chrome/browser/safe_browsing/client_side_detection_service.h"
10 #include "chrome/browser/ui/browser.h"
11 #include "chrome/browser/ui/tabs/tab_strip_model.h"
12 #include "chrome/test/base/in_process_browser_test.h"
13 #include "chrome/test/base/ui_test_utils.h"
14 #include "components/prefs/pref_service.h"
15 #include "components/safe_browsing/buildflags.h"
16 #include "components/safe_browsing/core/proto/client_model.pb.h"
17 #include "content/public/test/browser_test.h"
18 #include "testing/gmock/include/gmock/gmock.h"
19 #include "testing/gtest/include/gtest/gtest.h"
20 #include "url/gurl.h"
21
22 namespace safe_browsing {
23 namespace {
24
25 using ::testing::_;
26 using ::testing::StrictMock;
27
28 class FakeClientSideDetectionService : public ClientSideDetectionService {
29 public:
FakeClientSideDetectionService()30 FakeClientSideDetectionService() : ClientSideDetectionService(nullptr) {}
31
SendClientReportPhishingRequest(std::unique_ptr<ClientPhishingRequest> verdict,bool is_extended_reporting,bool is_enhanced_protection,ClientReportPhishingRequestCallback callback)32 void SendClientReportPhishingRequest(
33 std::unique_ptr<ClientPhishingRequest> verdict,
34 bool is_extended_reporting,
35 bool is_enhanced_protection,
36 ClientReportPhishingRequestCallback callback) override {
37 saved_request_ = *verdict;
38 saved_callback_ = std::move(callback);
39 request_callback_.Run();
40 }
41
saved_request()42 const ClientPhishingRequest& saved_request() { return saved_request_; }
43
saved_callback_is_null()44 bool saved_callback_is_null() { return saved_callback_.is_null(); }
45
saved_callback()46 ClientReportPhishingRequestCallback saved_callback() {
47 return std::move(saved_callback_);
48 }
49
SetModel(const ClientSideModel & model)50 void SetModel(const ClientSideModel& model) { model_ = model; }
51
GetModelStr()52 std::string GetModelStr() override { return model_.SerializeAsString(); }
53
SetRequestCallback(const base::RepeatingClosure & closure)54 void SetRequestCallback(const base::RepeatingClosure& closure) {
55 request_callback_ = closure;
56 }
57
58 private:
59 ClientPhishingRequest saved_request_;
60 ClientReportPhishingRequestCallback saved_callback_;
61 ClientSideModel model_;
62 base::RepeatingClosure request_callback_;
63 };
64
65 class MockSafeBrowsingUIManager : public SafeBrowsingUIManager {
66 public:
MockSafeBrowsingUIManager()67 MockSafeBrowsingUIManager() : SafeBrowsingUIManager(nullptr) {}
68
69 MOCK_METHOD1(DisplayBlockingPage, void(const UnsafeResource& resource));
70
71 protected:
72 ~MockSafeBrowsingUIManager() override = default;
73
74 private:
75 DISALLOW_COPY_AND_ASSIGN(MockSafeBrowsingUIManager);
76 };
77
78 } // namespace
79
80 class ClientSideDetectionHostBrowserTest : public InProcessBrowserTest {
81 public:
82 ClientSideDetectionHostBrowserTest() = default;
83 ~ClientSideDetectionHostBrowserTest() override = default;
84 };
85
86 #if BUILDFLAG(FULL_SAFE_BROWSING)
IN_PROC_BROWSER_TEST_F(ClientSideDetectionHostBrowserTest,VerifyVisualFeatureCollection)87 IN_PROC_BROWSER_TEST_F(ClientSideDetectionHostBrowserTest,
88 VerifyVisualFeatureCollection) {
89 FakeClientSideDetectionService fake_csd_service;
90
91 ClientSideModel model;
92 model.set_version(123);
93 model.set_max_words_per_term(1);
94 VisualTarget* target = model.mutable_vision_model()->add_targets();
95
96 target->set_digest("target1_digest");
97 // Create a hash corresponding to a blank screen.
98 std::string hash = "\x30";
99 for (int i = 0; i < 288; i++)
100 hash += "\xff";
101 target->set_hash(hash);
102 target->set_dimension_size(48);
103 MatchRule* match_rule = target->mutable_match_config()->add_match_rule();
104 // The actual hash distance is 76, so set the distance to 200 for safety. A
105 // completely random bitstring would expect a Hamming distance of 1152.
106 match_rule->set_hash_distance(200);
107
108 fake_csd_service.SetModel(model);
109
110 scoped_refptr<StrictMock<MockSafeBrowsingUIManager>> mock_ui_manager =
111 new StrictMock<MockSafeBrowsingUIManager>();
112
113 ASSERT_TRUE(embedded_test_server()->Start());
114 std::unique_ptr<ClientSideDetectionHost> csd_host =
115 ClientSideDetectionHost::Create(
116 browser()->tab_strip_model()->GetActiveWebContents());
117 csd_host->set_client_side_detection_service(&fake_csd_service);
118 csd_host->SendModelToRenderFrame();
119 csd_host->set_ui_manager(mock_ui_manager.get());
120
121 GURL page_url(embedded_test_server()->GetURL("/safe_browsing/malware.html"));
122 ui_test_utils::NavigateToURL(browser(), page_url);
123
124 base::RunLoop run_loop;
125 fake_csd_service.SetRequestCallback(run_loop.QuitClosure());
126
127 // Bypass the pre-classification checks
128 csd_host->OnPhishingPreClassificationDone(/*should_classify=*/true);
129
130 run_loop.Run();
131
132 ASSERT_FALSE(fake_csd_service.saved_callback_is_null());
133
134 EXPECT_EQ(fake_csd_service.saved_request().model_version(), 123);
135 ASSERT_EQ(fake_csd_service.saved_request().vision_match_size(), 1);
136 EXPECT_EQ(
137 fake_csd_service.saved_request().vision_match(0).matched_target_digest(),
138 "target1_digest");
139
140 // Expect an interstitial to be shown
141 EXPECT_CALL(*mock_ui_manager, DisplayBlockingPage(_));
142 std::move(fake_csd_service.saved_callback()).Run(page_url, true);
143 }
144 #endif
145
146 } // namespace safe_browsing
147