1 // Copyright 2014 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #include "storage/browser/quota/client_usage_tracker.h"
6 
7 #include <stdint.h>
8 
9 #include "base/bind.h"
10 #include "base/callback_helpers.h"
11 #include "base/metrics/histogram_macros.h"
12 #include "base/stl_util.h"
13 
14 namespace storage {
15 
16 namespace {
17 
18 using OriginSetByHost = ClientUsageTracker::OriginSetByHost;
19 
DidGetHostUsage(UsageCallback callback,int64_t limited_usage,int64_t unlimited_usage)20 void DidGetHostUsage(UsageCallback callback,
21                      int64_t limited_usage,
22                      int64_t unlimited_usage) {
23   DCHECK_GE(limited_usage, 0);
24   DCHECK_GE(unlimited_usage, 0);
25   std::move(callback).Run(limited_usage + unlimited_usage);
26 }
27 
EraseOriginFromOriginSet(OriginSetByHost * origins_by_host,const std::string & host,const url::Origin & origin)28 bool EraseOriginFromOriginSet(OriginSetByHost* origins_by_host,
29                               const std::string& host,
30                               const url::Origin& origin) {
31   auto it = origins_by_host->find(host);
32   if (it == origins_by_host->end())
33     return false;
34 
35   if (!it->second.erase(origin))
36     return false;
37 
38   if (it->second.empty())
39     origins_by_host->erase(host);
40   return true;
41 }
42 
OriginSetContainsOrigin(const OriginSetByHost & origins,const std::string & host,const url::Origin & origin)43 bool OriginSetContainsOrigin(const OriginSetByHost& origins,
44                              const std::string& host,
45                              const url::Origin& origin) {
46   auto itr = origins.find(host);
47   return itr != origins.end() && base::Contains(itr->second, origin);
48 }
49 
RecordSkippedOriginHistogram(const InvalidOriginReason reason)50 void RecordSkippedOriginHistogram(const InvalidOriginReason reason) {
51   UMA_HISTOGRAM_ENUMERATION("Quota.SkippedInvalidOriginUsage", reason);
52 }
53 
54 }  // namespace
55 
56 struct ClientUsageTracker::AccumulateInfo {
57   AccumulateInfo() = default;
58   ~AccumulateInfo() = default;
59 
60   AccumulateInfo(const AccumulateInfo&) = delete;
61   AccumulateInfo& operator=(const AccumulateInfo&) = delete;
62 
63   size_t pending_jobs = 0;
64   int64_t limited_usage = 0;
65   int64_t unlimited_usage = 0;
66 };
67 
ClientUsageTracker(UsageTracker * tracker,scoped_refptr<QuotaClient> client,blink::mojom::StorageType type,scoped_refptr<SpecialStoragePolicy> special_storage_policy)68 ClientUsageTracker::ClientUsageTracker(
69     UsageTracker* tracker,
70     scoped_refptr<QuotaClient> client,
71     blink::mojom::StorageType type,
72     scoped_refptr<SpecialStoragePolicy> special_storage_policy)
73     : client_(std::move(client)),
74       type_(type),
75       global_limited_usage_(0),
76       global_unlimited_usage_(0),
77       global_usage_retrieved_(false),
78       special_storage_policy_(std::move(special_storage_policy)) {
79   DCHECK(client_);
80   if (special_storage_policy_.get())
81     special_storage_policy_->AddObserver(this);
82 }
83 
~ClientUsageTracker()84 ClientUsageTracker::~ClientUsageTracker() {
85   DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
86   if (special_storage_policy_.get())
87     special_storage_policy_->RemoveObserver(this);
88 }
89 
GetGlobalUsage(GlobalUsageCallback callback)90 void ClientUsageTracker::GetGlobalUsage(GlobalUsageCallback callback) {
91   DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
92   if (global_usage_retrieved_ &&
93       non_cached_limited_origins_by_host_.empty() &&
94       non_cached_unlimited_origins_by_host_.empty()) {
95     std::move(callback).Run(global_limited_usage_ + global_unlimited_usage_,
96                             global_unlimited_usage_);
97     return;
98   }
99 
100   client_->GetOriginsForType(
101       type_, base::BindOnce(&ClientUsageTracker::DidGetOriginsForGlobalUsage,
102                             weak_factory_.GetWeakPtr(), std::move(callback)));
103 }
104 
GetHostUsage(const std::string & host,UsageCallback callback)105 void ClientUsageTracker::GetHostUsage(const std::string& host,
106                                       UsageCallback callback) {
107   DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
108   if (base::Contains(cached_hosts_, host) &&
109       !base::Contains(non_cached_limited_origins_by_host_, host) &&
110       !base::Contains(non_cached_unlimited_origins_by_host_, host)) {
111     // TODO(kinuko): Drop host_usage_map_ cache periodically.
112     std::move(callback).Run(GetCachedHostUsage(host));
113     return;
114   }
115 
116   if (!host_usage_accumulators_.Add(
117           host, base::BindOnce(&DidGetHostUsage, std::move(callback))))
118     return;
119   client_->GetOriginsForHost(
120       type_, host,
121       base::BindOnce(&ClientUsageTracker::DidGetOriginsForHostUsage,
122                      weak_factory_.GetWeakPtr(), host));
123 }
124 
UpdateUsageCache(const url::Origin & origin,int64_t delta)125 void ClientUsageTracker::UpdateUsageCache(const url::Origin& origin,
126                                           int64_t delta) {
127   DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
128   const std::string& host = origin.host();
129   if (base::Contains(cached_hosts_, host)) {
130     if (!IsUsageCacheEnabledForOrigin(origin))
131       return;
132 
133     // Constrain |delta| to avoid negative usage values.
134     // TODO(michaeln): crbug/463729
135     delta = std::max(delta, -cached_usage_by_host_[host][origin]);
136     cached_usage_by_host_[host][origin] += delta;
137     UpdateGlobalUsageValue(IsStorageUnlimited(origin) ? &global_unlimited_usage_
138                                                       : &global_limited_usage_,
139                            delta);
140 
141     return;
142   }
143 
144   // We call GetHostUsage() so that the cache still updates, but we don't need
145   // to do anything else with the usage so we do not pass a callback.
146   GetHostUsage(host, base::DoNothing());
147 }
148 
GetCachedUsage() const149 int64_t ClientUsageTracker::GetCachedUsage() const {
150   DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
151   int64_t usage = 0;
152   for (const auto& host_and_usage_map : cached_usage_by_host_) {
153     for (const auto& origin_and_usage : host_and_usage_map.second)
154       usage += origin_and_usage.second;
155   }
156   return usage;
157 }
158 
GetCachedHostsUsage() const159 std::map<std::string, int64_t> ClientUsageTracker::GetCachedHostsUsage() const {
160   DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
161   std::map<std::string, int64_t> host_usage;
162   for (const auto& host_and_usage_map : cached_usage_by_host_) {
163     const std::string& host = host_and_usage_map.first;
164     host_usage[host] += GetCachedHostUsage(host);
165   }
166   return host_usage;
167 }
168 
GetCachedOriginsUsage() const169 std::map<url::Origin, int64_t> ClientUsageTracker::GetCachedOriginsUsage()
170     const {
171   DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
172   std::map<url::Origin, int64_t> origin_usage;
173   for (const auto& host_and_usage_map : cached_usage_by_host_) {
174     for (const auto& origin_and_usage : host_and_usage_map.second)
175       origin_usage[origin_and_usage.first] += origin_and_usage.second;
176   }
177   return origin_usage;
178 }
179 
GetCachedOrigins() const180 std::set<url::Origin> ClientUsageTracker::GetCachedOrigins() const {
181   DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
182   std::set<url::Origin> origins;
183   for (const auto& host_and_usage_map : cached_usage_by_host_) {
184     for (const auto& origin_and_usage : host_and_usage_map.second)
185       origins.insert(origin_and_usage.first);
186   }
187   return origins;
188 }
189 
SetUsageCacheEnabled(const url::Origin & origin,bool enabled)190 void ClientUsageTracker::SetUsageCacheEnabled(const url::Origin& origin,
191                                               bool enabled) {
192   DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
193   const std::string& host = origin.host();
194   if (!enabled) {
195     // Erase |origin| from cache and subtract its usage.
196     auto host_it = cached_usage_by_host_.find(host);
197     if (host_it != cached_usage_by_host_.end()) {
198       UsageMap& cached_usage_for_host = host_it->second;
199 
200       auto origin_it = cached_usage_for_host.find(origin);
201       if (origin_it != cached_usage_for_host.end()) {
202         int64_t usage = origin_it->second;
203         UpdateUsageCache(origin, -usage);
204         cached_usage_for_host.erase(origin_it);
205         if (cached_usage_for_host.empty()) {
206           cached_usage_by_host_.erase(host_it);
207           cached_hosts_.erase(host);
208         }
209       }
210     }
211 
212     if (IsStorageUnlimited(origin))
213       non_cached_unlimited_origins_by_host_[host].insert(origin);
214     else
215       non_cached_limited_origins_by_host_[host].insert(origin);
216   } else {
217     // Erase |origin| from |non_cached_origins_| and invalidate the usage cache
218     // for the host.
219     if (EraseOriginFromOriginSet(&non_cached_limited_origins_by_host_,
220                                  host, origin) ||
221         EraseOriginFromOriginSet(&non_cached_unlimited_origins_by_host_,
222                                  host, origin)) {
223       cached_hosts_.erase(host);
224       global_usage_retrieved_ = false;
225     }
226   }
227 }
228 
DidGetOriginsForGlobalUsage(GlobalUsageCallback callback,const std::vector<url::Origin> & origins)229 void ClientUsageTracker::DidGetOriginsForGlobalUsage(
230     GlobalUsageCallback callback,
231     const std::vector<url::Origin>& origins) {
232   DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
233   std::map<std::string, std::vector<url::Origin>> origins_by_host;
234   for (const auto& origin : origins)
235     origins_by_host[origin.host()].push_back(origin);
236 
237   AccumulateInfo* info = new AccumulateInfo;
238   // Getting host usage may synchronously return the result if the usage is
239   // cached, which may in turn dispatch the completion callback before we finish
240   // looping over all hosts (because info->pending_jobs may reach 0 during the
241   // loop).  To avoid this, we add one more pending host as a sentinel and
242   // fire the sentinel callback at the end.
243   info->pending_jobs = origins_by_host.size() + 1;
244   auto accumulator = base::BindRepeating(
245       &ClientUsageTracker::AccumulateHostUsage, weak_factory_.GetWeakPtr(),
246       base::Owned(info), base::AdaptCallbackForRepeating(std::move(callback)));
247 
248   for (const auto& host_and_origins : origins_by_host) {
249     const std::string& host = host_and_origins.first;
250     const std::vector<url::Origin>& origins = host_and_origins.second;
251     if (host_usage_accumulators_.Add(host, accumulator))
252       GetUsageForOrigins(host, origins);
253   }
254 
255   // Fire the sentinel as we've now called GetUsageForOrigins for all clients.
256   accumulator.Run(0, 0);
257 }
258 
AccumulateHostUsage(AccumulateInfo * info,GlobalUsageCallback callback,int64_t limited_usage,int64_t unlimited_usage)259 void ClientUsageTracker::AccumulateHostUsage(AccumulateInfo* info,
260                                              GlobalUsageCallback callback,
261                                              int64_t limited_usage,
262                                              int64_t unlimited_usage) {
263   DCHECK_GT(info->pending_jobs, 0U);
264   DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
265   info->limited_usage += limited_usage;
266   info->unlimited_usage += unlimited_usage;
267   if (--info->pending_jobs)
268     return;
269 
270   DCHECK_GE(info->limited_usage, 0);
271   DCHECK_GE(info->unlimited_usage, 0);
272 
273   global_usage_retrieved_ = true;
274   std::move(callback).Run(info->limited_usage + info->unlimited_usage,
275                           info->unlimited_usage);
276 }
277 
DidGetOriginsForHostUsage(const std::string & host,const std::vector<url::Origin> & origins)278 void ClientUsageTracker::DidGetOriginsForHostUsage(
279     const std::string& host,
280     const std::vector<url::Origin>& origins) {
281   DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
282   GetUsageForOrigins(host, origins);
283 }
284 
GetUsageForOrigins(const std::string & host,const std::vector<url::Origin> & origins)285 void ClientUsageTracker::GetUsageForOrigins(
286     const std::string& host,
287     const std::vector<url::Origin>& origins) {
288   DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
289   AccumulateInfo* info = new AccumulateInfo;
290   // Getting origin usage may synchronously return the result if the usage is
291   // cached, which may in turn dispatch the completion callback before we finish
292   // looping over all origins (because info->pending_jobs may reach 0 during the
293   // loop).  To avoid this, we add one more pending origin as a sentinel and
294   // fire the sentinel callback at the end.
295   info->pending_jobs = origins.size() + 1;
296   auto accumulator =
297       base::BindRepeating(&ClientUsageTracker::AccumulateOriginUsage,
298                           weak_factory_.GetWeakPtr(), base::Owned(info), host);
299 
300   for (const auto& origin : origins) {
301     DCHECK_EQ(host, origin.host());
302 
303     int64_t origin_usage = 0;
304     if (GetCachedOriginUsage(origin, &origin_usage)) {
305       accumulator.Run(origin, origin_usage);
306     } else {
307       client_->GetOriginUsage(origin, type_,
308                               base::BindOnce(accumulator, origin));
309     }
310   }
311 
312   // Fire the sentinel as we've now called GetOriginUsage for all clients.
313   accumulator.Run(base::nullopt, 0);
314 }
315 
AccumulateOriginUsage(AccumulateInfo * info,const std::string & host,const base::Optional<url::Origin> & origin,int64_t usage)316 void ClientUsageTracker::AccumulateOriginUsage(
317     AccumulateInfo* info,
318     const std::string& host,
319     const base::Optional<url::Origin>& origin,
320     int64_t usage) {
321   DCHECK_GT(info->pending_jobs, 0U);
322   DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
323   if (origin.has_value()) {
324     // TODO(https://crbug.com/941480): |origin| should not be opaque or have an
325     // empty url, but sometimes it is.
326     if (origin->opaque()) {
327       DVLOG(1) << "AccumulateOriginUsage for opaque origin!";
328       RecordSkippedOriginHistogram(InvalidOriginReason::kIsOpaque);
329     } else if (origin->GetURL().is_empty()) {
330       DVLOG(1) << "AccumulateOriginUsage for origin with empty url!";
331       RecordSkippedOriginHistogram(InvalidOriginReason::kIsEmpty);
332     } else {
333       if (usage < 0)
334         usage = 0;
335 
336       if (IsStorageUnlimited(*origin))
337         info->unlimited_usage += usage;
338       else
339         info->limited_usage += usage;
340       if (IsUsageCacheEnabledForOrigin(*origin))
341         AddCachedOrigin(*origin, usage);
342     }
343   }
344   if (--info->pending_jobs)
345     return;
346 
347   AddCachedHost(host);
348   host_usage_accumulators_.Run(
349       host, info->limited_usage, info->unlimited_usage);
350 }
351 
AddCachedOrigin(const url::Origin & origin,int64_t new_usage)352 void ClientUsageTracker::AddCachedOrigin(const url::Origin& origin,
353                                          int64_t new_usage) {
354   DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
355   DCHECK(IsUsageCacheEnabledForOrigin(origin));
356 
357   const std::string& host = origin.host();
358   int64_t* usage = &cached_usage_by_host_[host][origin];
359   int64_t delta = new_usage - *usage;
360   *usage = new_usage;
361   if (delta) {
362     UpdateGlobalUsageValue(IsStorageUnlimited(origin) ? &global_unlimited_usage_
363                                                       : &global_limited_usage_,
364                            delta);
365   }
366 }
367 
AddCachedHost(const std::string & host)368 void ClientUsageTracker::AddCachedHost(const std::string& host) {
369   DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
370   cached_hosts_.insert(host);
371 }
372 
GetCachedHostUsage(const std::string & host) const373 int64_t ClientUsageTracker::GetCachedHostUsage(const std::string& host) const {
374   DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
375   auto it = cached_usage_by_host_.find(host);
376   if (it == cached_usage_by_host_.end())
377     return 0;
378 
379   int64_t usage = 0;
380   const UsageMap& usage_map = it->second;
381   for (const auto& origin_and_usage : usage_map)
382     usage += origin_and_usage.second;
383   return usage;
384 }
385 
GetCachedOriginUsage(const url::Origin & origin,int64_t * usage) const386 bool ClientUsageTracker::GetCachedOriginUsage(const url::Origin& origin,
387                                               int64_t* usage) const {
388   DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
389   const std::string& host = origin.host();
390   auto host_it = cached_usage_by_host_.find(host);
391   if (host_it == cached_usage_by_host_.end())
392     return false;
393 
394   auto origin_it = host_it->second.find(origin);
395   if (origin_it == host_it->second.end())
396     return false;
397 
398   DCHECK(IsUsageCacheEnabledForOrigin(origin));
399   *usage = origin_it->second;
400   return true;
401 }
402 
IsUsageCacheEnabledForOrigin(const url::Origin & origin) const403 bool ClientUsageTracker::IsUsageCacheEnabledForOrigin(
404     const url::Origin& origin) const {
405   DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
406   const std::string& host = origin.host();
407   return !OriginSetContainsOrigin(non_cached_limited_origins_by_host_,
408                                   host, origin) &&
409       !OriginSetContainsOrigin(non_cached_unlimited_origins_by_host_,
410                                host, origin);
411 }
412 
OnGranted(const url::Origin & origin,int change_flags)413 void ClientUsageTracker::OnGranted(const url::Origin& origin,
414                                    int change_flags) {
415   DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
416   if (change_flags & SpecialStoragePolicy::STORAGE_UNLIMITED) {
417     int64_t usage = 0;
418     if (GetCachedOriginUsage(origin, &usage)) {
419       global_unlimited_usage_ += usage;
420       global_limited_usage_ -= usage;
421     }
422 
423     const std::string& host = origin.host();
424     if (EraseOriginFromOriginSet(&non_cached_limited_origins_by_host_,
425                                  host, origin))
426       non_cached_unlimited_origins_by_host_[host].insert(origin);
427   }
428 }
429 
OnRevoked(const url::Origin & origin,int change_flags)430 void ClientUsageTracker::OnRevoked(const url::Origin& origin,
431                                    int change_flags) {
432   DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
433   if (change_flags & SpecialStoragePolicy::STORAGE_UNLIMITED) {
434     int64_t usage = 0;
435     if (GetCachedOriginUsage(origin, &usage)) {
436       global_unlimited_usage_ -= usage;
437       global_limited_usage_ += usage;
438     }
439 
440     const std::string& host = origin.host();
441     if (EraseOriginFromOriginSet(&non_cached_unlimited_origins_by_host_,
442                                  host, origin))
443       non_cached_limited_origins_by_host_[host].insert(origin);
444   }
445 }
446 
OnCleared()447 void ClientUsageTracker::OnCleared() {
448   DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
449   global_limited_usage_ += global_unlimited_usage_;
450   global_unlimited_usage_ = 0;
451 
452   for (const auto& host_and_origins : non_cached_unlimited_origins_by_host_) {
453     const auto& host = host_and_origins.first;
454     for (const auto& origin : host_and_origins.second)
455       non_cached_limited_origins_by_host_[host].insert(origin);
456   }
457   non_cached_unlimited_origins_by_host_.clear();
458 }
459 
UpdateGlobalUsageValue(int64_t * usage_value,int64_t delta)460 void ClientUsageTracker::UpdateGlobalUsageValue(int64_t* usage_value,
461                                                 int64_t delta) {
462   DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
463   *usage_value += delta;
464   if (*usage_value >= 0)
465     return;
466 
467   // If we have a negative global usage value, recalculate them.
468   // TODO(michaeln): There are book keeping bugs, crbug/463729
469   global_limited_usage_ = 0;
470   global_unlimited_usage_ = 0;
471   for (const auto& host_and_usage_map : cached_usage_by_host_) {
472     for (const auto& origin_and_usage : host_and_usage_map.second) {
473       if (IsStorageUnlimited(origin_and_usage.first))
474         global_unlimited_usage_ += origin_and_usage.second;
475       else
476         global_limited_usage_ += origin_and_usage.second;
477     }
478   }
479 }
480 
IsStorageUnlimited(const url::Origin & origin) const481 bool ClientUsageTracker::IsStorageUnlimited(const url::Origin& origin) const {
482   DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
483   if (type_ == blink::mojom::StorageType::kSyncable)
484     return false;
485   return special_storage_policy_.get() &&
486          special_storage_policy_->IsStorageUnlimited(origin.GetURL());
487 }
488 
489 }  // namespace storage
490