1 // Copyright 2013 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #include "sql/test/test_helpers.h"
6 
7 #include <stddef.h>
8 #include <stdint.h>
9 
10 #include <memory>
11 #include <string>
12 
13 #include "base/check.h"
14 #include "base/check_op.h"
15 #include "base/files/file_util.h"
16 #include "base/files/scoped_file.h"
17 #include "base/threading/thread_restrictions.h"
18 #include "sql/database.h"
19 #include "sql/statement.h"
20 #include "testing/gtest/include/gtest/gtest.h"
21 #include "third_party/sqlite/sqlite3.h"
22 
23 namespace {
24 
CountSQLItemsOfType(sql::Database * db,const char * type)25 size_t CountSQLItemsOfType(sql::Database* db, const char* type) {
26   static const char kTypeSQL[] =
27       "SELECT COUNT(*) FROM sqlite_master WHERE type = ?";
28   sql::Statement s(db->GetUniqueStatement(kTypeSQL));
29   s.BindCString(0, type);
30   EXPECT_TRUE(s.Step());
31   return s.ColumnInt(0);
32 }
33 
34 // Get page size for the database.
GetPageSize(sql::Database * db,int * page_size)35 bool GetPageSize(sql::Database* db, int* page_size) {
36   sql::Statement s(db->GetUniqueStatement("PRAGMA page_size"));
37   if (!s.Step())
38     return false;
39   *page_size = s.ColumnInt(0);
40   return true;
41 }
42 
43 // Get |name|'s root page number in the database.
GetRootPage(sql::Database * db,const char * name,int * page_number)44 bool GetRootPage(sql::Database* db, const char* name, int* page_number) {
45   static const char kPageSql[] =
46       "SELECT rootpage FROM sqlite_master WHERE name = ?";
47   sql::Statement s(db->GetUniqueStatement(kPageSql));
48   s.BindString(0, name);
49   if (!s.Step())
50     return false;
51   *page_number = s.ColumnInt(0);
52   return true;
53 }
54 
55 // Helper for reading a number from the SQLite header.
56 // See base/big_endian.h.
ReadBigEndian(unsigned char * buf,size_t bytes)57 unsigned ReadBigEndian(unsigned char* buf, size_t bytes) {
58   unsigned r = buf[0];
59   for (size_t i = 1; i < bytes; i++) {
60     r <<= 8;
61     r |= buf[i];
62   }
63   return r;
64 }
65 
66 // Helper for writing a number to the SQLite header.
WriteBigEndian(unsigned val,unsigned char * buf,size_t bytes)67 void WriteBigEndian(unsigned val, unsigned char* buf, size_t bytes) {
68   for (size_t i = 0; i < bytes; i++) {
69     buf[bytes - i - 1] = (val & 0xFF);
70     val >>= 8;
71   }
72 }
73 
IsWalDatabase(const base::FilePath & db_path)74 bool IsWalDatabase(const base::FilePath& db_path) {
75   // The SQLite header is documented at:
76   //   https://www.sqlite.org/fileformat.html#the_database_header
77   //
78   // Read the entire header.
79   constexpr int kHeaderSize = 100;
80   uint8_t header[kHeaderSize];
81   base::ReadFile(db_path, reinterpret_cast<char*>(header), sizeof(header));
82   constexpr int kWriteVersionHeaderOffset = 18;
83   constexpr int kReadVersionHeaderOffset = 19;
84   // If the read version is unsupported, we can't rely on our ability to
85   // interpret anything else in the header.
86   DCHECK_LE(header[kReadVersionHeaderOffset], 2)
87       << "Unsupported SQLite file format";
88   return header[kWriteVersionHeaderOffset] == 2;
89 }
90 
91 }  // namespace
92 
93 namespace sql {
94 namespace test {
95 
CorruptSizeInHeader(const base::FilePath & db_path)96 bool CorruptSizeInHeader(const base::FilePath& db_path) {
97   if (IsWalDatabase(db_path)) {
98     // Checkpoint the WAL file in Truncate mode before corrupting to ensure that
99     // any future transaction always touches the DB file and not just the WAL
100     // file.
101     base::ScopedAllowBlockingForTesting allow_blocking;
102     // TODO: This function doesn't reliably work if connections to the DB are
103     // still open. Change any uses to ensure that we close all database
104     // connections before calling this function.
105     sql::Database db({.exclusive_locking = false, .wal_mode = true});
106     if (!db.Open(db_path))
107       return false;
108     int wal_log_size = 0;
109     int checkpointed_frame_count = 0;
110     int truncate_result = sqlite3_wal_checkpoint_v2(
111         db.db(InternalApiToken()), /*zDb=*/nullptr, SQLITE_CHECKPOINT_TRUNCATE,
112         &wal_log_size, &checkpointed_frame_count);
113     // A successful checkpoint in truncate mode sets these to zero.
114     DCHECK(wal_log_size == 0);
115     DCHECK(checkpointed_frame_count == 0);
116     if (truncate_result != SQLITE_OK)
117       return false;
118     db.Close();
119   }
120 
121   // See http://www.sqlite.org/fileformat.html#database_header
122   const size_t kHeaderSize = 100;
123 
124   unsigned char header[kHeaderSize];
125 
126   base::ScopedFILE file(base::OpenFile(db_path, "rb+"));
127   if (!file.get())
128     return false;
129 
130   if (0 != fseek(file.get(), 0, SEEK_SET))
131     return false;
132   if (1u != fread(header, sizeof(header), 1, file.get()))
133     return false;
134 
135   int64_t db_size = 0;
136   if (!base::GetFileSize(db_path, &db_size))
137     return false;
138 
139   CorruptSizeInHeaderMemory(header, db_size);
140 
141   if (0 != fseek(file.get(), 0, SEEK_SET))
142     return false;
143   if (1u != fwrite(header, sizeof(header), 1, file.get()))
144     return false;
145 
146   return true;
147 }
148 
CorruptSizeInHeaderWithLock(const base::FilePath & db_path)149 bool CorruptSizeInHeaderWithLock(const base::FilePath& db_path) {
150   base::ScopedAllowBlockingForTesting allow_blocking;
151   sql::Database db;
152   if (!db.Open(db_path))
153     return false;
154 
155   // Prevent anyone else from using the database.  The transaction is
156   // rolled back when |db| is destroyed.
157   if (!db.Execute("BEGIN EXCLUSIVE"))
158     return false;
159 
160   return CorruptSizeInHeader(db_path);
161 }
162 
CorruptSizeInHeaderMemory(unsigned char * header,int64_t db_size)163 void CorruptSizeInHeaderMemory(unsigned char* header, int64_t db_size) {
164   const size_t kPageSizeOffset = 16;
165   const size_t kFileChangeCountOffset = 24;
166   const size_t kPageCountOffset = 28;
167   const size_t kVersionValidForOffset = 92;  // duplicate kFileChangeCountOffset
168 
169   const unsigned page_size = ReadBigEndian(header + kPageSizeOffset, 2);
170 
171   // One larger than the expected size.
172   const unsigned page_count =
173       static_cast<unsigned>((db_size + page_size) / page_size);
174   WriteBigEndian(page_count, header + kPageCountOffset, 4);
175 
176   // Update change count so outstanding readers know the info changed.
177   // Both spots must match for the page count to be considered valid.
178   unsigned change_count = ReadBigEndian(header + kFileChangeCountOffset, 4);
179   WriteBigEndian(change_count + 1, header + kFileChangeCountOffset, 4);
180   WriteBigEndian(change_count + 1, header + kVersionValidForOffset, 4);
181 }
182 
CorruptTableOrIndex(const base::FilePath & db_path,const char * tree_name,const char * update_sql)183 bool CorruptTableOrIndex(const base::FilePath& db_path,
184                          const char* tree_name,
185                          const char* update_sql) {
186   sql::Database db;
187   if (!db.Open(db_path))
188     return false;
189 
190   int page_size = db.page_size();
191   if (!GetPageSize(&db, &page_size))
192     return false;
193 
194   int page_number = 0;
195   if (!GetRootPage(&db, tree_name, &page_number))
196     return false;
197 
198   // SQLite uses 1-based page numbering.
199   const long int page_ofs = (page_number - 1) * page_size;
200   std::unique_ptr<char[]> page_buf(new char[page_size]);
201 
202   // Get the page into page_buf.
203   base::ScopedFILE file(base::OpenFile(db_path, "rb+"));
204   if (!file.get())
205     return false;
206   if (0 != fseek(file.get(), page_ofs, SEEK_SET))
207     return false;
208   if (1u != fread(page_buf.get(), page_size, 1, file.get()))
209     return false;
210 
211   // Require the page to be a leaf node.  A multilevel tree would be
212   // very hard to restore correctly.
213   if (page_buf[0] != 0xD && page_buf[0] != 0xA)
214     return false;
215 
216   // The update has to work, and make changes.
217   if (!db.Execute(update_sql))
218     return false;
219   if (db.GetLastChangeCount() == 0)
220     return false;
221 
222   // Ensure that the database is fully flushed.
223   db.Close();
224 
225   // Check that the stored page actually changed.  This catches usage
226   // errors where |update_sql| is not related to |tree_name|.
227   std::unique_ptr<char[]> check_page_buf(new char[page_size]);
228   // The on-disk data should have changed.
229   if (0 != fflush(file.get()))
230     return false;
231   if (0 != fseek(file.get(), page_ofs, SEEK_SET))
232     return false;
233   if (1u != fread(check_page_buf.get(), page_size, 1, file.get()))
234     return false;
235   if (!memcmp(check_page_buf.get(), page_buf.get(), page_size))
236     return false;
237 
238   // Put the original page back.
239   if (0 != fseek(file.get(), page_ofs, SEEK_SET))
240     return false;
241   if (1u != fwrite(page_buf.get(), page_size, 1, file.get()))
242     return false;
243 
244   return true;
245 }
246 
CountSQLTables(sql::Database * db)247 size_t CountSQLTables(sql::Database* db) {
248   return CountSQLItemsOfType(db, "table");
249 }
250 
CountSQLIndices(sql::Database * db)251 size_t CountSQLIndices(sql::Database* db) {
252   return CountSQLItemsOfType(db, "index");
253 }
254 
CountTableColumns(sql::Database * db,const char * table)255 size_t CountTableColumns(sql::Database* db, const char* table) {
256   // TODO(shess): sql::Database::QuoteForSQL() would make sense.
257   std::string quoted_table;
258   {
259     static const char kQuoteSQL[] = "SELECT quote(?)";
260     sql::Statement s(db->GetUniqueStatement(kQuoteSQL));
261     s.BindCString(0, table);
262     EXPECT_TRUE(s.Step());
263     quoted_table = s.ColumnString(0);
264   }
265 
266   std::string sql = "PRAGMA table_info(" + quoted_table + ")";
267   sql::Statement s(db->GetUniqueStatement(sql.c_str()));
268   size_t rows = 0;
269   while (s.Step()) {
270     ++rows;
271   }
272   EXPECT_TRUE(s.Succeeded());
273   return rows;
274 }
275 
CountTableRows(sql::Database * db,const char * table,size_t * count)276 bool CountTableRows(sql::Database* db, const char* table, size_t* count) {
277   // TODO(shess): Table should probably be quoted with [] or "".  See
278   // http://www.sqlite.org/lang_keywords.html .  Meanwhile, odd names
279   // will throw an error.
280   std::string sql = "SELECT COUNT(*) FROM ";
281   sql += table;
282   sql::Statement s(db->GetUniqueStatement(sql.c_str()));
283   if (!s.Step())
284     return false;
285 
286   *count = static_cast<size_t>(s.ColumnInt64(0));
287   return true;
288 }
289 
CreateDatabaseFromSQL(const base::FilePath & db_path,const base::FilePath & sql_path)290 bool CreateDatabaseFromSQL(const base::FilePath& db_path,
291                            const base::FilePath& sql_path) {
292   if (base::PathExists(db_path) || !base::PathExists(sql_path))
293     return false;
294 
295   std::string sql;
296   if (!base::ReadFileToString(sql_path, &sql))
297     return false;
298 
299   sql::Database db;
300   if (!db.Open(db_path))
301     return false;
302 
303   // TODO(shess): Android defaults to auto_vacuum mode.
304   // Unfortunately, this makes certain kinds of tests which manipulate
305   // the raw database hard/impossible to write.
306   // http://crbug.com/307303 is for exploring this test issue.
307   ignore_result(db.Execute("PRAGMA auto_vacuum = 0"));
308 
309   return db.Execute(sql.c_str());
310 }
311 
IntegrityCheck(sql::Database * db)312 std::string IntegrityCheck(sql::Database* db) {
313   sql::Statement statement(db->GetUniqueStatement("PRAGMA integrity_check"));
314 
315   // SQLite should always return a row of data.
316   EXPECT_TRUE(statement.Step());
317 
318   return statement.ColumnString(0);
319 }
320 
ExecuteWithResult(sql::Database * db,const char * sql)321 std::string ExecuteWithResult(sql::Database* db, const char* sql) {
322   sql::Statement s(db->GetUniqueStatement(sql));
323   return s.Step() ? s.ColumnString(0) : std::string();
324 }
325 
ExecuteWithResults(sql::Database * db,const char * sql,const char * column_sep,const char * row_sep)326 std::string ExecuteWithResults(sql::Database* db,
327                                const char* sql,
328                                const char* column_sep,
329                                const char* row_sep) {
330   sql::Statement s(db->GetUniqueStatement(sql));
331   std::string ret;
332   while (s.Step()) {
333     if (!ret.empty())
334       ret += row_sep;
335     for (int i = 0; i < s.ColumnCount(); ++i) {
336       if (i > 0)
337         ret += column_sep;
338       ret += s.ColumnString(i);
339     }
340   }
341   return ret;
342 }
343 
GetPageCount(sql::Database * db)344 int GetPageCount(sql::Database* db) {
345   sql::Statement statement(db->GetUniqueStatement("PRAGMA page_count"));
346   CHECK(statement.Step());
347   return statement.ColumnInt(0);
348 }
349 
350 // static
Create(sql::Database * db,const std::string & db_name,const std::string & table_name,const std::string & column_name)351 ColumnInfo ColumnInfo::Create(sql::Database* db,
352                               const std::string& db_name,
353                               const std::string& table_name,
354                               const std::string& column_name) {
355   sqlite3* const sqlite3_db = db->db(InternalApiToken());
356 
357   const char* data_type;
358   const char* collation_sequence;
359   int not_null;
360   int primary_key;
361   int auto_increment;
362   int status = sqlite3_table_column_metadata(
363       sqlite3_db, db_name.c_str(), table_name.c_str(), column_name.c_str(),
364       &data_type, &collation_sequence, &not_null, &primary_key,
365       &auto_increment);
366   CHECK_EQ(status, SQLITE_OK) << "SQLite error: " << sqlite3_errmsg(sqlite3_db);
367 
368   // This happens when virtual tables report no type information.
369   if (data_type == nullptr)
370     data_type = "(nullptr)";
371 
372   return {std::string(data_type), std::string(collation_sequence),
373           not_null != 0, primary_key != 0, auto_increment != 0};
374 }
375 
376 }  // namespace test
377 }  // namespace sql
378