1 // Copyright (c) 2012 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #include "chrome/browser/safe_browsing/client_side_detection_service.h"
6 
7 #include <algorithm>
8 #include <memory>
9 
10 #include "base/bind.h"
11 #include "base/containers/queue.h"
12 #include "base/location.h"
13 #include "base/memory/ptr_util.h"
14 #include "base/memory/scoped_refptr.h"
15 #include "base/metrics/histogram_functions.h"
16 #include "base/metrics/histogram_macros.h"
17 #include "base/single_thread_task_runner.h"
18 #include "base/stl_util.h"
19 #include "base/threading/thread_task_runner_handle.h"
20 #include "base/time/time.h"
21 #include "chrome/browser/browser_process.h"
22 #include "chrome/browser/policy/chrome_browser_policy_connector.h"
23 #include "chrome/browser/profiles/profile.h"
24 #include "chrome/browser/safe_browsing/client_side_detection_host.h"
25 #include "chrome/common/pref_names.h"
26 #include "components/prefs/pref_service.h"
27 #include "components/safe_browsing/content/common/safe_browsing.mojom.h"
28 #include "components/safe_browsing/core/common/safe_browsing_prefs.h"
29 #include "components/safe_browsing/core/common/utils.h"
30 #include "components/safe_browsing/core/proto/client_model.pb.h"
31 #include "components/safe_browsing/core/proto/csd.pb.h"
32 #include "content/public/browser/browser_thread.h"
33 #include "content/public/browser/notification_service.h"
34 #include "content/public/browser/notification_types.h"
35 #include "content/public/browser/render_process_host.h"
36 #include "crypto/sha2.h"
37 #include "google_apis/google_api_keys.h"
38 #include "mojo/public/cpp/bindings/remote.h"
39 #include "net/base/escape.h"
40 #include "net/base/ip_address.h"
41 #include "net/base/load_flags.h"
42 #include "net/http/http_response_headers.h"
43 #include "net/http/http_status_code.h"
44 #include "net/traffic_annotation/network_traffic_annotation.h"
45 #include "services/network/public/cpp/shared_url_loader_factory.h"
46 #include "services/network/public/cpp/simple_url_loader.h"
47 #include "url/gurl.h"
48 
49 using content::BrowserThread;
50 
51 namespace safe_browsing {
52 
53 const int ClientSideDetectionService::kInitialClientModelFetchDelayMs = 10000;
54 const int ClientSideDetectionService::kReportsIntervalDays = 1;
55 const int ClientSideDetectionService::kMaxReportsPerInterval = 3;
56 const int ClientSideDetectionService::kNegativeCacheIntervalDays = 1;
57 const int ClientSideDetectionService::kPositiveCacheIntervalMinutes = 30;
58 
59 const char ClientSideDetectionService::kClientReportPhishingUrl[] =
60     "https://sb-ssl.google.com/safebrowsing/clientreport/phishing";
61 
62 struct ClientSideDetectionService::ClientPhishingReportInfo {
63   std::unique_ptr<network::SimpleURLLoader> loader;
64   ClientReportPhishingRequestCallback callback;
65   GURL phishing_url;
66 };
67 
CacheState(bool phish,base::Time time)68 ClientSideDetectionService::CacheState::CacheState(bool phish, base::Time time)
69     : is_phishing(phish), timestamp(time) {}
70 
ClientSideDetectionService(Profile * profile)71 ClientSideDetectionService::ClientSideDetectionService(Profile* profile)
72     : profile_(profile),
73       enabled_(false),
74       extended_reporting_(false),
75       url_loader_factory_(nullptr) {
76   // |profile_| can be null in unit tests
77   if (!profile_)
78     return;
79 
80   if (g_browser_process->safe_browsing_service()) {
81     url_loader_factory_ =
82         g_browser_process->safe_browsing_service()->GetURLLoaderFactory(
83             profile);
84   }
85 
86   pref_change_registrar_.Init(profile_->GetPrefs());
87   pref_change_registrar_.Add(
88       prefs::kSafeBrowsingEnabled,
89       base::BindRepeating(&ClientSideDetectionService::OnPrefsUpdated,
90                           base::Unretained(this)));
91   pref_change_registrar_.Add(
92       prefs::kSafeBrowsingEnhanced,
93       base::BindRepeating(&ClientSideDetectionService::OnPrefsUpdated,
94                           base::Unretained(this)));
95   pref_change_registrar_.Add(
96       prefs::kSafeBrowsingScoutReportingEnabled,
97       base::BindRepeating(&ClientSideDetectionService::OnPrefsUpdated,
98                           base::Unretained(this)));
99 
100   // Do an initial check of the prefs.
101   OnPrefsUpdated();
102 }
103 
~ClientSideDetectionService()104 ClientSideDetectionService::~ClientSideDetectionService() {
105   weak_factory_.InvalidateWeakPtrs();
106 }
107 
Shutdown()108 void ClientSideDetectionService::Shutdown() {
109   url_loader_factory_.reset();
110 }
111 
OnPrefsUpdated()112 void ClientSideDetectionService::OnPrefsUpdated() {
113   DCHECK_CURRENTLY_ON(BrowserThread::UI);
114   bool enabled = IsSafeBrowsingEnabled(*profile_->GetPrefs());
115   bool extended_reporting =
116       IsEnhancedProtectionEnabled(*profile_->GetPrefs()) ||
117       IsExtendedReportingEnabled(*profile_->GetPrefs());
118   if (enabled == enabled_ && extended_reporting_ == extended_reporting)
119     return;
120 
121   enabled_ = enabled;
122   extended_reporting_ = extended_reporting;
123 
124   if (enabled_) {
125     if (!model_factory_.is_null()) {
126       model_loader_ = model_factory_.Run();
127     } else {
128       model_loader_ = std::make_unique<ModelLoader>(
129           base::BindRepeating(&ClientSideDetectionService::SendModelToRenderers,
130                               base::Unretained(this)),
131           profile_->GetURLLoaderFactory(), extended_reporting_);
132     }
133     // Refresh the models when the service is enabled.  This can happen when
134     // either of the preferences are toggled, or early during startup if
135     // safe browsing is already enabled. In a lot of cases the model will be
136     // in the cache so it  won't actually be fetched from the network.
137     // We delay the first model fetches to avoid slowing down browser startup.
138     model_loader_->ScheduleFetch(kInitialClientModelFetchDelayMs);
139   } else {
140     if (model_loader_) {
141       // Cancel model loads in progress.
142       model_loader_->CancelFetcher();
143     }
144     // Invoke pending callbacks with a false verdict.
145     for (auto it = client_phishing_reports_.begin();
146          it != client_phishing_reports_.end(); ++it) {
147       ClientPhishingReportInfo* info = it->second.get();
148       if (!info->callback.is_null())
149         std::move(info->callback).Run(info->phishing_url, false);
150     }
151     client_phishing_reports_.clear();
152     cache_.clear();
153   }
154 
155   SendModelToRenderers();  // always refresh the renderer state
156 }
157 
SendClientReportPhishingRequest(std::unique_ptr<ClientPhishingRequest> verdict,bool is_extended_reporting,bool is_enhanced_reporting,ClientReportPhishingRequestCallback callback)158 void ClientSideDetectionService::SendClientReportPhishingRequest(
159     std::unique_ptr<ClientPhishingRequest> verdict,
160     bool is_extended_reporting,
161     bool is_enhanced_reporting,
162     ClientReportPhishingRequestCallback callback) {
163   DCHECK_CURRENTLY_ON(BrowserThread::UI);
164   base::ThreadTaskRunnerHandle::Get()->PostTask(
165       FROM_HERE,
166       base::BindOnce(
167           &ClientSideDetectionService::StartClientReportPhishingRequest,
168           weak_factory_.GetWeakPtr(), std::move(verdict), is_extended_reporting,
169           is_enhanced_reporting, std::move(callback)));
170 }
171 
IsPrivateIPAddress(const std::string & ip_address) const172 bool ClientSideDetectionService::IsPrivateIPAddress(
173     const std::string& ip_address) const {
174   net::IPAddress address;
175   if (!address.AssignFromIPLiteral(ip_address)) {
176     // Err on the side of privacy and assume this might be private.
177     return true;
178   }
179 
180   return !address.IsPubliclyRoutable();
181 }
182 
AddClientSideDetectionHost(ClientSideDetectionHost * host)183 void ClientSideDetectionService::AddClientSideDetectionHost(
184     ClientSideDetectionHost* host) {
185   csd_hosts_.push_back(host);
186 }
187 
RemoveClientSideDetectionHost(ClientSideDetectionHost * host)188 void ClientSideDetectionService::RemoveClientSideDetectionHost(
189     ClientSideDetectionHost* host) {
190   std::vector<ClientSideDetectionHost*>::iterator position =
191       std::find(csd_hosts_.begin(), csd_hosts_.end(), host);
192   if (position != csd_hosts_.end())
193     csd_hosts_.erase(position);
194 }
195 
OnURLLoaderComplete(network::SimpleURLLoader * url_loader,std::unique_ptr<std::string> response_body)196 void ClientSideDetectionService::OnURLLoaderComplete(
197     network::SimpleURLLoader* url_loader,
198     std::unique_ptr<std::string> response_body) {
199   std::string data;
200   if (response_body)
201     data = std::move(*response_body.get());
202   int response_code = 0;
203   if (url_loader->ResponseInfo() && url_loader->ResponseInfo()->headers)
204     response_code = url_loader->ResponseInfo()->headers->response_code();
205 
206   DCHECK(base::Contains(client_phishing_reports_, url_loader));
207   HandlePhishingVerdict(url_loader, url_loader->GetFinalURL(),
208                         url_loader->NetError(), response_code, data);
209 }
210 
SendModelToRenderers()211 void ClientSideDetectionService::SendModelToRenderers() {
212   for (ClientSideDetectionHost* host : csd_hosts_) {
213     host->SendModelToRenderFrame();
214   }
215 }
216 
StartClientReportPhishingRequest(std::unique_ptr<ClientPhishingRequest> request,bool is_extended_reporting,bool is_enhanced_reporting,ClientReportPhishingRequestCallback callback)217 void ClientSideDetectionService::StartClientReportPhishingRequest(
218     std::unique_ptr<ClientPhishingRequest> request,
219     bool is_extended_reporting,
220     bool is_enhanced_reporting,
221     ClientReportPhishingRequestCallback callback) {
222   DCHECK_CURRENTLY_ON(BrowserThread::UI);
223 
224   if (!enabled_) {
225     if (!callback.is_null())
226       std::move(callback).Run(GURL(request->url()), false);
227     return;
228   }
229 
230   // Fill in metadata about which model we used.
231   request->set_model_filename(model_loader_->name());
232   if (is_extended_reporting || is_enhanced_reporting) {
233     if (is_enhanced_reporting) {
234       request->mutable_population()->set_user_population(
235           ChromeUserPopulation::ENHANCED_PROTECTION);
236     } else {
237       request->mutable_population()->set_user_population(
238           ChromeUserPopulation::EXTENDED_REPORTING);
239     }
240   } else {
241     request->mutable_population()->set_user_population(
242         ChromeUserPopulation::SAFE_BROWSING);
243   }
244   request->mutable_population()->set_profile_management_status(
245       GetProfileManagementStatus(
246           g_browser_process->browser_policy_connector()));
247 
248   std::string request_data;
249   request->SerializeToString(&request_data);
250 
251   net::NetworkTrafficAnnotationTag traffic_annotation =
252       net::DefineNetworkTrafficAnnotation(
253           "safe_browsing_client_side_phishing_detector", R"(
254           semantics {
255             sender: "Safe Browsing Client-Side Phishing Detector"
256             description:
257               "If the client-side phishing detector determines that the "
258               "current page contents are similar to phishing pages, it will "
259               "send a request to Safe Browsing to ask for a final verdict. If "
260               "Safe Browsing agrees the page is dangerous, Chrome will show a "
261               "full-page interstitial warning."
262             trigger:
263               "Whenever the clinet-side detector machine learning model "
264               "computes a phishy-ness score above a threshold, after page-load."
265             data:
266               "Top-level page URL without CGI parameters, boolean and double "
267               "features extracted from DOM, such as the number of resources "
268               "loaded in the page, if certain likely phishing and social "
269               "engineering terms found on the page, etc."
270             destination: GOOGLE_OWNED_SERVICE
271           }
272           policy {
273             cookies_allowed: YES
274             cookies_store: "Safe browsing cookie store"
275             setting:
276               "Users can enable or disable this feature by toggling 'Protect "
277               "you and your device from dangerous sites' in Chrome settings "
278               "under Privacy. This feature is enabled by default."
279             chrome_policy {
280               SafeBrowsingEnabled {
281                 policy_options {mode: MANDATORY}
282                 SafeBrowsingEnabled: false
283               }
284             }
285           })");
286   auto resource_request = std::make_unique<network::ResourceRequest>();
287   resource_request->url = GetClientReportUrl(kClientReportPhishingUrl);
288   resource_request->method = "POST";
289   resource_request->load_flags = net::LOAD_DISABLE_CACHE;
290   auto loader = network::SimpleURLLoader::Create(std::move(resource_request),
291                                                  traffic_annotation);
292   loader->AttachStringForUpload(request_data, "application/octet-stream");
293   loader->DownloadToStringOfUnboundedSizeUntilCrashAndDie(
294       url_loader_factory_.get(),
295       base::BindOnce(&ClientSideDetectionService::OnURLLoaderComplete,
296                      base::Unretained(this), loader.get()));
297 
298   // Remember which callback and URL correspond to the current fetcher object.
299   std::unique_ptr<ClientPhishingReportInfo> info(new ClientPhishingReportInfo);
300   auto* loader_ptr = loader.get();
301   info->loader = std::move(loader);
302   info->callback = std::move(callback);
303   info->phishing_url = GURL(request->url());
304   client_phishing_reports_[loader_ptr] = std::move(info);
305 
306   // Record that we made a request
307   phishing_report_times_.push(base::Time::Now());
308 }
309 
HandlePhishingVerdict(network::SimpleURLLoader * source,const GURL & url,int net_error,int response_code,const std::string & data)310 void ClientSideDetectionService::HandlePhishingVerdict(
311     network::SimpleURLLoader* source,
312     const GURL& url,
313     int net_error,
314     int response_code,
315     const std::string& data) {
316   ClientPhishingResponse response;
317   std::unique_ptr<ClientPhishingReportInfo> info =
318       std::move(client_phishing_reports_[source]);
319   client_phishing_reports_.erase(source);
320 
321   bool is_phishing = false;
322   if (net_error == net::OK && net::HTTP_OK == response_code &&
323       response.ParseFromString(data)) {
324     // Cache response, possibly flushing an old one.
325     cache_[info->phishing_url] =
326         base::WrapUnique(new CacheState(response.phishy(), base::Time::Now()));
327     is_phishing = response.phishy();
328   }
329   if (!info->callback.is_null())
330     std::move(info->callback).Run(info->phishing_url, is_phishing);
331 }
332 
IsInCache(const GURL & url)333 bool ClientSideDetectionService::IsInCache(const GURL& url) {
334   UpdateCache();
335 
336   return cache_.find(url) != cache_.end();
337 }
338 
GetValidCachedResult(const GURL & url,bool * is_phishing)339 bool ClientSideDetectionService::GetValidCachedResult(const GURL& url,
340                                                       bool* is_phishing) {
341   UpdateCache();
342 
343   auto it = cache_.find(url);
344   if (it == cache_.end()) {
345     return false;
346   }
347 
348   // We still need to check if the result is valid.
349   const CacheState& cache_state = *it->second;
350   if (cache_state.is_phishing
351           ? cache_state.timestamp >
352                 base::Time::Now() -
353                     base::TimeDelta::FromMinutes(kPositiveCacheIntervalMinutes)
354           : cache_state.timestamp >
355                 base::Time::Now() -
356                     base::TimeDelta::FromDays(kNegativeCacheIntervalDays)) {
357     *is_phishing = cache_state.is_phishing;
358     return true;
359   }
360   return false;
361 }
362 
UpdateCache()363 void ClientSideDetectionService::UpdateCache() {
364   // Since we limit the number of requests but allow pass-through for cache
365   // refreshes, we don't want to remove elements from the cache if they
366   // could be used for this purpose even if we will not use the entry to
367   // satisfy the request from the cache.
368   base::TimeDelta positive_cache_interval =
369       std::max(base::TimeDelta::FromMinutes(kPositiveCacheIntervalMinutes),
370                base::TimeDelta::FromDays(kReportsIntervalDays));
371   base::TimeDelta negative_cache_interval =
372       std::max(base::TimeDelta::FromDays(kNegativeCacheIntervalDays),
373                base::TimeDelta::FromDays(kReportsIntervalDays));
374 
375   // Remove elements from the cache that will no longer be used.
376   for (auto it = cache_.begin(); it != cache_.end();) {
377     const CacheState& cache_state = *it->second;
378     if (cache_state.is_phishing
379             ? cache_state.timestamp >
380                   base::Time::Now() - positive_cache_interval
381             : cache_state.timestamp >
382                   base::Time::Now() - negative_cache_interval) {
383       ++it;
384     } else {
385       cache_.erase(it++);
386     }
387   }
388 }
389 
OverPhishingReportLimit()390 bool ClientSideDetectionService::OverPhishingReportLimit() {
391   return GetPhishingNumReports() > kMaxReportsPerInterval;
392 }
393 
GetPhishingNumReports()394 int ClientSideDetectionService::GetPhishingNumReports() {
395   return GetNumReports(&phishing_report_times_);
396 }
397 
GetNumReports(base::queue<base::Time> * report_times)398 int ClientSideDetectionService::GetNumReports(
399     base::queue<base::Time>* report_times) {
400   base::Time cutoff =
401       base::Time::Now() - base::TimeDelta::FromDays(kReportsIntervalDays);
402 
403   // Erase items older than cutoff because we will never care about them again.
404   while (!report_times->empty() && report_times->front() < cutoff) {
405     report_times->pop();
406   }
407 
408   // Return the number of elements that are above the cutoff.
409   return report_times->size();
410 }
411 
412 // static
GetClientReportUrl(const std::string & report_url)413 GURL ClientSideDetectionService::GetClientReportUrl(
414     const std::string& report_url) {
415   GURL url(report_url);
416   std::string api_key = google_apis::GetAPIKey();
417   if (!api_key.empty())
418     url = url.Resolve("?key=" + net::EscapeQueryParamValue(api_key, true));
419 
420   return url;
421 }
422 
423 ModelLoader::ClientModelStatus
GetLastModelStatus()424 ClientSideDetectionService::GetLastModelStatus() {
425   // |model_loader_| can be null in tests
426   return model_loader_ ? model_loader_->last_client_model_status()
427                        : ModelLoader::MODEL_NEVER_FETCHED;
428 }
429 
GetModelStr()430 std::string ClientSideDetectionService::GetModelStr() {
431   return model_loader_ ? model_loader_->model_str() : "";
432 }
433 
SetModelLoaderFactoryForTesting(base::RepeatingCallback<std::unique_ptr<ModelLoader> ()> factory)434 void ClientSideDetectionService::SetModelLoaderFactoryForTesting(
435     base::RepeatingCallback<std::unique_ptr<ModelLoader>()> factory) {
436   model_factory_ = factory;
437 }
438 
SetURLLoaderFactoryForTesting(scoped_refptr<network::SharedURLLoaderFactory> url_loader_factory)439 void ClientSideDetectionService::SetURLLoaderFactoryForTesting(
440     scoped_refptr<network::SharedURLLoaderFactory> url_loader_factory) {
441   url_loader_factory_ = url_loader_factory;
442 }
443 
444 }  // namespace safe_browsing
445