1 // Copyright (c) 2012 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 #include "chrome/browser/safe_browsing/client_side_detection_service.h"
6
7 #include <algorithm>
8 #include <memory>
9
10 #include "base/bind.h"
11 #include "base/containers/queue.h"
12 #include "base/location.h"
13 #include "base/memory/ptr_util.h"
14 #include "base/memory/scoped_refptr.h"
15 #include "base/metrics/histogram_functions.h"
16 #include "base/metrics/histogram_macros.h"
17 #include "base/single_thread_task_runner.h"
18 #include "base/stl_util.h"
19 #include "base/threading/thread_task_runner_handle.h"
20 #include "base/time/time.h"
21 #include "chrome/browser/browser_process.h"
22 #include "chrome/browser/policy/chrome_browser_policy_connector.h"
23 #include "chrome/browser/profiles/profile.h"
24 #include "chrome/browser/safe_browsing/client_side_detection_host.h"
25 #include "chrome/common/pref_names.h"
26 #include "components/prefs/pref_service.h"
27 #include "components/safe_browsing/content/common/safe_browsing.mojom.h"
28 #include "components/safe_browsing/core/common/safe_browsing_prefs.h"
29 #include "components/safe_browsing/core/common/utils.h"
30 #include "components/safe_browsing/core/proto/client_model.pb.h"
31 #include "components/safe_browsing/core/proto/csd.pb.h"
32 #include "content/public/browser/browser_thread.h"
33 #include "content/public/browser/notification_service.h"
34 #include "content/public/browser/notification_types.h"
35 #include "content/public/browser/render_process_host.h"
36 #include "crypto/sha2.h"
37 #include "google_apis/google_api_keys.h"
38 #include "mojo/public/cpp/bindings/remote.h"
39 #include "net/base/escape.h"
40 #include "net/base/ip_address.h"
41 #include "net/base/load_flags.h"
42 #include "net/http/http_response_headers.h"
43 #include "net/http/http_status_code.h"
44 #include "net/traffic_annotation/network_traffic_annotation.h"
45 #include "services/network/public/cpp/shared_url_loader_factory.h"
46 #include "services/network/public/cpp/simple_url_loader.h"
47 #include "url/gurl.h"
48
49 using content::BrowserThread;
50
51 namespace safe_browsing {
52
53 const int ClientSideDetectionService::kInitialClientModelFetchDelayMs = 10000;
54 const int ClientSideDetectionService::kReportsIntervalDays = 1;
55 const int ClientSideDetectionService::kMaxReportsPerInterval = 3;
56 const int ClientSideDetectionService::kNegativeCacheIntervalDays = 1;
57 const int ClientSideDetectionService::kPositiveCacheIntervalMinutes = 30;
58
59 const char ClientSideDetectionService::kClientReportPhishingUrl[] =
60 "https://sb-ssl.google.com/safebrowsing/clientreport/phishing";
61
62 struct ClientSideDetectionService::ClientPhishingReportInfo {
63 std::unique_ptr<network::SimpleURLLoader> loader;
64 ClientReportPhishingRequestCallback callback;
65 GURL phishing_url;
66 };
67
CacheState(bool phish,base::Time time)68 ClientSideDetectionService::CacheState::CacheState(bool phish, base::Time time)
69 : is_phishing(phish), timestamp(time) {}
70
ClientSideDetectionService(Profile * profile)71 ClientSideDetectionService::ClientSideDetectionService(Profile* profile)
72 : profile_(profile),
73 enabled_(false),
74 extended_reporting_(false),
75 url_loader_factory_(nullptr) {
76 // |profile_| can be null in unit tests
77 if (!profile_)
78 return;
79
80 if (g_browser_process->safe_browsing_service()) {
81 url_loader_factory_ =
82 g_browser_process->safe_browsing_service()->GetURLLoaderFactory(
83 profile);
84 }
85
86 pref_change_registrar_.Init(profile_->GetPrefs());
87 pref_change_registrar_.Add(
88 prefs::kSafeBrowsingEnabled,
89 base::BindRepeating(&ClientSideDetectionService::OnPrefsUpdated,
90 base::Unretained(this)));
91 pref_change_registrar_.Add(
92 prefs::kSafeBrowsingEnhanced,
93 base::BindRepeating(&ClientSideDetectionService::OnPrefsUpdated,
94 base::Unretained(this)));
95 pref_change_registrar_.Add(
96 prefs::kSafeBrowsingScoutReportingEnabled,
97 base::BindRepeating(&ClientSideDetectionService::OnPrefsUpdated,
98 base::Unretained(this)));
99
100 // Do an initial check of the prefs.
101 OnPrefsUpdated();
102 }
103
~ClientSideDetectionService()104 ClientSideDetectionService::~ClientSideDetectionService() {
105 weak_factory_.InvalidateWeakPtrs();
106 }
107
Shutdown()108 void ClientSideDetectionService::Shutdown() {
109 url_loader_factory_.reset();
110 }
111
OnPrefsUpdated()112 void ClientSideDetectionService::OnPrefsUpdated() {
113 DCHECK_CURRENTLY_ON(BrowserThread::UI);
114 bool enabled = IsSafeBrowsingEnabled(*profile_->GetPrefs());
115 bool extended_reporting =
116 IsEnhancedProtectionEnabled(*profile_->GetPrefs()) ||
117 IsExtendedReportingEnabled(*profile_->GetPrefs());
118 if (enabled == enabled_ && extended_reporting_ == extended_reporting)
119 return;
120
121 enabled_ = enabled;
122 extended_reporting_ = extended_reporting;
123
124 if (enabled_) {
125 if (!model_factory_.is_null()) {
126 model_loader_ = model_factory_.Run();
127 } else {
128 model_loader_ = std::make_unique<ModelLoader>(
129 base::BindRepeating(&ClientSideDetectionService::SendModelToRenderers,
130 base::Unretained(this)),
131 profile_->GetURLLoaderFactory(), extended_reporting_);
132 }
133 // Refresh the models when the service is enabled. This can happen when
134 // either of the preferences are toggled, or early during startup if
135 // safe browsing is already enabled. In a lot of cases the model will be
136 // in the cache so it won't actually be fetched from the network.
137 // We delay the first model fetches to avoid slowing down browser startup.
138 model_loader_->ScheduleFetch(kInitialClientModelFetchDelayMs);
139 } else {
140 if (model_loader_) {
141 // Cancel model loads in progress.
142 model_loader_->CancelFetcher();
143 }
144 // Invoke pending callbacks with a false verdict.
145 for (auto it = client_phishing_reports_.begin();
146 it != client_phishing_reports_.end(); ++it) {
147 ClientPhishingReportInfo* info = it->second.get();
148 if (!info->callback.is_null())
149 std::move(info->callback).Run(info->phishing_url, false);
150 }
151 client_phishing_reports_.clear();
152 cache_.clear();
153 }
154
155 SendModelToRenderers(); // always refresh the renderer state
156 }
157
SendClientReportPhishingRequest(std::unique_ptr<ClientPhishingRequest> verdict,bool is_extended_reporting,bool is_enhanced_reporting,ClientReportPhishingRequestCallback callback)158 void ClientSideDetectionService::SendClientReportPhishingRequest(
159 std::unique_ptr<ClientPhishingRequest> verdict,
160 bool is_extended_reporting,
161 bool is_enhanced_reporting,
162 ClientReportPhishingRequestCallback callback) {
163 DCHECK_CURRENTLY_ON(BrowserThread::UI);
164 base::ThreadTaskRunnerHandle::Get()->PostTask(
165 FROM_HERE,
166 base::BindOnce(
167 &ClientSideDetectionService::StartClientReportPhishingRequest,
168 weak_factory_.GetWeakPtr(), std::move(verdict), is_extended_reporting,
169 is_enhanced_reporting, std::move(callback)));
170 }
171
IsPrivateIPAddress(const std::string & ip_address) const172 bool ClientSideDetectionService::IsPrivateIPAddress(
173 const std::string& ip_address) const {
174 net::IPAddress address;
175 if (!address.AssignFromIPLiteral(ip_address)) {
176 // Err on the side of privacy and assume this might be private.
177 return true;
178 }
179
180 return !address.IsPubliclyRoutable();
181 }
182
AddClientSideDetectionHost(ClientSideDetectionHost * host)183 void ClientSideDetectionService::AddClientSideDetectionHost(
184 ClientSideDetectionHost* host) {
185 csd_hosts_.push_back(host);
186 }
187
RemoveClientSideDetectionHost(ClientSideDetectionHost * host)188 void ClientSideDetectionService::RemoveClientSideDetectionHost(
189 ClientSideDetectionHost* host) {
190 std::vector<ClientSideDetectionHost*>::iterator position =
191 std::find(csd_hosts_.begin(), csd_hosts_.end(), host);
192 if (position != csd_hosts_.end())
193 csd_hosts_.erase(position);
194 }
195
OnURLLoaderComplete(network::SimpleURLLoader * url_loader,std::unique_ptr<std::string> response_body)196 void ClientSideDetectionService::OnURLLoaderComplete(
197 network::SimpleURLLoader* url_loader,
198 std::unique_ptr<std::string> response_body) {
199 std::string data;
200 if (response_body)
201 data = std::move(*response_body.get());
202 int response_code = 0;
203 if (url_loader->ResponseInfo() && url_loader->ResponseInfo()->headers)
204 response_code = url_loader->ResponseInfo()->headers->response_code();
205
206 DCHECK(base::Contains(client_phishing_reports_, url_loader));
207 HandlePhishingVerdict(url_loader, url_loader->GetFinalURL(),
208 url_loader->NetError(), response_code, data);
209 }
210
SendModelToRenderers()211 void ClientSideDetectionService::SendModelToRenderers() {
212 for (ClientSideDetectionHost* host : csd_hosts_) {
213 host->SendModelToRenderFrame();
214 }
215 }
216
StartClientReportPhishingRequest(std::unique_ptr<ClientPhishingRequest> request,bool is_extended_reporting,bool is_enhanced_reporting,ClientReportPhishingRequestCallback callback)217 void ClientSideDetectionService::StartClientReportPhishingRequest(
218 std::unique_ptr<ClientPhishingRequest> request,
219 bool is_extended_reporting,
220 bool is_enhanced_reporting,
221 ClientReportPhishingRequestCallback callback) {
222 DCHECK_CURRENTLY_ON(BrowserThread::UI);
223
224 if (!enabled_) {
225 if (!callback.is_null())
226 std::move(callback).Run(GURL(request->url()), false);
227 return;
228 }
229
230 // Fill in metadata about which model we used.
231 request->set_model_filename(model_loader_->name());
232 if (is_extended_reporting || is_enhanced_reporting) {
233 if (is_enhanced_reporting) {
234 request->mutable_population()->set_user_population(
235 ChromeUserPopulation::ENHANCED_PROTECTION);
236 } else {
237 request->mutable_population()->set_user_population(
238 ChromeUserPopulation::EXTENDED_REPORTING);
239 }
240 } else {
241 request->mutable_population()->set_user_population(
242 ChromeUserPopulation::SAFE_BROWSING);
243 }
244 request->mutable_population()->set_profile_management_status(
245 GetProfileManagementStatus(
246 g_browser_process->browser_policy_connector()));
247
248 std::string request_data;
249 request->SerializeToString(&request_data);
250
251 net::NetworkTrafficAnnotationTag traffic_annotation =
252 net::DefineNetworkTrafficAnnotation(
253 "safe_browsing_client_side_phishing_detector", R"(
254 semantics {
255 sender: "Safe Browsing Client-Side Phishing Detector"
256 description:
257 "If the client-side phishing detector determines that the "
258 "current page contents are similar to phishing pages, it will "
259 "send a request to Safe Browsing to ask for a final verdict. If "
260 "Safe Browsing agrees the page is dangerous, Chrome will show a "
261 "full-page interstitial warning."
262 trigger:
263 "Whenever the clinet-side detector machine learning model "
264 "computes a phishy-ness score above a threshold, after page-load."
265 data:
266 "Top-level page URL without CGI parameters, boolean and double "
267 "features extracted from DOM, such as the number of resources "
268 "loaded in the page, if certain likely phishing and social "
269 "engineering terms found on the page, etc."
270 destination: GOOGLE_OWNED_SERVICE
271 }
272 policy {
273 cookies_allowed: YES
274 cookies_store: "Safe browsing cookie store"
275 setting:
276 "Users can enable or disable this feature by toggling 'Protect "
277 "you and your device from dangerous sites' in Chrome settings "
278 "under Privacy. This feature is enabled by default."
279 chrome_policy {
280 SafeBrowsingEnabled {
281 policy_options {mode: MANDATORY}
282 SafeBrowsingEnabled: false
283 }
284 }
285 })");
286 auto resource_request = std::make_unique<network::ResourceRequest>();
287 resource_request->url = GetClientReportUrl(kClientReportPhishingUrl);
288 resource_request->method = "POST";
289 resource_request->load_flags = net::LOAD_DISABLE_CACHE;
290 auto loader = network::SimpleURLLoader::Create(std::move(resource_request),
291 traffic_annotation);
292 loader->AttachStringForUpload(request_data, "application/octet-stream");
293 loader->DownloadToStringOfUnboundedSizeUntilCrashAndDie(
294 url_loader_factory_.get(),
295 base::BindOnce(&ClientSideDetectionService::OnURLLoaderComplete,
296 base::Unretained(this), loader.get()));
297
298 // Remember which callback and URL correspond to the current fetcher object.
299 std::unique_ptr<ClientPhishingReportInfo> info(new ClientPhishingReportInfo);
300 auto* loader_ptr = loader.get();
301 info->loader = std::move(loader);
302 info->callback = std::move(callback);
303 info->phishing_url = GURL(request->url());
304 client_phishing_reports_[loader_ptr] = std::move(info);
305
306 // Record that we made a request
307 phishing_report_times_.push(base::Time::Now());
308 }
309
HandlePhishingVerdict(network::SimpleURLLoader * source,const GURL & url,int net_error,int response_code,const std::string & data)310 void ClientSideDetectionService::HandlePhishingVerdict(
311 network::SimpleURLLoader* source,
312 const GURL& url,
313 int net_error,
314 int response_code,
315 const std::string& data) {
316 ClientPhishingResponse response;
317 std::unique_ptr<ClientPhishingReportInfo> info =
318 std::move(client_phishing_reports_[source]);
319 client_phishing_reports_.erase(source);
320
321 bool is_phishing = false;
322 if (net_error == net::OK && net::HTTP_OK == response_code &&
323 response.ParseFromString(data)) {
324 // Cache response, possibly flushing an old one.
325 cache_[info->phishing_url] =
326 base::WrapUnique(new CacheState(response.phishy(), base::Time::Now()));
327 is_phishing = response.phishy();
328 }
329 if (!info->callback.is_null())
330 std::move(info->callback).Run(info->phishing_url, is_phishing);
331 }
332
IsInCache(const GURL & url)333 bool ClientSideDetectionService::IsInCache(const GURL& url) {
334 UpdateCache();
335
336 return cache_.find(url) != cache_.end();
337 }
338
GetValidCachedResult(const GURL & url,bool * is_phishing)339 bool ClientSideDetectionService::GetValidCachedResult(const GURL& url,
340 bool* is_phishing) {
341 UpdateCache();
342
343 auto it = cache_.find(url);
344 if (it == cache_.end()) {
345 return false;
346 }
347
348 // We still need to check if the result is valid.
349 const CacheState& cache_state = *it->second;
350 if (cache_state.is_phishing
351 ? cache_state.timestamp >
352 base::Time::Now() -
353 base::TimeDelta::FromMinutes(kPositiveCacheIntervalMinutes)
354 : cache_state.timestamp >
355 base::Time::Now() -
356 base::TimeDelta::FromDays(kNegativeCacheIntervalDays)) {
357 *is_phishing = cache_state.is_phishing;
358 return true;
359 }
360 return false;
361 }
362
UpdateCache()363 void ClientSideDetectionService::UpdateCache() {
364 // Since we limit the number of requests but allow pass-through for cache
365 // refreshes, we don't want to remove elements from the cache if they
366 // could be used for this purpose even if we will not use the entry to
367 // satisfy the request from the cache.
368 base::TimeDelta positive_cache_interval =
369 std::max(base::TimeDelta::FromMinutes(kPositiveCacheIntervalMinutes),
370 base::TimeDelta::FromDays(kReportsIntervalDays));
371 base::TimeDelta negative_cache_interval =
372 std::max(base::TimeDelta::FromDays(kNegativeCacheIntervalDays),
373 base::TimeDelta::FromDays(kReportsIntervalDays));
374
375 // Remove elements from the cache that will no longer be used.
376 for (auto it = cache_.begin(); it != cache_.end();) {
377 const CacheState& cache_state = *it->second;
378 if (cache_state.is_phishing
379 ? cache_state.timestamp >
380 base::Time::Now() - positive_cache_interval
381 : cache_state.timestamp >
382 base::Time::Now() - negative_cache_interval) {
383 ++it;
384 } else {
385 cache_.erase(it++);
386 }
387 }
388 }
389
OverPhishingReportLimit()390 bool ClientSideDetectionService::OverPhishingReportLimit() {
391 return GetPhishingNumReports() > kMaxReportsPerInterval;
392 }
393
GetPhishingNumReports()394 int ClientSideDetectionService::GetPhishingNumReports() {
395 return GetNumReports(&phishing_report_times_);
396 }
397
GetNumReports(base::queue<base::Time> * report_times)398 int ClientSideDetectionService::GetNumReports(
399 base::queue<base::Time>* report_times) {
400 base::Time cutoff =
401 base::Time::Now() - base::TimeDelta::FromDays(kReportsIntervalDays);
402
403 // Erase items older than cutoff because we will never care about them again.
404 while (!report_times->empty() && report_times->front() < cutoff) {
405 report_times->pop();
406 }
407
408 // Return the number of elements that are above the cutoff.
409 return report_times->size();
410 }
411
412 // static
GetClientReportUrl(const std::string & report_url)413 GURL ClientSideDetectionService::GetClientReportUrl(
414 const std::string& report_url) {
415 GURL url(report_url);
416 std::string api_key = google_apis::GetAPIKey();
417 if (!api_key.empty())
418 url = url.Resolve("?key=" + net::EscapeQueryParamValue(api_key, true));
419
420 return url;
421 }
422
423 ModelLoader::ClientModelStatus
GetLastModelStatus()424 ClientSideDetectionService::GetLastModelStatus() {
425 // |model_loader_| can be null in tests
426 return model_loader_ ? model_loader_->last_client_model_status()
427 : ModelLoader::MODEL_NEVER_FETCHED;
428 }
429
GetModelStr()430 std::string ClientSideDetectionService::GetModelStr() {
431 return model_loader_ ? model_loader_->model_str() : "";
432 }
433
SetModelLoaderFactoryForTesting(base::RepeatingCallback<std::unique_ptr<ModelLoader> ()> factory)434 void ClientSideDetectionService::SetModelLoaderFactoryForTesting(
435 base::RepeatingCallback<std::unique_ptr<ModelLoader>()> factory) {
436 model_factory_ = factory;
437 }
438
SetURLLoaderFactoryForTesting(scoped_refptr<network::SharedURLLoaderFactory> url_loader_factory)439 void ClientSideDetectionService::SetURLLoaderFactoryForTesting(
440 scoped_refptr<network::SharedURLLoaderFactory> url_loader_factory) {
441 url_loader_factory_ = url_loader_factory;
442 }
443
444 } // namespace safe_browsing
445