1 // Copyright 2012 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #include "chrome/browser/extensions/blocklist.h"
6 
7 #include <algorithm>
8 #include <iterator>
9 
10 #include "base/bind.h"
11 #include "base/callback_list.h"
12 #include "base/lazy_instance.h"
13 #include "base/macros.h"
14 #include "base/memory/ref_counted.h"
15 #include "base/single_thread_task_runner.h"
16 #include "base/stl_util.h"
17 #include "base/threading/thread_task_runner_handle.h"
18 #include "chrome/browser/browser_process.h"
19 #include "chrome/browser/extensions/blocklist_factory.h"
20 #include "chrome/browser/extensions/blocklist_state_fetcher.h"
21 #include "chrome/browser/safe_browsing/safe_browsing_service.h"
22 #include "components/prefs/pref_service.h"
23 #include "components/safe_browsing/buildflags.h"
24 #include "components/safe_browsing/core/db/util.h"
25 #include "content/public/browser/browser_task_traits.h"
26 #include "content/public/browser/browser_thread.h"
27 #include "extensions/browser/extension_prefs.h"
28 
29 using content::BrowserThread;
30 using safe_browsing::SafeBrowsingDatabaseManager;
31 
32 namespace extensions {
33 
34 namespace {
35 
36 // The safe browsing database manager to use. Make this a global/static variable
37 // rather than a member of Blocklist because Blocklist accesses the real
38 // database manager before it has a chance to get a fake one.
39 class LazySafeBrowsingDatabaseManager {
40  public:
LazySafeBrowsingDatabaseManager()41   LazySafeBrowsingDatabaseManager() {
42 #if BUILDFLAG(SAFE_BROWSING_DB_LOCAL)
43     if (g_browser_process && g_browser_process->safe_browsing_service()) {
44       instance_ =
45           g_browser_process->safe_browsing_service()->database_manager();
46     }
47 #endif
48   }
49 
get()50   scoped_refptr<SafeBrowsingDatabaseManager> get() { return instance_; }
51 
set(scoped_refptr<SafeBrowsingDatabaseManager> instance)52   void set(scoped_refptr<SafeBrowsingDatabaseManager> instance) {
53     instance_ = instance;
54     database_changed_callback_list_.Notify();
55   }
56 
57   std::unique_ptr<base::RepeatingClosureList::Subscription>
RegisterDatabaseChangedCallback(const base::RepeatingClosure & cb)58   RegisterDatabaseChangedCallback(const base::RepeatingClosure& cb) {
59     return database_changed_callback_list_.Add(cb);
60   }
61 
62  private:
63   scoped_refptr<SafeBrowsingDatabaseManager> instance_;
64   base::RepeatingClosureList database_changed_callback_list_;
65 };
66 
67 static base::LazyInstance<LazySafeBrowsingDatabaseManager>::DestructorAtExit
68     g_database_manager = LAZY_INSTANCE_INITIALIZER;
69 
70 // Implementation of SafeBrowsingDatabaseManager::Client, the class which is
71 // called back from safebrowsing queries.
72 //
73 // Constructed on any thread but lives on the IO from then on.
74 class SafeBrowsingClientImpl
75     : public SafeBrowsingDatabaseManager::Client,
76       public base::RefCountedThreadSafe<SafeBrowsingClientImpl> {
77  public:
78   using OnResultCallback =
79       base::OnceCallback<void(const std::set<std::string>&)>;
80 
81   // Constructs a client to query the database manager for |extension_ids| and
82   // run |callback| with the IDs of those which have been blocklisted.
Start(const std::set<std::string> & extension_ids,OnResultCallback callback)83   static void Start(const std::set<std::string>& extension_ids,
84                     OnResultCallback callback) {
85     auto safe_browsing_client = base::WrapRefCounted(
86         new SafeBrowsingClientImpl(extension_ids, std::move(callback)));
87     content::GetIOThreadTaskRunner({})->PostTask(
88         FROM_HERE,
89         base::BindOnce(&SafeBrowsingClientImpl::StartCheck,
90                        safe_browsing_client, g_database_manager.Get().get(),
91                        extension_ids));
92   }
93 
94  private:
95   friend class base::RefCountedThreadSafe<SafeBrowsingClientImpl>;
96 
SafeBrowsingClientImpl(const std::set<std::string> & extension_ids,OnResultCallback callback)97   SafeBrowsingClientImpl(const std::set<std::string>& extension_ids,
98                          OnResultCallback callback)
99       : callback_task_runner_(base::ThreadTaskRunnerHandle::Get()),
100         callback_(std::move(callback)) {}
101 
~SafeBrowsingClientImpl()102   ~SafeBrowsingClientImpl() override {}
103 
104   // Pass |database_manager| as a parameter to avoid touching
105   // SafeBrowsingService on the IO thread.
StartCheck(scoped_refptr<SafeBrowsingDatabaseManager> database_manager,const std::set<std::string> & extension_ids)106   void StartCheck(scoped_refptr<SafeBrowsingDatabaseManager> database_manager,
107                   const std::set<std::string>& extension_ids) {
108     DCHECK_CURRENTLY_ON(BrowserThread::IO);
109     if (database_manager->CheckExtensionIDs(extension_ids, this)) {
110       // Definitely not blocklisted. Callback immediately.
111       callback_task_runner_->PostTask(
112           FROM_HERE,
113           base::BindOnce(std::move(callback_), std::set<std::string>()));
114       return;
115     }
116     // Something might be blocklisted, response will come in
117     // OnCheckExtensionsResult.
118     AddRef();  // Balanced in OnCheckExtensionsResult
119   }
120 
OnCheckExtensionsResult(const std::set<std::string> & hits)121   void OnCheckExtensionsResult(const std::set<std::string>& hits) override {
122     DCHECK_CURRENTLY_ON(BrowserThread::IO);
123     callback_task_runner_->PostTask(FROM_HERE,
124                                     base::BindOnce(std::move(callback_), hits));
125     Release();  // Balanced in StartCheck.
126   }
127 
128   scoped_refptr<base::SingleThreadTaskRunner> callback_task_runner_;
129   OnResultCallback callback_;
130 
131   DISALLOW_COPY_AND_ASSIGN(SafeBrowsingClientImpl);
132 };
133 
CheckOneExtensionState(Blocklist::IsBlocklistedCallback callback,const Blocklist::BlocklistStateMap & state_map)134 void CheckOneExtensionState(Blocklist::IsBlocklistedCallback callback,
135                             const Blocklist::BlocklistStateMap& state_map) {
136   std::move(callback).Run(state_map.empty() ? NOT_BLOCKLISTED
137                                             : state_map.begin()->second);
138 }
139 
GetMalwareFromBlocklistStateMap(Blocklist::GetMalwareIDsCallback callback,const Blocklist::BlocklistStateMap & state_map)140 void GetMalwareFromBlocklistStateMap(
141     Blocklist::GetMalwareIDsCallback callback,
142     const Blocklist::BlocklistStateMap& state_map) {
143   std::set<std::string> malware;
144   for (const auto& state_pair : state_map) {
145     // TODO(oleg): UNKNOWN is treated as MALWARE for backwards compatibility.
146     // In future GetMalwareIDs will be removed and the caller will have to
147     // deal with BLOCKLISTED_UNKNOWN state returned from GetBlocklistedIDs.
148     if (state_pair.second == BLOCKLISTED_MALWARE ||
149         state_pair.second == BLOCKLISTED_UNKNOWN) {
150       malware.insert(state_pair.first);
151     }
152   }
153   std::move(callback).Run(malware);
154 }
155 
156 }  // namespace
157 
Observer(Blocklist * blocklist)158 Blocklist::Observer::Observer(Blocklist* blocklist) : blocklist_(blocklist) {
159   blocklist_->AddObserver(this);
160 }
161 
~Observer()162 Blocklist::Observer::~Observer() {
163   blocklist_->RemoveObserver(this);
164 }
165 
ScopedDatabaseManagerForTest(scoped_refptr<SafeBrowsingDatabaseManager> database_manager)166 Blocklist::ScopedDatabaseManagerForTest::ScopedDatabaseManagerForTest(
167     scoped_refptr<SafeBrowsingDatabaseManager> database_manager)
168     : original_(GetDatabaseManager()) {
169   SetDatabaseManager(database_manager);
170 }
171 
~ScopedDatabaseManagerForTest()172 Blocklist::ScopedDatabaseManagerForTest::~ScopedDatabaseManagerForTest() {
173   SetDatabaseManager(original_);
174 }
175 
Blocklist(ExtensionPrefs * prefs)176 Blocklist::Blocklist(ExtensionPrefs* prefs) {
177   auto& lazy_database_manager = g_database_manager.Get();
178   // Using base::Unretained is safe because when this object goes away, the
179   // subscription will automatically be destroyed.
180   database_changed_subscription_ =
181       lazy_database_manager.RegisterDatabaseChangedCallback(base::BindRepeating(
182           &Blocklist::ObserveNewDatabase, base::Unretained(this)));
183 
184   ObserveNewDatabase();
185 }
186 
~Blocklist()187 Blocklist::~Blocklist() {}
188 
189 // static
Get(content::BrowserContext * context)190 Blocklist* Blocklist::Get(content::BrowserContext* context) {
191   return BlocklistFactory::GetForBrowserContext(context);
192 }
193 
GetBlocklistedIDs(const std::set<std::string> & ids,GetBlocklistedIDsCallback callback)194 void Blocklist::GetBlocklistedIDs(const std::set<std::string>& ids,
195                                   GetBlocklistedIDsCallback callback) {
196   DCHECK_CURRENTLY_ON(BrowserThread::UI);
197 
198   if (ids.empty() || !GetDatabaseManager().get()) {
199     base::ThreadTaskRunnerHandle::Get()->PostTask(
200         FROM_HERE, base::BindOnce(std::move(callback), BlocklistStateMap()));
201     return;
202   }
203 
204   // Constructing the SafeBrowsingClientImpl begins the process of asking
205   // safebrowsing for the blocklisted extensions. The set of blocklisted
206   // extensions returned by SafeBrowsing will then be passed to
207   // GetBlocklistStateIDs to get the particular BlocklistState for each id.
208   SafeBrowsingClientImpl::Start(
209       ids, base::BindOnce(&Blocklist::GetBlocklistStateForIDs, AsWeakPtr(),
210                           std::move(callback)));
211 }
212 
GetMalwareIDs(const std::set<std::string> & ids,GetMalwareIDsCallback callback)213 void Blocklist::GetMalwareIDs(const std::set<std::string>& ids,
214                               GetMalwareIDsCallback callback) {
215   GetBlocklistedIDs(ids, base::BindOnce(&GetMalwareFromBlocklistStateMap,
216                                         std::move(callback)));
217 }
218 
IsBlocklisted(const std::string & extension_id,IsBlocklistedCallback callback)219 void Blocklist::IsBlocklisted(const std::string& extension_id,
220                               IsBlocklistedCallback callback) {
221   std::set<std::string> check;
222   check.insert(extension_id);
223   GetBlocklistedIDs(
224       check, base::BindOnce(&CheckOneExtensionState, std::move(callback)));
225 }
226 
GetBlocklistStateForIDs(GetBlocklistedIDsCallback callback,const std::set<std::string> & blocklisted_ids)227 void Blocklist::GetBlocklistStateForIDs(
228     GetBlocklistedIDsCallback callback,
229     const std::set<std::string>& blocklisted_ids) {
230   DCHECK_CURRENTLY_ON(BrowserThread::UI);
231 
232   std::set<std::string> ids_unknown_state;
233   BlocklistStateMap extensions_state;
234   for (const auto& blocklisted_id : blocklisted_ids) {
235     auto cache_it = blocklist_state_cache_.find(blocklisted_id);
236     if (cache_it == blocklist_state_cache_.end() ||
237         cache_it->second ==
238             BLOCKLISTED_UNKNOWN) {  // Do not return UNKNOWN
239                                     // from cache, retry request.
240       ids_unknown_state.insert(blocklisted_id);
241     } else {
242       extensions_state[blocklisted_id] = cache_it->second;
243     }
244   }
245 
246   if (ids_unknown_state.empty()) {
247     std::move(callback).Run(extensions_state);
248   } else {
249     // After the extension blocklist states have been downloaded, call this
250     // functions again, but prevent infinite cycle in case server is offline
251     // or some other reason prevents us from receiving the blocklist state for
252     // these extensions.
253     RequestExtensionsBlocklistState(
254         ids_unknown_state,
255         base::BindOnce(&Blocklist::ReturnBlocklistStateMap, AsWeakPtr(),
256                        std::move(callback), blocklisted_ids));
257   }
258 }
259 
ReturnBlocklistStateMap(GetBlocklistedIDsCallback callback,const std::set<std::string> & blocklisted_ids)260 void Blocklist::ReturnBlocklistStateMap(
261     GetBlocklistedIDsCallback callback,
262     const std::set<std::string>& blocklisted_ids) {
263   BlocklistStateMap extensions_state;
264   for (const auto& blocklisted_id : blocklisted_ids) {
265     auto cache_it = blocklist_state_cache_.find(blocklisted_id);
266     if (cache_it != blocklist_state_cache_.end())
267       extensions_state[blocklisted_id] = cache_it->second;
268     // If for some reason we still haven't cached the state of this extension,
269     // we silently skip it.
270   }
271 
272   std::move(callback).Run(extensions_state);
273 }
274 
RequestExtensionsBlocklistState(const std::set<std::string> & ids,base::OnceClosure callback)275 void Blocklist::RequestExtensionsBlocklistState(
276     const std::set<std::string>& ids,
277     base::OnceClosure callback) {
278   DCHECK_CURRENTLY_ON(BrowserThread::UI);
279   if (!state_fetcher_)
280     state_fetcher_.reset(new BlocklistStateFetcher());
281 
282   state_requests_.emplace_back(std::vector<std::string>(ids.begin(), ids.end()),
283                                std::move(callback));
284   for (const auto& id : ids) {
285     state_fetcher_->Request(
286         id, base::Bind(&Blocklist::OnBlocklistStateReceived, AsWeakPtr(), id));
287   }
288 }
289 
OnBlocklistStateReceived(const std::string & id,BlocklistState state)290 void Blocklist::OnBlocklistStateReceived(const std::string& id,
291                                          BlocklistState state) {
292   DCHECK_CURRENTLY_ON(BrowserThread::UI);
293   blocklist_state_cache_[id] = state;
294 
295   // Go through the opened requests and call the callbacks for those requests
296   // for which we already got all the required blocklist states.
297   auto requests_it = state_requests_.begin();
298   while (requests_it != state_requests_.end()) {
299     const std::vector<std::string>& ids = requests_it->first;
300 
301     bool have_all_in_cache = true;
302     for (const auto& id : ids) {
303       if (!base::Contains(blocklist_state_cache_, id)) {
304         have_all_in_cache = false;
305         break;
306       }
307     }
308 
309     if (have_all_in_cache) {
310       std::move(requests_it->second).Run();
311       requests_it = state_requests_.erase(requests_it);  // returns next element
312     } else {
313       ++requests_it;
314     }
315   }
316 }
317 
SetBlocklistStateFetcherForTest(BlocklistStateFetcher * fetcher)318 void Blocklist::SetBlocklistStateFetcherForTest(
319     BlocklistStateFetcher* fetcher) {
320   state_fetcher_.reset(fetcher);
321 }
322 
ResetBlocklistStateFetcherForTest()323 BlocklistStateFetcher* Blocklist::ResetBlocklistStateFetcherForTest() {
324   return state_fetcher_.release();
325 }
326 
ResetDatabaseUpdatedListenerForTest()327 void Blocklist::ResetDatabaseUpdatedListenerForTest() {
328   database_updated_subscription_.reset();
329 }
330 
AddObserver(Observer * observer)331 void Blocklist::AddObserver(Observer* observer) {
332   DCHECK_CURRENTLY_ON(BrowserThread::UI);
333   observers_.AddObserver(observer);
334 }
335 
RemoveObserver(Observer * observer)336 void Blocklist::RemoveObserver(Observer* observer) {
337   DCHECK_CURRENTLY_ON(BrowserThread::UI);
338   observers_.RemoveObserver(observer);
339 }
340 
341 // static
SetDatabaseManager(scoped_refptr<SafeBrowsingDatabaseManager> database_manager)342 void Blocklist::SetDatabaseManager(
343     scoped_refptr<SafeBrowsingDatabaseManager> database_manager) {
344   g_database_manager.Get().set(database_manager);
345 }
346 
347 // static
GetDatabaseManager()348 scoped_refptr<SafeBrowsingDatabaseManager> Blocklist::GetDatabaseManager() {
349   return g_database_manager.Get().get();
350 }
351 
ObserveNewDatabase()352 void Blocklist::ObserveNewDatabase() {
353   auto database_manager = GetDatabaseManager();
354   if (database_manager.get()) {
355     // Using base::Unretained is safe because when this object goes away, the
356     // subscription to the callback list will automatically be destroyed.
357     database_updated_subscription_ =
358         database_manager.get()->RegisterDatabaseUpdatedCallback(
359             base::BindRepeating(&Blocklist::NotifyObservers,
360                                 base::Unretained(this)));
361   } else {
362     database_updated_subscription_.reset();
363   }
364 }
365 
NotifyObservers()366 void Blocklist::NotifyObservers() {
367   for (auto& observer : observers_)
368     observer.OnBlocklistUpdated();
369 }
370 
371 }  // namespace extensions
372