1 // Copyright 2013 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 #include "sql/test/test_helpers.h"
6
7 #include <stddef.h>
8 #include <stdint.h>
9
10 #include <memory>
11 #include <string>
12
13 #include "base/check.h"
14 #include "base/check_op.h"
15 #include "base/files/file_util.h"
16 #include "base/files/scoped_file.h"
17 #include "base/threading/thread_restrictions.h"
18 #include "sql/database.h"
19 #include "sql/statement.h"
20 #include "testing/gtest/include/gtest/gtest.h"
21 #include "third_party/sqlite/sqlite3.h"
22
23 namespace {
24
CountSQLItemsOfType(sql::Database * db,const char * type)25 size_t CountSQLItemsOfType(sql::Database* db, const char* type) {
26 static const char kTypeSQL[] =
27 "SELECT COUNT(*) FROM sqlite_master WHERE type = ?";
28 sql::Statement s(db->GetUniqueStatement(kTypeSQL));
29 s.BindCString(0, type);
30 EXPECT_TRUE(s.Step());
31 return s.ColumnInt(0);
32 }
33
34 // Get page size for the database.
GetPageSize(sql::Database * db,int * page_size)35 bool GetPageSize(sql::Database* db, int* page_size) {
36 sql::Statement s(db->GetUniqueStatement("PRAGMA page_size"));
37 if (!s.Step())
38 return false;
39 *page_size = s.ColumnInt(0);
40 return true;
41 }
42
43 // Get |name|'s root page number in the database.
GetRootPage(sql::Database * db,const char * name,int * page_number)44 bool GetRootPage(sql::Database* db, const char* name, int* page_number) {
45 static const char kPageSql[] =
46 "SELECT rootpage FROM sqlite_master WHERE name = ?";
47 sql::Statement s(db->GetUniqueStatement(kPageSql));
48 s.BindString(0, name);
49 if (!s.Step())
50 return false;
51 *page_number = s.ColumnInt(0);
52 return true;
53 }
54
55 // Helper for reading a number from the SQLite header.
56 // See base/big_endian.h.
ReadBigEndian(unsigned char * buf,size_t bytes)57 unsigned ReadBigEndian(unsigned char* buf, size_t bytes) {
58 unsigned r = buf[0];
59 for (size_t i = 1; i < bytes; i++) {
60 r <<= 8;
61 r |= buf[i];
62 }
63 return r;
64 }
65
66 // Helper for writing a number to the SQLite header.
WriteBigEndian(unsigned val,unsigned char * buf,size_t bytes)67 void WriteBigEndian(unsigned val, unsigned char* buf, size_t bytes) {
68 for (size_t i = 0; i < bytes; i++) {
69 buf[bytes - i - 1] = (val & 0xFF);
70 val >>= 8;
71 }
72 }
73
IsWalDatabase(const base::FilePath & db_path)74 bool IsWalDatabase(const base::FilePath& db_path) {
75 // The SQLite header is documented at:
76 // https://www.sqlite.org/fileformat.html#the_database_header
77 //
78 // Read the entire header.
79 constexpr int kHeaderSize = 100;
80 uint8_t header[kHeaderSize];
81 base::ReadFile(db_path, reinterpret_cast<char*>(header), sizeof(header));
82 constexpr int kWriteVersionHeaderOffset = 18;
83 constexpr int kReadVersionHeaderOffset = 19;
84 // If the read version is unsupported, we can't rely on our ability to
85 // interpret anything else in the header.
86 DCHECK_LE(header[kReadVersionHeaderOffset], 2)
87 << "Unsupported SQLite file format";
88 return header[kWriteVersionHeaderOffset] == 2;
89 }
90
91 } // namespace
92
93 namespace sql {
94 namespace test {
95
CorruptSizeInHeader(const base::FilePath & db_path)96 bool CorruptSizeInHeader(const base::FilePath& db_path) {
97 if (IsWalDatabase(db_path)) {
98 // Checkpoint the WAL file in Truncate mode before corrupting to ensure that
99 // any future transaction always touches the DB file and not just the WAL
100 // file.
101 base::ScopedAllowBlockingForTesting allow_blocking;
102 // TODO: This function doesn't reliably work if connections to the DB are
103 // still open. Change any uses to ensure that we close all database
104 // connections before calling this function.
105 sql::Database db({.exclusive_locking = false, .wal_mode = true});
106 if (!db.Open(db_path))
107 return false;
108 int wal_log_size = 0;
109 int checkpointed_frame_count = 0;
110 int truncate_result = sqlite3_wal_checkpoint_v2(
111 db.db(InternalApiToken()), /*zDb=*/nullptr, SQLITE_CHECKPOINT_TRUNCATE,
112 &wal_log_size, &checkpointed_frame_count);
113 // A successful checkpoint in truncate mode sets these to zero.
114 DCHECK(wal_log_size == 0);
115 DCHECK(checkpointed_frame_count == 0);
116 if (truncate_result != SQLITE_OK)
117 return false;
118 db.Close();
119 }
120
121 // See http://www.sqlite.org/fileformat.html#database_header
122 const size_t kHeaderSize = 100;
123
124 unsigned char header[kHeaderSize];
125
126 base::ScopedFILE file(base::OpenFile(db_path, "rb+"));
127 if (!file.get())
128 return false;
129
130 if (0 != fseek(file.get(), 0, SEEK_SET))
131 return false;
132 if (1u != fread(header, sizeof(header), 1, file.get()))
133 return false;
134
135 int64_t db_size = 0;
136 if (!base::GetFileSize(db_path, &db_size))
137 return false;
138
139 CorruptSizeInHeaderMemory(header, db_size);
140
141 if (0 != fseek(file.get(), 0, SEEK_SET))
142 return false;
143 if (1u != fwrite(header, sizeof(header), 1, file.get()))
144 return false;
145
146 return true;
147 }
148
CorruptSizeInHeaderWithLock(const base::FilePath & db_path)149 bool CorruptSizeInHeaderWithLock(const base::FilePath& db_path) {
150 base::ScopedAllowBlockingForTesting allow_blocking;
151 sql::Database db;
152 if (!db.Open(db_path))
153 return false;
154
155 // Prevent anyone else from using the database. The transaction is
156 // rolled back when |db| is destroyed.
157 if (!db.Execute("BEGIN EXCLUSIVE"))
158 return false;
159
160 return CorruptSizeInHeader(db_path);
161 }
162
CorruptSizeInHeaderMemory(unsigned char * header,int64_t db_size)163 void CorruptSizeInHeaderMemory(unsigned char* header, int64_t db_size) {
164 const size_t kPageSizeOffset = 16;
165 const size_t kFileChangeCountOffset = 24;
166 const size_t kPageCountOffset = 28;
167 const size_t kVersionValidForOffset = 92; // duplicate kFileChangeCountOffset
168
169 const unsigned page_size = ReadBigEndian(header + kPageSizeOffset, 2);
170
171 // One larger than the expected size.
172 const unsigned page_count =
173 static_cast<unsigned>((db_size + page_size) / page_size);
174 WriteBigEndian(page_count, header + kPageCountOffset, 4);
175
176 // Update change count so outstanding readers know the info changed.
177 // Both spots must match for the page count to be considered valid.
178 unsigned change_count = ReadBigEndian(header + kFileChangeCountOffset, 4);
179 WriteBigEndian(change_count + 1, header + kFileChangeCountOffset, 4);
180 WriteBigEndian(change_count + 1, header + kVersionValidForOffset, 4);
181 }
182
CorruptTableOrIndex(const base::FilePath & db_path,const char * tree_name,const char * update_sql)183 bool CorruptTableOrIndex(const base::FilePath& db_path,
184 const char* tree_name,
185 const char* update_sql) {
186 sql::Database db;
187 if (!db.Open(db_path))
188 return false;
189
190 int page_size = db.page_size();
191 if (!GetPageSize(&db, &page_size))
192 return false;
193
194 int page_number = 0;
195 if (!GetRootPage(&db, tree_name, &page_number))
196 return false;
197
198 // SQLite uses 1-based page numbering.
199 const long int page_ofs = (page_number - 1) * page_size;
200 std::unique_ptr<char[]> page_buf(new char[page_size]);
201
202 // Get the page into page_buf.
203 base::ScopedFILE file(base::OpenFile(db_path, "rb+"));
204 if (!file.get())
205 return false;
206 if (0 != fseek(file.get(), page_ofs, SEEK_SET))
207 return false;
208 if (1u != fread(page_buf.get(), page_size, 1, file.get()))
209 return false;
210
211 // Require the page to be a leaf node. A multilevel tree would be
212 // very hard to restore correctly.
213 if (page_buf[0] != 0xD && page_buf[0] != 0xA)
214 return false;
215
216 // The update has to work, and make changes.
217 if (!db.Execute(update_sql))
218 return false;
219 if (db.GetLastChangeCount() == 0)
220 return false;
221
222 // Ensure that the database is fully flushed.
223 db.Close();
224
225 // Check that the stored page actually changed. This catches usage
226 // errors where |update_sql| is not related to |tree_name|.
227 std::unique_ptr<char[]> check_page_buf(new char[page_size]);
228 // The on-disk data should have changed.
229 if (0 != fflush(file.get()))
230 return false;
231 if (0 != fseek(file.get(), page_ofs, SEEK_SET))
232 return false;
233 if (1u != fread(check_page_buf.get(), page_size, 1, file.get()))
234 return false;
235 if (!memcmp(check_page_buf.get(), page_buf.get(), page_size))
236 return false;
237
238 // Put the original page back.
239 if (0 != fseek(file.get(), page_ofs, SEEK_SET))
240 return false;
241 if (1u != fwrite(page_buf.get(), page_size, 1, file.get()))
242 return false;
243
244 return true;
245 }
246
CountSQLTables(sql::Database * db)247 size_t CountSQLTables(sql::Database* db) {
248 return CountSQLItemsOfType(db, "table");
249 }
250
CountSQLIndices(sql::Database * db)251 size_t CountSQLIndices(sql::Database* db) {
252 return CountSQLItemsOfType(db, "index");
253 }
254
CountTableColumns(sql::Database * db,const char * table)255 size_t CountTableColumns(sql::Database* db, const char* table) {
256 // TODO(shess): sql::Database::QuoteForSQL() would make sense.
257 std::string quoted_table;
258 {
259 static const char kQuoteSQL[] = "SELECT quote(?)";
260 sql::Statement s(db->GetUniqueStatement(kQuoteSQL));
261 s.BindCString(0, table);
262 EXPECT_TRUE(s.Step());
263 quoted_table = s.ColumnString(0);
264 }
265
266 std::string sql = "PRAGMA table_info(" + quoted_table + ")";
267 sql::Statement s(db->GetUniqueStatement(sql.c_str()));
268 size_t rows = 0;
269 while (s.Step()) {
270 ++rows;
271 }
272 EXPECT_TRUE(s.Succeeded());
273 return rows;
274 }
275
CountTableRows(sql::Database * db,const char * table,size_t * count)276 bool CountTableRows(sql::Database* db, const char* table, size_t* count) {
277 // TODO(shess): Table should probably be quoted with [] or "". See
278 // http://www.sqlite.org/lang_keywords.html . Meanwhile, odd names
279 // will throw an error.
280 std::string sql = "SELECT COUNT(*) FROM ";
281 sql += table;
282 sql::Statement s(db->GetUniqueStatement(sql.c_str()));
283 if (!s.Step())
284 return false;
285
286 *count = static_cast<size_t>(s.ColumnInt64(0));
287 return true;
288 }
289
CreateDatabaseFromSQL(const base::FilePath & db_path,const base::FilePath & sql_path)290 bool CreateDatabaseFromSQL(const base::FilePath& db_path,
291 const base::FilePath& sql_path) {
292 if (base::PathExists(db_path) || !base::PathExists(sql_path))
293 return false;
294
295 std::string sql;
296 if (!base::ReadFileToString(sql_path, &sql))
297 return false;
298
299 sql::Database db;
300 if (!db.Open(db_path))
301 return false;
302
303 // TODO(shess): Android defaults to auto_vacuum mode.
304 // Unfortunately, this makes certain kinds of tests which manipulate
305 // the raw database hard/impossible to write.
306 // http://crbug.com/307303 is for exploring this test issue.
307 ignore_result(db.Execute("PRAGMA auto_vacuum = 0"));
308
309 return db.Execute(sql.c_str());
310 }
311
IntegrityCheck(sql::Database * db)312 std::string IntegrityCheck(sql::Database* db) {
313 sql::Statement statement(db->GetUniqueStatement("PRAGMA integrity_check"));
314
315 // SQLite should always return a row of data.
316 EXPECT_TRUE(statement.Step());
317
318 return statement.ColumnString(0);
319 }
320
ExecuteWithResult(sql::Database * db,const char * sql)321 std::string ExecuteWithResult(sql::Database* db, const char* sql) {
322 sql::Statement s(db->GetUniqueStatement(sql));
323 return s.Step() ? s.ColumnString(0) : std::string();
324 }
325
ExecuteWithResults(sql::Database * db,const char * sql,const char * column_sep,const char * row_sep)326 std::string ExecuteWithResults(sql::Database* db,
327 const char* sql,
328 const char* column_sep,
329 const char* row_sep) {
330 sql::Statement s(db->GetUniqueStatement(sql));
331 std::string ret;
332 while (s.Step()) {
333 if (!ret.empty())
334 ret += row_sep;
335 for (int i = 0; i < s.ColumnCount(); ++i) {
336 if (i > 0)
337 ret += column_sep;
338 ret += s.ColumnString(i);
339 }
340 }
341 return ret;
342 }
343
GetPageCount(sql::Database * db)344 int GetPageCount(sql::Database* db) {
345 sql::Statement statement(db->GetUniqueStatement("PRAGMA page_count"));
346 CHECK(statement.Step());
347 return statement.ColumnInt(0);
348 }
349
350 // static
Create(sql::Database * db,const std::string & db_name,const std::string & table_name,const std::string & column_name)351 ColumnInfo ColumnInfo::Create(sql::Database* db,
352 const std::string& db_name,
353 const std::string& table_name,
354 const std::string& column_name) {
355 sqlite3* const sqlite3_db = db->db(InternalApiToken());
356
357 const char* data_type;
358 const char* collation_sequence;
359 int not_null;
360 int primary_key;
361 int auto_increment;
362 int status = sqlite3_table_column_metadata(
363 sqlite3_db, db_name.c_str(), table_name.c_str(), column_name.c_str(),
364 &data_type, &collation_sequence, ¬_null, &primary_key,
365 &auto_increment);
366 CHECK_EQ(status, SQLITE_OK) << "SQLite error: " << sqlite3_errmsg(sqlite3_db);
367
368 // This happens when virtual tables report no type information.
369 if (data_type == nullptr)
370 data_type = "(nullptr)";
371
372 return {std::string(data_type), std::string(collation_sequence),
373 not_null != 0, primary_key != 0, auto_increment != 0};
374 }
375
376 } // namespace test
377 } // namespace sql
378