1 // Copyright 2017 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #include "chrome/browser/safe_browsing/client_side_detection_host.h"
6 
7 #include "base/run_loop.h"
8 #include "chrome/browser/profiles/profile.h"
9 #include "chrome/browser/safe_browsing/client_side_detection_service.h"
10 #include "chrome/browser/ui/browser.h"
11 #include "chrome/browser/ui/tabs/tab_strip_model.h"
12 #include "chrome/test/base/in_process_browser_test.h"
13 #include "chrome/test/base/ui_test_utils.h"
14 #include "components/prefs/pref_service.h"
15 #include "components/safe_browsing/buildflags.h"
16 #include "components/safe_browsing/core/proto/client_model.pb.h"
17 #include "content/public/test/browser_test.h"
18 #include "testing/gmock/include/gmock/gmock.h"
19 #include "testing/gtest/include/gtest/gtest.h"
20 #include "url/gurl.h"
21 
22 namespace safe_browsing {
23 namespace {
24 
25 using ::testing::_;
26 using ::testing::StrictMock;
27 
28 class FakeClientSideDetectionService : public ClientSideDetectionService {
29  public:
FakeClientSideDetectionService()30   FakeClientSideDetectionService() : ClientSideDetectionService(nullptr) {}
31 
SendClientReportPhishingRequest(std::unique_ptr<ClientPhishingRequest> verdict,bool is_extended_reporting,bool is_enhanced_protection,ClientReportPhishingRequestCallback callback)32   void SendClientReportPhishingRequest(
33       std::unique_ptr<ClientPhishingRequest> verdict,
34       bool is_extended_reporting,
35       bool is_enhanced_protection,
36       ClientReportPhishingRequestCallback callback) override {
37     saved_request_ = *verdict;
38     saved_callback_ = std::move(callback);
39     request_callback_.Run();
40   }
41 
saved_request()42   const ClientPhishingRequest& saved_request() { return saved_request_; }
43 
saved_callback_is_null()44   bool saved_callback_is_null() { return saved_callback_.is_null(); }
45 
saved_callback()46   ClientReportPhishingRequestCallback saved_callback() {
47     return std::move(saved_callback_);
48   }
49 
SetModel(const ClientSideModel & model)50   void SetModel(const ClientSideModel& model) { model_ = model; }
51 
GetModelStr()52   std::string GetModelStr() override { return model_.SerializeAsString(); }
53 
SetRequestCallback(const base::RepeatingClosure & closure)54   void SetRequestCallback(const base::RepeatingClosure& closure) {
55     request_callback_ = closure;
56   }
57 
58  private:
59   ClientPhishingRequest saved_request_;
60   ClientReportPhishingRequestCallback saved_callback_;
61   ClientSideModel model_;
62   base::RepeatingClosure request_callback_;
63 };
64 
65 class MockSafeBrowsingUIManager : public SafeBrowsingUIManager {
66  public:
MockSafeBrowsingUIManager()67   MockSafeBrowsingUIManager() : SafeBrowsingUIManager(nullptr) {}
68 
69   MOCK_METHOD1(DisplayBlockingPage, void(const UnsafeResource& resource));
70 
71  protected:
72   ~MockSafeBrowsingUIManager() override = default;
73 
74  private:
75   DISALLOW_COPY_AND_ASSIGN(MockSafeBrowsingUIManager);
76 };
77 
78 }  // namespace
79 
80 class ClientSideDetectionHostBrowserTest : public InProcessBrowserTest {
81  public:
82   ClientSideDetectionHostBrowserTest() = default;
83   ~ClientSideDetectionHostBrowserTest() override = default;
84 };
85 
86 #if BUILDFLAG(FULL_SAFE_BROWSING)
IN_PROC_BROWSER_TEST_F(ClientSideDetectionHostBrowserTest,VerifyVisualFeatureCollection)87 IN_PROC_BROWSER_TEST_F(ClientSideDetectionHostBrowserTest,
88                        VerifyVisualFeatureCollection) {
89   FakeClientSideDetectionService fake_csd_service;
90 
91   ClientSideModel model;
92   model.set_version(123);
93   model.set_max_words_per_term(1);
94   VisualTarget* target = model.mutable_vision_model()->add_targets();
95 
96   target->set_digest("target1_digest");
97   // Create a hash corresponding to a blank screen.
98   std::string hash = "\x30";
99   for (int i = 0; i < 288; i++)
100     hash += "\xff";
101   target->set_hash(hash);
102   target->set_dimension_size(48);
103   MatchRule* match_rule = target->mutable_match_config()->add_match_rule();
104   // The actual hash distance is 76, so set the distance to 200 for safety. A
105   // completely random bitstring would expect a Hamming distance of 1152.
106   match_rule->set_hash_distance(200);
107 
108   fake_csd_service.SetModel(model);
109 
110   scoped_refptr<StrictMock<MockSafeBrowsingUIManager>> mock_ui_manager =
111       new StrictMock<MockSafeBrowsingUIManager>();
112 
113   ASSERT_TRUE(embedded_test_server()->Start());
114   std::unique_ptr<ClientSideDetectionHost> csd_host =
115       ClientSideDetectionHost::Create(
116           browser()->tab_strip_model()->GetActiveWebContents());
117   csd_host->set_client_side_detection_service(&fake_csd_service);
118   csd_host->SendModelToRenderFrame();
119   csd_host->set_ui_manager(mock_ui_manager.get());
120 
121   GURL page_url(embedded_test_server()->GetURL("/safe_browsing/malware.html"));
122   ui_test_utils::NavigateToURL(browser(), page_url);
123 
124   base::RunLoop run_loop;
125   fake_csd_service.SetRequestCallback(run_loop.QuitClosure());
126 
127   // Bypass the pre-classification checks
128   csd_host->OnPhishingPreClassificationDone(/*should_classify=*/true);
129 
130   run_loop.Run();
131 
132   ASSERT_FALSE(fake_csd_service.saved_callback_is_null());
133 
134   EXPECT_EQ(fake_csd_service.saved_request().model_version(), 123);
135   ASSERT_EQ(fake_csd_service.saved_request().vision_match_size(), 1);
136   EXPECT_EQ(
137       fake_csd_service.saved_request().vision_match(0).matched_target_digest(),
138       "target1_digest");
139 
140   // Expect an interstitial to be shown
141   EXPECT_CALL(*mock_ui_manager, DisplayBlockingPage(_));
142   std::move(fake_csd_service.saved_callback()).Run(page_url, true);
143 }
144 #endif
145 
146 }  // namespace safe_browsing
147