1 // Copyright 2012 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 #include "chrome/browser/extensions/blocklist.h"
6
7 #include <algorithm>
8 #include <iterator>
9
10 #include "base/bind.h"
11 #include "base/callback_list.h"
12 #include "base/lazy_instance.h"
13 #include "base/macros.h"
14 #include "base/memory/ref_counted.h"
15 #include "base/single_thread_task_runner.h"
16 #include "base/stl_util.h"
17 #include "base/threading/thread_task_runner_handle.h"
18 #include "chrome/browser/browser_process.h"
19 #include "chrome/browser/extensions/blocklist_factory.h"
20 #include "chrome/browser/extensions/blocklist_state_fetcher.h"
21 #include "chrome/browser/safe_browsing/safe_browsing_service.h"
22 #include "components/prefs/pref_service.h"
23 #include "components/safe_browsing/buildflags.h"
24 #include "components/safe_browsing/core/db/util.h"
25 #include "content/public/browser/browser_task_traits.h"
26 #include "content/public/browser/browser_thread.h"
27 #include "extensions/browser/extension_prefs.h"
28
29 using content::BrowserThread;
30 using safe_browsing::SafeBrowsingDatabaseManager;
31
32 namespace extensions {
33
34 namespace {
35
36 // The safe browsing database manager to use. Make this a global/static variable
37 // rather than a member of Blocklist because Blocklist accesses the real
38 // database manager before it has a chance to get a fake one.
39 class LazySafeBrowsingDatabaseManager {
40 public:
LazySafeBrowsingDatabaseManager()41 LazySafeBrowsingDatabaseManager() {
42 #if BUILDFLAG(SAFE_BROWSING_DB_LOCAL)
43 if (g_browser_process && g_browser_process->safe_browsing_service()) {
44 instance_ =
45 g_browser_process->safe_browsing_service()->database_manager();
46 }
47 #endif
48 }
49
get()50 scoped_refptr<SafeBrowsingDatabaseManager> get() { return instance_; }
51
set(scoped_refptr<SafeBrowsingDatabaseManager> instance)52 void set(scoped_refptr<SafeBrowsingDatabaseManager> instance) {
53 instance_ = instance;
54 database_changed_callback_list_.Notify();
55 }
56
57 std::unique_ptr<base::RepeatingClosureList::Subscription>
RegisterDatabaseChangedCallback(const base::RepeatingClosure & cb)58 RegisterDatabaseChangedCallback(const base::RepeatingClosure& cb) {
59 return database_changed_callback_list_.Add(cb);
60 }
61
62 private:
63 scoped_refptr<SafeBrowsingDatabaseManager> instance_;
64 base::RepeatingClosureList database_changed_callback_list_;
65 };
66
67 static base::LazyInstance<LazySafeBrowsingDatabaseManager>::DestructorAtExit
68 g_database_manager = LAZY_INSTANCE_INITIALIZER;
69
70 // Implementation of SafeBrowsingDatabaseManager::Client, the class which is
71 // called back from safebrowsing queries.
72 //
73 // Constructed on any thread but lives on the IO from then on.
74 class SafeBrowsingClientImpl
75 : public SafeBrowsingDatabaseManager::Client,
76 public base::RefCountedThreadSafe<SafeBrowsingClientImpl> {
77 public:
78 using OnResultCallback =
79 base::OnceCallback<void(const std::set<std::string>&)>;
80
81 // Constructs a client to query the database manager for |extension_ids| and
82 // run |callback| with the IDs of those which have been blocklisted.
Start(const std::set<std::string> & extension_ids,OnResultCallback callback)83 static void Start(const std::set<std::string>& extension_ids,
84 OnResultCallback callback) {
85 auto safe_browsing_client = base::WrapRefCounted(
86 new SafeBrowsingClientImpl(extension_ids, std::move(callback)));
87 content::GetIOThreadTaskRunner({})->PostTask(
88 FROM_HERE,
89 base::BindOnce(&SafeBrowsingClientImpl::StartCheck,
90 safe_browsing_client, g_database_manager.Get().get(),
91 extension_ids));
92 }
93
94 private:
95 friend class base::RefCountedThreadSafe<SafeBrowsingClientImpl>;
96
SafeBrowsingClientImpl(const std::set<std::string> & extension_ids,OnResultCallback callback)97 SafeBrowsingClientImpl(const std::set<std::string>& extension_ids,
98 OnResultCallback callback)
99 : callback_task_runner_(base::ThreadTaskRunnerHandle::Get()),
100 callback_(std::move(callback)) {}
101
~SafeBrowsingClientImpl()102 ~SafeBrowsingClientImpl() override {}
103
104 // Pass |database_manager| as a parameter to avoid touching
105 // SafeBrowsingService on the IO thread.
StartCheck(scoped_refptr<SafeBrowsingDatabaseManager> database_manager,const std::set<std::string> & extension_ids)106 void StartCheck(scoped_refptr<SafeBrowsingDatabaseManager> database_manager,
107 const std::set<std::string>& extension_ids) {
108 DCHECK_CURRENTLY_ON(BrowserThread::IO);
109 if (database_manager->CheckExtensionIDs(extension_ids, this)) {
110 // Definitely not blocklisted. Callback immediately.
111 callback_task_runner_->PostTask(
112 FROM_HERE,
113 base::BindOnce(std::move(callback_), std::set<std::string>()));
114 return;
115 }
116 // Something might be blocklisted, response will come in
117 // OnCheckExtensionsResult.
118 AddRef(); // Balanced in OnCheckExtensionsResult
119 }
120
OnCheckExtensionsResult(const std::set<std::string> & hits)121 void OnCheckExtensionsResult(const std::set<std::string>& hits) override {
122 DCHECK_CURRENTLY_ON(BrowserThread::IO);
123 callback_task_runner_->PostTask(FROM_HERE,
124 base::BindOnce(std::move(callback_), hits));
125 Release(); // Balanced in StartCheck.
126 }
127
128 scoped_refptr<base::SingleThreadTaskRunner> callback_task_runner_;
129 OnResultCallback callback_;
130
131 DISALLOW_COPY_AND_ASSIGN(SafeBrowsingClientImpl);
132 };
133
CheckOneExtensionState(Blocklist::IsBlocklistedCallback callback,const Blocklist::BlocklistStateMap & state_map)134 void CheckOneExtensionState(Blocklist::IsBlocklistedCallback callback,
135 const Blocklist::BlocklistStateMap& state_map) {
136 std::move(callback).Run(state_map.empty() ? NOT_BLOCKLISTED
137 : state_map.begin()->second);
138 }
139
GetMalwareFromBlocklistStateMap(Blocklist::GetMalwareIDsCallback callback,const Blocklist::BlocklistStateMap & state_map)140 void GetMalwareFromBlocklistStateMap(
141 Blocklist::GetMalwareIDsCallback callback,
142 const Blocklist::BlocklistStateMap& state_map) {
143 std::set<std::string> malware;
144 for (const auto& state_pair : state_map) {
145 // TODO(oleg): UNKNOWN is treated as MALWARE for backwards compatibility.
146 // In future GetMalwareIDs will be removed and the caller will have to
147 // deal with BLOCKLISTED_UNKNOWN state returned from GetBlocklistedIDs.
148 if (state_pair.second == BLOCKLISTED_MALWARE ||
149 state_pair.second == BLOCKLISTED_UNKNOWN) {
150 malware.insert(state_pair.first);
151 }
152 }
153 std::move(callback).Run(malware);
154 }
155
156 } // namespace
157
Observer(Blocklist * blocklist)158 Blocklist::Observer::Observer(Blocklist* blocklist) : blocklist_(blocklist) {
159 blocklist_->AddObserver(this);
160 }
161
~Observer()162 Blocklist::Observer::~Observer() {
163 blocklist_->RemoveObserver(this);
164 }
165
ScopedDatabaseManagerForTest(scoped_refptr<SafeBrowsingDatabaseManager> database_manager)166 Blocklist::ScopedDatabaseManagerForTest::ScopedDatabaseManagerForTest(
167 scoped_refptr<SafeBrowsingDatabaseManager> database_manager)
168 : original_(GetDatabaseManager()) {
169 SetDatabaseManager(database_manager);
170 }
171
~ScopedDatabaseManagerForTest()172 Blocklist::ScopedDatabaseManagerForTest::~ScopedDatabaseManagerForTest() {
173 SetDatabaseManager(original_);
174 }
175
Blocklist(ExtensionPrefs * prefs)176 Blocklist::Blocklist(ExtensionPrefs* prefs) {
177 auto& lazy_database_manager = g_database_manager.Get();
178 // Using base::Unretained is safe because when this object goes away, the
179 // subscription will automatically be destroyed.
180 database_changed_subscription_ =
181 lazy_database_manager.RegisterDatabaseChangedCallback(base::BindRepeating(
182 &Blocklist::ObserveNewDatabase, base::Unretained(this)));
183
184 ObserveNewDatabase();
185 }
186
~Blocklist()187 Blocklist::~Blocklist() {}
188
189 // static
Get(content::BrowserContext * context)190 Blocklist* Blocklist::Get(content::BrowserContext* context) {
191 return BlocklistFactory::GetForBrowserContext(context);
192 }
193
GetBlocklistedIDs(const std::set<std::string> & ids,GetBlocklistedIDsCallback callback)194 void Blocklist::GetBlocklistedIDs(const std::set<std::string>& ids,
195 GetBlocklistedIDsCallback callback) {
196 DCHECK_CURRENTLY_ON(BrowserThread::UI);
197
198 if (ids.empty() || !GetDatabaseManager().get()) {
199 base::ThreadTaskRunnerHandle::Get()->PostTask(
200 FROM_HERE, base::BindOnce(std::move(callback), BlocklistStateMap()));
201 return;
202 }
203
204 // Constructing the SafeBrowsingClientImpl begins the process of asking
205 // safebrowsing for the blocklisted extensions. The set of blocklisted
206 // extensions returned by SafeBrowsing will then be passed to
207 // GetBlocklistStateIDs to get the particular BlocklistState for each id.
208 SafeBrowsingClientImpl::Start(
209 ids, base::BindOnce(&Blocklist::GetBlocklistStateForIDs, AsWeakPtr(),
210 std::move(callback)));
211 }
212
GetMalwareIDs(const std::set<std::string> & ids,GetMalwareIDsCallback callback)213 void Blocklist::GetMalwareIDs(const std::set<std::string>& ids,
214 GetMalwareIDsCallback callback) {
215 GetBlocklistedIDs(ids, base::BindOnce(&GetMalwareFromBlocklistStateMap,
216 std::move(callback)));
217 }
218
IsBlocklisted(const std::string & extension_id,IsBlocklistedCallback callback)219 void Blocklist::IsBlocklisted(const std::string& extension_id,
220 IsBlocklistedCallback callback) {
221 std::set<std::string> check;
222 check.insert(extension_id);
223 GetBlocklistedIDs(
224 check, base::BindOnce(&CheckOneExtensionState, std::move(callback)));
225 }
226
GetBlocklistStateForIDs(GetBlocklistedIDsCallback callback,const std::set<std::string> & blocklisted_ids)227 void Blocklist::GetBlocklistStateForIDs(
228 GetBlocklistedIDsCallback callback,
229 const std::set<std::string>& blocklisted_ids) {
230 DCHECK_CURRENTLY_ON(BrowserThread::UI);
231
232 std::set<std::string> ids_unknown_state;
233 BlocklistStateMap extensions_state;
234 for (const auto& blocklisted_id : blocklisted_ids) {
235 auto cache_it = blocklist_state_cache_.find(blocklisted_id);
236 if (cache_it == blocklist_state_cache_.end() ||
237 cache_it->second ==
238 BLOCKLISTED_UNKNOWN) { // Do not return UNKNOWN
239 // from cache, retry request.
240 ids_unknown_state.insert(blocklisted_id);
241 } else {
242 extensions_state[blocklisted_id] = cache_it->second;
243 }
244 }
245
246 if (ids_unknown_state.empty()) {
247 std::move(callback).Run(extensions_state);
248 } else {
249 // After the extension blocklist states have been downloaded, call this
250 // functions again, but prevent infinite cycle in case server is offline
251 // or some other reason prevents us from receiving the blocklist state for
252 // these extensions.
253 RequestExtensionsBlocklistState(
254 ids_unknown_state,
255 base::BindOnce(&Blocklist::ReturnBlocklistStateMap, AsWeakPtr(),
256 std::move(callback), blocklisted_ids));
257 }
258 }
259
ReturnBlocklistStateMap(GetBlocklistedIDsCallback callback,const std::set<std::string> & blocklisted_ids)260 void Blocklist::ReturnBlocklistStateMap(
261 GetBlocklistedIDsCallback callback,
262 const std::set<std::string>& blocklisted_ids) {
263 BlocklistStateMap extensions_state;
264 for (const auto& blocklisted_id : blocklisted_ids) {
265 auto cache_it = blocklist_state_cache_.find(blocklisted_id);
266 if (cache_it != blocklist_state_cache_.end())
267 extensions_state[blocklisted_id] = cache_it->second;
268 // If for some reason we still haven't cached the state of this extension,
269 // we silently skip it.
270 }
271
272 std::move(callback).Run(extensions_state);
273 }
274
RequestExtensionsBlocklistState(const std::set<std::string> & ids,base::OnceClosure callback)275 void Blocklist::RequestExtensionsBlocklistState(
276 const std::set<std::string>& ids,
277 base::OnceClosure callback) {
278 DCHECK_CURRENTLY_ON(BrowserThread::UI);
279 if (!state_fetcher_)
280 state_fetcher_.reset(new BlocklistStateFetcher());
281
282 state_requests_.emplace_back(std::vector<std::string>(ids.begin(), ids.end()),
283 std::move(callback));
284 for (const auto& id : ids) {
285 state_fetcher_->Request(
286 id, base::Bind(&Blocklist::OnBlocklistStateReceived, AsWeakPtr(), id));
287 }
288 }
289
OnBlocklistStateReceived(const std::string & id,BlocklistState state)290 void Blocklist::OnBlocklistStateReceived(const std::string& id,
291 BlocklistState state) {
292 DCHECK_CURRENTLY_ON(BrowserThread::UI);
293 blocklist_state_cache_[id] = state;
294
295 // Go through the opened requests and call the callbacks for those requests
296 // for which we already got all the required blocklist states.
297 auto requests_it = state_requests_.begin();
298 while (requests_it != state_requests_.end()) {
299 const std::vector<std::string>& ids = requests_it->first;
300
301 bool have_all_in_cache = true;
302 for (const auto& id : ids) {
303 if (!base::Contains(blocklist_state_cache_, id)) {
304 have_all_in_cache = false;
305 break;
306 }
307 }
308
309 if (have_all_in_cache) {
310 std::move(requests_it->second).Run();
311 requests_it = state_requests_.erase(requests_it); // returns next element
312 } else {
313 ++requests_it;
314 }
315 }
316 }
317
SetBlocklistStateFetcherForTest(BlocklistStateFetcher * fetcher)318 void Blocklist::SetBlocklistStateFetcherForTest(
319 BlocklistStateFetcher* fetcher) {
320 state_fetcher_.reset(fetcher);
321 }
322
ResetBlocklistStateFetcherForTest()323 BlocklistStateFetcher* Blocklist::ResetBlocklistStateFetcherForTest() {
324 return state_fetcher_.release();
325 }
326
ResetDatabaseUpdatedListenerForTest()327 void Blocklist::ResetDatabaseUpdatedListenerForTest() {
328 database_updated_subscription_.reset();
329 }
330
AddObserver(Observer * observer)331 void Blocklist::AddObserver(Observer* observer) {
332 DCHECK_CURRENTLY_ON(BrowserThread::UI);
333 observers_.AddObserver(observer);
334 }
335
RemoveObserver(Observer * observer)336 void Blocklist::RemoveObserver(Observer* observer) {
337 DCHECK_CURRENTLY_ON(BrowserThread::UI);
338 observers_.RemoveObserver(observer);
339 }
340
341 // static
SetDatabaseManager(scoped_refptr<SafeBrowsingDatabaseManager> database_manager)342 void Blocklist::SetDatabaseManager(
343 scoped_refptr<SafeBrowsingDatabaseManager> database_manager) {
344 g_database_manager.Get().set(database_manager);
345 }
346
347 // static
GetDatabaseManager()348 scoped_refptr<SafeBrowsingDatabaseManager> Blocklist::GetDatabaseManager() {
349 return g_database_manager.Get().get();
350 }
351
ObserveNewDatabase()352 void Blocklist::ObserveNewDatabase() {
353 auto database_manager = GetDatabaseManager();
354 if (database_manager.get()) {
355 // Using base::Unretained is safe because when this object goes away, the
356 // subscription to the callback list will automatically be destroyed.
357 database_updated_subscription_ =
358 database_manager.get()->RegisterDatabaseUpdatedCallback(
359 base::BindRepeating(&Blocklist::NotifyObservers,
360 base::Unretained(this)));
361 } else {
362 database_updated_subscription_.reset();
363 }
364 }
365
NotifyObservers()366 void Blocklist::NotifyObservers() {
367 for (auto& observer : observers_)
368 observer.OnBlocklistUpdated();
369 }
370
371 } // namespace extensions
372