1 // Copyright 2016 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #include "components/offline_pages/core/background/request_queue_store.h"
6 
7 #include <string>
8 #include <unordered_set>
9 #include <utility>
10 
11 #include "base/bind.h"
12 #include "base/files/file_path.h"
13 #include "base/files/file_util.h"
14 #include "base/location.h"
15 #include "base/logging.h"
16 #include "base/sequenced_task_runner.h"
17 #include "base/single_thread_task_runner.h"
18 #include "base/task_runner_util.h"
19 #include "base/threading/thread_task_runner_handle.h"
20 #include "components/offline_pages/core/background/save_page_request.h"
21 #include "components/offline_pages/core/offline_page_item_utils.h"
22 #include "components/offline_pages/core/offline_store_utils.h"
23 #include "sql/database.h"
24 #include "sql/statement.h"
25 #include "sql/transaction.h"
26 
27 namespace offline_pages {
28 
29 template class StoreUpdateResult<SavePageRequest>;
30 
31 namespace {
32 
33 using SuccessCallback = base::OnceCallback<void(bool)>;
34 
35 // This is a macro instead of a const so that
36 // it can be used inline in other SQL statements below.
37 #define REQUEST_QUEUE_TABLE_NAME "request_queue_v1"
38 // The full set of fields, used in several SQL statements.
39 #define REQUEST_QUEUE_FIELDS                                                \
40   "request_id, creation_time, activation_time,"                             \
41   " last_attempt_time, started_attempt_count, completed_attempt_count,"     \
42   " state, url, client_namespace, client_id, original_url, request_origin," \
43   " fail_state, auto_fetch_notification_state"
44 
45 const bool kUserRequested = true;
46 
CreateRequestQueueTable(sql::Database * db)47 bool CreateRequestQueueTable(sql::Database* db) {
48   static const char kSql[] =
49       "CREATE TABLE IF NOT EXISTS " REQUEST_QUEUE_TABLE_NAME
50       " (request_id INTEGER PRIMARY KEY NOT NULL,"
51       " creation_time INTEGER NOT NULL,"
52       " activation_time INTEGER NOT NULL DEFAULT 0,"
53       " last_attempt_time INTEGER NOT NULL DEFAULT 0,"
54       " started_attempt_count INTEGER NOT NULL,"
55       " completed_attempt_count INTEGER NOT NULL,"
56       " state INTEGER NOT NULL DEFAULT 0,"
57       " url VARCHAR NOT NULL,"
58       " client_namespace VARCHAR NOT NULL,"
59       " client_id VARCHAR NOT NULL,"
60       " original_url VARCHAR NOT NULL DEFAULT '',"
61       " request_origin VARCHAR NOT NULL DEFAULT '',"
62       " fail_state INTEGER NOT NULL DEFAULT 0,"
63       " auto_fetch_notification_state INTEGER NOT NULL DEFAULT 0"
64       ")";
65   return db->Execute(kSql);
66 }
67 
68 // Upgrades an old version of the request queue table to the new version.
69 //
70 // The upgrade is done by renaming the existing table to
71 // 'temp_request_queue_v1', reinserting data from the temporary table back into
72 // 'request_queue_v1', and finally dropping the temporary table.
73 //
74 // |upgrade_sql| is the SQL statement that copies data from the temporary
75 // table back into the primary table.
UpgradeWithQuery(sql::Database * db,const char * upgrade_sql)76 bool UpgradeWithQuery(sql::Database* db, const char* upgrade_sql) {
77   if (!db->Execute("ALTER TABLE " REQUEST_QUEUE_TABLE_NAME
78                    " RENAME TO temp_" REQUEST_QUEUE_TABLE_NAME)) {
79     return false;
80   }
81   if (!CreateRequestQueueTable(db))
82     return false;
83   if (!db->Execute(upgrade_sql))
84     return false;
85   return db->Execute("DROP TABLE IF EXISTS temp_" REQUEST_QUEUE_TABLE_NAME);
86 }
87 
UpgradeFrom57(sql::Database * db)88 bool UpgradeFrom57(sql::Database* db) {
89   static const char kSql[] =
90       "INSERT INTO " REQUEST_QUEUE_TABLE_NAME
91       " (request_id, creation_time, activation_time, last_attempt_time, "
92       "started_attempt_count, completed_attempt_count, state, url, "
93       "client_namespace, client_id) "
94       "SELECT "
95       "request_id, creation_time, activation_time, last_attempt_time, "
96       "started_attempt_count, completed_attempt_count, state, url, "
97       "client_namespace, client_id "
98       "FROM temp_" REQUEST_QUEUE_TABLE_NAME;
99   return UpgradeWithQuery(db, kSql);
100 }
101 
UpgradeFrom58(sql::Database * db)102 bool UpgradeFrom58(sql::Database* db) {
103   static const char kSql[] =
104       "INSERT INTO " REQUEST_QUEUE_TABLE_NAME
105       " (request_id, creation_time, activation_time, last_attempt_time, "
106       "started_attempt_count, completed_attempt_count, state, url, "
107       "client_namespace, client_id, original_url) "
108       "SELECT "
109       "request_id, creation_time, activation_time, last_attempt_time, "
110       "started_attempt_count, completed_attempt_count, state, url, "
111       "client_namespace, client_id, original_url "
112       "FROM temp_" REQUEST_QUEUE_TABLE_NAME;
113   return UpgradeWithQuery(db, kSql);
114 }
115 
UpgradeFrom61(sql::Database * db)116 bool UpgradeFrom61(sql::Database* db) {
117   static const char kSql[] =
118       "INSERT INTO " REQUEST_QUEUE_TABLE_NAME
119       " (request_id, creation_time, activation_time, last_attempt_time, "
120       "started_attempt_count, completed_attempt_count, state, url, "
121       "client_namespace, client_id, original_url, request_origin) "
122       "SELECT "
123       "request_id, creation_time, activation_time, last_attempt_time, "
124       "started_attempt_count, completed_attempt_count, state, url, "
125       "client_namespace, client_id, original_url, request_origin "
126       "FROM temp_" REQUEST_QUEUE_TABLE_NAME;
127   return UpgradeWithQuery(db, kSql);
128 }
129 
UpgradeFrom72(sql::Database * db)130 bool UpgradeFrom72(sql::Database* db) {
131   static const char kSql[] =
132       "INSERT INTO " REQUEST_QUEUE_TABLE_NAME
133       " (request_id, creation_time, activation_time, last_attempt_time, "
134       "started_attempt_count, completed_attempt_count, state, url, "
135       "client_namespace, client_id, original_url, request_origin, fail_state) "
136       "SELECT "
137       "request_id, creation_time, activation_time, last_attempt_time, "
138       "started_attempt_count, completed_attempt_count, state, url, "
139       "client_namespace, client_id, original_url, request_origin, fail_state "
140       "FROM temp_" REQUEST_QUEUE_TABLE_NAME;
141   return UpgradeWithQuery(db, kSql);
142 }
143 
CreateSchemaSync(sql::Database * db)144 bool CreateSchemaSync(sql::Database* db) {
145   sql::Transaction transaction(db);
146   if (!transaction.Begin())
147     return false;
148 
149   if (!db->DoesTableExist(REQUEST_QUEUE_TABLE_NAME)) {
150     if (!CreateRequestQueueTable(db))
151       return false;
152   }
153 
154   // If there is not already a state column, we need to drop the old table. We
155   // are choosing to drop instead of upgrade since the feature is not yet
156   // released, so we don't try to migrate it.
157   if (!db->DoesColumnExist(REQUEST_QUEUE_TABLE_NAME, "state")) {
158     if (!db->Execute("DROP TABLE IF EXISTS " REQUEST_QUEUE_TABLE_NAME))
159       return false;
160   }
161 
162   if (!db->DoesColumnExist(REQUEST_QUEUE_TABLE_NAME, "original_url")) {
163     if (!UpgradeFrom57(db))
164       return false;
165   } else if (!db->DoesColumnExist(REQUEST_QUEUE_TABLE_NAME, "request_origin")) {
166     if (!UpgradeFrom58(db))
167       return false;
168   } else if (!db->DoesColumnExist(REQUEST_QUEUE_TABLE_NAME, "fail_state")) {
169     if (!UpgradeFrom61(db))
170       return false;
171   } else if (!db->DoesColumnExist(REQUEST_QUEUE_TABLE_NAME,
172                                   "auto_fetch_notification_state")) {
173     if (!UpgradeFrom72(db))
174       return false;
175   }
176 
177   // This would be a great place to add indices when we need them.
178   return transaction.Commit();
179 }
180 
181 // Enum conversion code. Database corruption is possible, so make sure enum
182 // values are in the domain. Because corruption is rare, there is not robust
183 // error handling.
184 
AutoFetchNotificationStateFromInt(int value)185 SavePageRequest::AutoFetchNotificationState AutoFetchNotificationStateFromInt(
186     int value) {
187   switch (static_cast<SavePageRequest::AutoFetchNotificationState>(value)) {
188     case SavePageRequest::AutoFetchNotificationState::kUnknown:
189     case SavePageRequest::AutoFetchNotificationState::kShown:
190       return static_cast<SavePageRequest::AutoFetchNotificationState>(value);
191   }
192   DLOG(ERROR) << "Invalid AutoFetchNotificationState value: " << value;
193   return SavePageRequest::AutoFetchNotificationState::kUnknown;
194 }
195 
ToRequestState(int value)196 SavePageRequest::RequestState ToRequestState(int value) {
197   switch (static_cast<SavePageRequest::RequestState>(value)) {
198     case SavePageRequest::RequestState::AVAILABLE:
199     case SavePageRequest::RequestState::PAUSED:
200     case SavePageRequest::RequestState::OFFLINING:
201       return static_cast<SavePageRequest::RequestState>(value);
202   }
203   DLOG(ERROR) << "Invalid RequestState value: " << value;
204   return SavePageRequest::RequestState::AVAILABLE;
205 }
206 
ToFailState(int value)207 offline_items_collection::FailState ToFailState(int value) {
208   offline_items_collection::FailState state = FailState::NO_FAILURE;
209   if (!offline_items_collection::ToFailState(value, &state)) {
210     DLOG(ERROR) << "Invalid FailState: " << value;
211   }
212 
213   return state;
214 }
215 
216 // Create a save page request from the first row of an SQL result. The result
217 // must have the exact columns from the |REQUEST_QUEUE_FIELDS| macro.
MakeSavePageRequest(const sql::Statement & statement)218 std::unique_ptr<SavePageRequest> MakeSavePageRequest(
219     const sql::Statement& statement) {
220   const int64_t id = statement.ColumnInt64(0);
221   const base::Time creation_time =
222       store_utils::FromDatabaseTime(statement.ColumnInt64(1));
223   const base::Time last_attempt_time =
224       store_utils::FromDatabaseTime(statement.ColumnInt64(3));
225   const int64_t started_attempt_count = statement.ColumnInt64(4);
226   const int64_t completed_attempt_count = statement.ColumnInt64(5);
227   const SavePageRequest::RequestState state =
228       ToRequestState(statement.ColumnInt64(6));
229   const GURL url(statement.ColumnString(7));
230   const ClientId client_id(statement.ColumnString(8),
231                            statement.ColumnString(9));
232   const GURL original_url(statement.ColumnString(10));
233   const std::string request_origin(statement.ColumnString(11));
234 
235   DVLOG(2) << "making save page request - id " << id << " url " << url
236            << " client_id " << client_id.name_space << "-" << client_id.id
237            << " creation time " << creation_time << " user requested "
238            << kUserRequested << " original_url " << original_url
239            << " request_origin " << request_origin;
240 
241   std::unique_ptr<SavePageRequest> request(new SavePageRequest(
242       id, std::move(url), std::move(client_id), creation_time, kUserRequested));
243   request->set_last_attempt_time(last_attempt_time);
244   request->set_started_attempt_count(started_attempt_count);
245   request->set_completed_attempt_count(completed_attempt_count);
246   request->set_request_state(state);
247   request->set_original_url(std::move(original_url));
248   request->set_request_origin(std::move(request_origin));
249   request->set_fail_state(ToFailState(statement.ColumnInt64(12)));
250   request->set_auto_fetch_notification_state(
251       AutoFetchNotificationStateFromInt(statement.ColumnInt(13)));
252   return request;
253 }
254 
255 // Get a request for a specific id.
GetOneRequestSync(sql::Database * db,const int64_t request_id)256 std::unique_ptr<SavePageRequest> GetOneRequestSync(sql::Database* db,
257                                                    const int64_t request_id) {
258   static const char kSql[] =
259       "SELECT " REQUEST_QUEUE_FIELDS " FROM " REQUEST_QUEUE_TABLE_NAME
260       " WHERE request_id=?";
261 
262   sql::Statement statement(db->GetCachedStatement(SQL_FROM_HERE, kSql));
263   statement.BindInt64(0, request_id);
264 
265   if (statement.Step())
266     return MakeSavePageRequest(statement);
267   return {};
268 }
269 
DeleteRequestByIdSync(sql::Database * db,int64_t request_id)270 ItemActionStatus DeleteRequestByIdSync(sql::Database* db, int64_t request_id) {
271   static const char kSql[] =
272       "DELETE FROM " REQUEST_QUEUE_TABLE_NAME " WHERE request_id=?";
273   sql::Statement statement(db->GetCachedStatement(SQL_FROM_HERE, kSql));
274   statement.BindInt64(0, request_id);
275   if (!statement.Run())
276     return ItemActionStatus::STORE_ERROR;
277   if (db->GetLastChangeCount() == 0)
278     return ItemActionStatus::NOT_FOUND;
279   return ItemActionStatus::SUCCESS;
280 }
281 
InsertSync(sql::Database * db,const SavePageRequest & request)282 AddRequestResult InsertSync(sql::Database* db, const SavePageRequest& request) {
283   static const char kSql[] = "INSERT OR IGNORE INTO " REQUEST_QUEUE_TABLE_NAME
284                              " (" REQUEST_QUEUE_FIELDS
285                              ") VALUES"
286                              " (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)";
287 
288   sql::Statement statement(db->GetCachedStatement(SQL_FROM_HERE, kSql));
289   statement.BindInt64(0, request.request_id());
290   statement.BindInt64(1, store_utils::ToDatabaseTime(request.creation_time()));
291   statement.BindInt64(2, 0);
292   statement.BindInt64(3,
293                       store_utils::ToDatabaseTime(request.last_attempt_time()));
294   statement.BindInt64(4, request.started_attempt_count());
295   statement.BindInt64(5, request.completed_attempt_count());
296   statement.BindInt64(6, static_cast<int64_t>(request.request_state()));
297   statement.BindString(7, request.url().spec());
298   statement.BindString(8, request.client_id().name_space);
299   statement.BindString(9, request.client_id().id);
300   statement.BindString(10, request.original_url().spec());
301   statement.BindString(11, request.request_origin());
302   statement.BindInt64(12, static_cast<int64_t>(request.fail_state()));
303   statement.BindInt64(
304       13, static_cast<int64_t>(request.auto_fetch_notification_state()));
305 
306   if (!statement.Run())
307     return AddRequestResult::STORE_FAILURE;
308   if (db->GetLastChangeCount() == 0)
309     return AddRequestResult::ALREADY_EXISTS;
310   return AddRequestResult::SUCCESS;
311 }
312 
UpdateSync(sql::Database * db,const SavePageRequest & request)313 ItemActionStatus UpdateSync(sql::Database* db, const SavePageRequest& request) {
314   static const char kSql[] =
315       "UPDATE OR IGNORE " REQUEST_QUEUE_TABLE_NAME
316       " SET creation_time = ?, activation_time = ?, last_attempt_time = ?,"
317       " started_attempt_count = ?, completed_attempt_count = ?, state = ?,"
318       " url = ?, client_namespace = ?, client_id = ?, original_url = ?,"
319       " request_origin = ?, fail_state = ?, auto_fetch_notification_state = ?"
320       " WHERE request_id = ?";
321 
322   sql::Statement statement(db->GetCachedStatement(SQL_FROM_HERE, kSql));
323   // SET columns:
324   statement.BindInt64(0, store_utils::ToDatabaseTime(request.creation_time()));
325   statement.BindInt64(1, 0);
326   statement.BindInt64(2,
327                       store_utils::ToDatabaseTime(request.last_attempt_time()));
328   statement.BindInt64(3, request.started_attempt_count());
329   statement.BindInt64(4, request.completed_attempt_count());
330   statement.BindInt64(5, static_cast<int64_t>(request.request_state()));
331   statement.BindString(6, request.url().spec());
332   statement.BindString(7, request.client_id().name_space);
333   statement.BindString(8, request.client_id().id);
334   statement.BindString(9, request.original_url().spec());
335   statement.BindString(10, request.request_origin());
336   statement.BindInt64(11, static_cast<int64_t>(request.fail_state()));
337   statement.BindInt64(
338       12, static_cast<int64_t>(request.auto_fetch_notification_state()));
339   // WHERE:
340   statement.BindInt64(13, request.request_id());
341 
342   if (!statement.Run())
343     return ItemActionStatus::STORE_ERROR;
344   if (db->GetLastChangeCount() == 0)
345     return ItemActionStatus::NOT_FOUND;
346   return ItemActionStatus::SUCCESS;
347 }
348 
StoreUpdateResultForIds(StoreState store_state,const std::vector<int64_t> & item_ids,ItemActionStatus action_status)349 UpdateRequestsResult StoreUpdateResultForIds(
350     StoreState store_state,
351     const std::vector<int64_t>& item_ids,
352     ItemActionStatus action_status) {
353   UpdateRequestsResult result(store_state);
354   for (const auto& item_id : item_ids)
355     result.item_statuses.emplace_back(item_id, action_status);
356   return result;
357 }
358 
StoreErrorForAllRequests(const std::vector<SavePageRequest> & items)359 UpdateRequestsResult StoreErrorForAllRequests(
360     const std::vector<SavePageRequest>& items) {
361   std::vector<int64_t> item_ids;
362   for (const auto& item : items)
363     item_ids.push_back(item.request_id());
364   return StoreUpdateResultForIds(StoreState::LOADED, item_ids,
365                                  ItemActionStatus::STORE_ERROR);
366 }
367 
StoreErrorForAllIds(const std::vector<int64_t> & item_ids)368 UpdateRequestsResult StoreErrorForAllIds(const std::vector<int64_t>& item_ids) {
369   return StoreUpdateResultForIds(StoreState::LOADED, item_ids,
370                                  ItemActionStatus::STORE_ERROR);
371 }
372 
InitDatabaseSync(sql::Database * db,const base::FilePath & path)373 bool InitDatabaseSync(sql::Database* db, const base::FilePath& path) {
374   db->set_page_size(4096);
375   db->set_cache_size(500);
376   db->set_histogram_tag("BackgroundRequestQueue");
377   db->set_exclusive_locking();
378 
379   if (path.empty()) {
380     if (!db->OpenInMemory())
381       return false;
382   } else {
383     base::File::Error err;
384     if (!base::CreateDirectoryAndGetError(path.DirName(), &err))
385       return false;
386     if (!db->Open(path))
387       return false;
388   }
389   db->Preload();
390 
391   return CreateSchemaSync(db);
392 }
393 
394 base::Optional<std::vector<std::unique_ptr<SavePageRequest>>>
GetAllRequestsSync(sql::Database * db)395 GetAllRequestsSync(sql::Database* db) {
396   static const char kSql[] =
397       "SELECT " REQUEST_QUEUE_FIELDS " FROM " REQUEST_QUEUE_TABLE_NAME;
398   sql::Statement statement(db->GetCachedStatement(SQL_FROM_HERE, kSql));
399   std::vector<std::unique_ptr<SavePageRequest>> requests;
400   while (statement.Step())
401     requests.push_back(MakeSavePageRequest(statement));
402   if (!statement.Succeeded())
403     return base::nullopt;
404   return requests;
405 }
406 
407 // Calls |callback| with the result of |requests|.
InvokeGetRequestsCallback(RequestQueueStore::GetRequestsCallback callback,base::Optional<std::vector<std::unique_ptr<SavePageRequest>>> requests)408 void InvokeGetRequestsCallback(
409     RequestQueueStore::GetRequestsCallback callback,
410     base::Optional<std::vector<std::unique_ptr<SavePageRequest>>> requests) {
411   if (requests) {
412     std::move(callback).Run(true, std::move(requests).value());
413   } else {
414     std::move(callback).Run(false, {});
415   }
416 }
417 
GetRequestsByIdsSync(sql::Database * db,const std::vector<int64_t> & request_ids)418 UpdateRequestsResult GetRequestsByIdsSync(
419     sql::Database* db,
420     const std::vector<int64_t>& request_ids) {
421   UpdateRequestsResult result(StoreState::LOADED);
422 
423   // If you create a transaction but don't Commit() it is automatically
424   // rolled back by its destructor when it falls out of scope.
425   sql::Transaction transaction(db);
426   if (!transaction.Begin())
427     return StoreErrorForAllIds(request_ids);
428 
429   // Make sure not to include the same request multiple times, preserving the
430   // order of non-duplicated IDs in the result.
431   std::unordered_set<int64_t> processed_ids;
432   for (int64_t request_id : request_ids) {
433     if (!processed_ids.insert(request_id).second)
434       continue;
435     std::unique_ptr<SavePageRequest> request =
436         GetOneRequestSync(db, request_id);
437     if (request)
438       result.updated_items.push_back(*request);
439     ItemActionStatus status =
440         request ? ItemActionStatus::SUCCESS : ItemActionStatus::NOT_FOUND;
441     result.item_statuses.emplace_back(request_id, status);
442   }
443 
444   if (!transaction.Commit())
445     return StoreErrorForAllIds(request_ids);
446 
447   return result;
448 }
449 
AddRequestSync(sql::Database * db,const SavePageRequest & request,const RequestQueueStore::AddOptions & options)450 AddRequestResult AddRequestSync(sql::Database* db,
451                                 const SavePageRequest& request,
452                                 const RequestQueueStore::AddOptions& options) {
453   // If we need to check preconditions, read the set of active requests and
454   // check preconditions.
455   if (options.maximum_in_flight_requests_for_namespace > 0 ||
456       options.disallow_duplicate_requests) {
457     base::Optional<std::vector<std::unique_ptr<SavePageRequest>>> requests =
458         GetAllRequestsSync(db);
459     if (!requests)
460       return AddRequestResult::STORE_FAILURE;
461 
462     if (options.maximum_in_flight_requests_for_namespace > 0) {
463       int existing_requests = 0;
464       for (const std::unique_ptr<SavePageRequest>& existing_request :
465            requests.value()) {
466         if (existing_request->client_id().name_space ==
467             request.client_id().name_space)
468           ++existing_requests;
469       }
470       if (existing_requests >= options.maximum_in_flight_requests_for_namespace)
471         return AddRequestResult::REQUEST_QUOTA_HIT;
472     }
473 
474     if (options.disallow_duplicate_requests) {
475       for (const std::unique_ptr<SavePageRequest>& existing_request :
476            requests.value()) {
477         if (existing_request->client_id().name_space ==
478                 request.client_id().name_space &&
479             EqualsIgnoringFragment(existing_request->url(), request.url()))
480           return AddRequestResult::DUPLICATE_URL;
481       }
482     }
483   }
484   return InsertSync(db, request);
485 }
486 
UpdateRequestsSync(sql::Database * db,const std::vector<SavePageRequest> & requests)487 UpdateRequestsResult UpdateRequestsSync(
488     sql::Database* db,
489     const std::vector<SavePageRequest>& requests) {
490   sql::Transaction transaction(db);
491   if (!transaction.Begin())
492     return StoreErrorForAllRequests(requests);
493 
494   UpdateRequestsResult result(StoreState::LOADED);
495   for (const auto& request : requests) {
496     ItemActionStatus status = UpdateSync(db, request);
497     result.item_statuses.emplace_back(request.request_id(), status);
498     if (status == ItemActionStatus::SUCCESS)
499       result.updated_items.push_back(request);
500   }
501 
502   if (!transaction.Commit())
503     return StoreErrorForAllRequests(requests);
504 
505   return result;
506 }
507 
RemoveRequestsSync(sql::Database * db,const std::vector<int64_t> & request_ids)508 UpdateRequestsResult RemoveRequestsSync(
509     sql::Database* db,
510     const std::vector<int64_t>& request_ids) {
511   UpdateRequestsResult result(StoreState::LOADED);
512 
513   // If you create a transaction but don't Commit() it is automatically
514   // rolled back by its destructor when it falls out of scope.
515   sql::Transaction transaction(db);
516   if (!transaction.Begin()) {
517     return StoreErrorForAllIds(request_ids);
518   }
519 
520   // Read the request before we delete it, and if the delete worked, put it on
521   // the queue of requests that got deleted.
522   for (int64_t request_id : request_ids) {
523     std::unique_ptr<SavePageRequest> request =
524         GetOneRequestSync(db, request_id);
525     ItemActionStatus status = DeleteRequestByIdSync(db, request_id);
526     result.item_statuses.push_back(std::make_pair(request_id, status));
527     if (status == ItemActionStatus::SUCCESS)
528       result.updated_items.push_back(*request);
529   }
530 
531   if (!transaction.Commit()) {
532     return StoreErrorForAllIds(request_ids);
533   }
534 
535   return result;
536 }
537 
ResetSync(sql::Database * db,const base::FilePath & db_file_path)538 bool ResetSync(sql::Database* db, const base::FilePath& db_file_path) {
539   // This method deletes the content of the whole store and reinitializes it.
540   bool success = true;
541   if (db) {
542     success = db->Raze();
543     db->Close();
544   }
545   return base::DeletePathRecursively(db_file_path) && success;
546 }
547 
SetAutoFetchNotificationStateSync(sql::Database * db,int64_t request_id,SavePageRequest::AutoFetchNotificationState state)548 bool SetAutoFetchNotificationStateSync(
549     sql::Database* db,
550     int64_t request_id,
551     SavePageRequest::AutoFetchNotificationState state) {
552   std::unique_ptr<SavePageRequest> request = GetOneRequestSync(db, request_id);
553   if (!request)
554     return false;
555 
556   request->set_auto_fetch_notification_state(state);
557   return UpdateSync(db, *request) == ItemActionStatus::SUCCESS;
558 }
559 
RemoveRequestsIfSync(sql::Database * db,const base::RepeatingCallback<bool (const SavePageRequest &)> & remove_predicate)560 UpdateRequestsResult RemoveRequestsIfSync(
561     sql::Database* db,
562     const base::RepeatingCallback<bool(const SavePageRequest&)>&
563         remove_predicate) {
564   base::Optional<std::vector<std::unique_ptr<SavePageRequest>>> requests =
565       GetAllRequestsSync(db);
566   if (!requests)
567     return UpdateRequestsResult(StoreState::LOADED);
568 
569   std::vector<int64_t> ids_to_remove;
570   for (const std::unique_ptr<SavePageRequest>& request : requests.value()) {
571     if (remove_predicate.Run(*request))
572       ids_to_remove.push_back(request->request_id());
573   }
574   return RemoveRequestsSync(db, ids_to_remove);
575 }
576 
577 }  // anonymous namespace
578 
RequestQueueStore(scoped_refptr<base::SequencedTaskRunner> background_task_runner)579 RequestQueueStore::RequestQueueStore(
580     scoped_refptr<base::SequencedTaskRunner> background_task_runner)
581     : background_task_runner_(std::move(background_task_runner)),
582       state_(StoreState::NOT_LOADED) {}
583 
RequestQueueStore(scoped_refptr<base::SequencedTaskRunner> background_task_runner,const base::FilePath & path)584 RequestQueueStore::RequestQueueStore(
585     scoped_refptr<base::SequencedTaskRunner> background_task_runner,
586     const base::FilePath& path)
587     : RequestQueueStore(background_task_runner) {
588   DCHECK(!path.empty());
589   db_file_path_ = path.AppendASCII("RequestQueue.db");
590 }
591 
~RequestQueueStore()592 RequestQueueStore::~RequestQueueStore() {
593   if (db_)
594     background_task_runner_->DeleteSoon(FROM_HERE, db_.release());
595 }
596 
Initialize(InitializeCallback callback)597 void RequestQueueStore::Initialize(InitializeCallback callback) {
598   DCHECK(!db_);
599   db_.reset(new sql::Database());
600 
601   base::PostTaskAndReplyWithResult(
602       background_task_runner_.get(), FROM_HERE,
603       base::BindOnce(&InitDatabaseSync, db_.get(), db_file_path_),
604       base::BindOnce(&RequestQueueStore::OnOpenConnectionDone,
605                      weak_ptr_factory_.GetWeakPtr(), std::move(callback)));
606 }
607 
GetRequests(GetRequestsCallback callback)608 void RequestQueueStore::GetRequests(GetRequestsCallback callback) {
609   DCHECK(db_);
610   if (!CheckDb()) {
611     std::vector<std::unique_ptr<SavePageRequest>> requests;
612     base::ThreadTaskRunnerHandle::Get()->PostTask(
613         FROM_HERE,
614         base::BindOnce(std::move(callback), false, std::move(requests)));
615     return;
616   }
617   base::PostTaskAndReplyWithResult(
618       background_task_runner_.get(), FROM_HERE,
619       base::BindOnce(&GetAllRequestsSync, db_.get()),
620       base::BindOnce(&InvokeGetRequestsCallback, std::move(callback)));
621 }
622 
GetRequestsByIds(const std::vector<int64_t> & request_ids,UpdateCallback callback)623 void RequestQueueStore::GetRequestsByIds(
624     const std::vector<int64_t>& request_ids,
625     UpdateCallback callback) {
626   if (!CheckDb()) {
627     base::ThreadTaskRunnerHandle::Get()->PostTask(
628         FROM_HERE,
629         base::BindOnce(std::move(callback),
630                        StoreUpdateResultForIds(StoreState::LOADED, request_ids,
631                                                ItemActionStatus::STORE_ERROR)));
632     return;
633   }
634 
635   base::PostTaskAndReplyWithResult(
636       background_task_runner_.get(), FROM_HERE,
637       base::BindOnce(&GetRequestsByIdsSync, db_.get(), request_ids),
638       std::move(callback));
639 }
640 
AddRequest(const SavePageRequest & request,AddOptions options,AddCallback callback)641 void RequestQueueStore::AddRequest(const SavePageRequest& request,
642                                    AddOptions options,
643                                    AddCallback callback) {
644   if (!CheckDb()) {
645     base::ThreadTaskRunnerHandle::Get()->PostTask(
646         FROM_HERE,
647         base::BindOnce(std::move(callback), AddRequestResult::STORE_FAILURE));
648     return;
649   }
650 
651   base::PostTaskAndReplyWithResult(
652       background_task_runner_.get(), FROM_HERE,
653       base::BindOnce(&AddRequestSync, db_.get(), request, options),
654       std::move(callback));
655 }
656 
UpdateRequests(const std::vector<SavePageRequest> & requests,UpdateCallback callback)657 void RequestQueueStore::UpdateRequests(
658     const std::vector<SavePageRequest>& requests,
659     UpdateCallback callback) {
660   if (!CheckDb()) {
661     base::ThreadTaskRunnerHandle::Get()->PostTask(
662         FROM_HERE, base::BindOnce(std::move(callback),
663                                   StoreErrorForAllRequests(requests)));
664     return;
665   }
666 
667   base::PostTaskAndReplyWithResult(
668       background_task_runner_.get(), FROM_HERE,
669       base::BindOnce(&UpdateRequestsSync, db_.get(), requests),
670       std::move(callback));
671 }
672 
RemoveRequests(const std::vector<int64_t> & request_ids,UpdateCallback callback)673 void RequestQueueStore::RemoveRequests(const std::vector<int64_t>& request_ids,
674                                        UpdateCallback callback) {
675   if (!CheckDb()) {
676     base::ThreadTaskRunnerHandle::Get()->PostTask(
677         FROM_HERE,
678         base::BindOnce(std::move(callback),
679                        StoreUpdateResultForIds(StoreState::LOADED, request_ids,
680                                                ItemActionStatus::STORE_ERROR)));
681     return;
682   }
683 
684   base::PostTaskAndReplyWithResult(
685       background_task_runner_.get(), FROM_HERE,
686       base::BindOnce(&RemoveRequestsSync, db_.get(), request_ids),
687       std::move(callback));
688 }
689 
RemoveRequestsIf(const base::RepeatingCallback<bool (const SavePageRequest &)> & remove_predicate,UpdateCallback callback)690 void RequestQueueStore::RemoveRequestsIf(
691     const base::RepeatingCallback<bool(const SavePageRequest&)>&
692         remove_predicate,
693     UpdateCallback callback) {
694   base::PostTaskAndReplyWithResult(
695       background_task_runner_.get(), FROM_HERE,
696       base::BindOnce(RemoveRequestsIfSync, db_.get(), remove_predicate),
697       std::move(callback));
698 }
699 
SetAutoFetchNotificationState(int64_t request_id,SavePageRequest::AutoFetchNotificationState state,base::OnceCallback<void (bool updated)> callback)700 void RequestQueueStore::SetAutoFetchNotificationState(
701     int64_t request_id,
702     SavePageRequest::AutoFetchNotificationState state,
703     base::OnceCallback<void(bool updated)> callback) {
704   base::PostTaskAndReplyWithResult(
705       background_task_runner_.get(), FROM_HERE,
706       base::BindOnce(SetAutoFetchNotificationStateSync, db_.get(), request_id,
707                      state),
708       std::move(callback));
709 }
710 
Reset(ResetCallback callback)711 void RequestQueueStore::Reset(ResetCallback callback) {
712   base::PostTaskAndReplyWithResult(
713       background_task_runner_.get(), FROM_HERE,
714       base::BindOnce(ResetSync, db_.get(), db_file_path_),
715       base::BindOnce(&RequestQueueStore::OnResetDone,
716                      weak_ptr_factory_.GetWeakPtr(), std::move(callback)));
717 }
718 
state() const719 StoreState RequestQueueStore::state() const {
720   return state_;
721 }
722 
SetStateForTesting(StoreState state,bool reset_db)723 void RequestQueueStore::SetStateForTesting(StoreState state, bool reset_db) {
724   state_ = state;
725   if (reset_db)
726     db_.reset(nullptr);
727 }
728 
OnOpenConnectionDone(InitializeCallback callback,bool success)729 void RequestQueueStore::OnOpenConnectionDone(InitializeCallback callback,
730                                              bool success) {
731   DCHECK(db_);
732   state_ = success ? StoreState::LOADED : StoreState::FAILED_LOADING;
733   std::move(callback).Run(success);
734 }
735 
OnResetDone(ResetCallback callback,bool success)736 void RequestQueueStore::OnResetDone(ResetCallback callback, bool success) {
737   state_ = success ? StoreState::NOT_LOADED : StoreState::FAILED_RESET;
738   db_.reset();
739   base::ThreadTaskRunnerHandle::Get()->PostTask(
740       FROM_HERE, base::BindOnce(std::move(callback), success));
741 }
742 
CheckDb() const743 bool RequestQueueStore::CheckDb() const {
744   return db_ && state_ == StoreState::LOADED;
745 }
746 
747 }  // namespace offline_pages
748