1 /* This Source Code Form is subject to the terms of the Mozilla Public
2 * License, v. 2.0. If a copy of the MPL was not distributed with this
3 * file, You can obtain one at http://mozilla.org/MPL/2.0/. */
4
5 /// Use this module to open a new SQLite database connection.
6 ///
7 /// Usage:
8 /// - Define a struct that implements ConnectionInitializer. This handles:
9 /// - Initializing the schema for a new database
10 /// - Upgrading the schema for an existing database
11 /// - Extra preparation/finishing steps, for example setting up SQLite functions
12 ///
13 /// - Call open_database() in your database constructor:
14 /// - The first method called is `prepare()`. This is executed outside of a transaction
15 /// and is suitable for executing pragmas (eg, `PRAGMA journal_mode=wal`), defining
16 /// functions, etc.
17 /// - If the database file is not present and the connection is writable, open_database()
18 /// will create a new DB and call init(), then finish(). If the connection is not
19 /// writable it will panic, meaning that if you support ReadOnly connections, they must
20 /// be created after a writable connection is open.
21 /// - If the database file exists and the connection is writable, open_database() will open
22 /// it and call prepare(), upgrade_from() for each upgrade that needs to be applied, then
23 /// finish(). As above, a read-only connection will panic if upgrades are necessary, so
24 /// you should ensure the first connection opened is writable.
25 /// - If the connection is not writable, `finish()` will be called (ie, `finish()`, like
26 /// `prepare()`, is called for all connections)
27 ///
28 /// See the autofill DB code for an example.
29 ///
30 use crate::ConnExt;
31 use rusqlite::{Connection, OpenFlags, Transaction, NO_PARAMS};
32 use std::path::Path;
33 use thiserror::Error;
34
35 #[derive(Error, Debug)]
36 pub enum Error {
37 #[error("Incompatible database version: {0}")]
38 IncompatibleVersion(u32),
39 #[error("Error executing SQL: {0}")]
40 SqlError(#[from] rusqlite::Error),
41 }
42
43 pub type Result<T> = std::result::Result<T, Error>;
44
45 pub trait ConnectionInitializer {
46 // Name to display in the logs
47 const NAME: &'static str;
48
49 // The version that the last upgrade function upgrades to.
50 const END_VERSION: u32;
51
52 // Functions called only for writable connections all take a Transaction
53 // Initialize a newly created database to END_VERSION
init(&self, tx: &Transaction<'_>) -> Result<()>54 fn init(&self, tx: &Transaction<'_>) -> Result<()>;
55
56 // Upgrade schema from version -> version + 1
upgrade_from(&self, conn: &Transaction<'_>, version: u32) -> Result<()>57 fn upgrade_from(&self, conn: &Transaction<'_>, version: u32) -> Result<()>;
58
59 // Runs immediately after creation for all types of connections. If writable,
60 // will *not* be in the transaction created for the "only writable" functions above.
prepare(&self, _conn: &Connection) -> Result<()>61 fn prepare(&self, _conn: &Connection) -> Result<()> {
62 Ok(())
63 }
64
65 // Runs for all types of connections. If a writable connection is being
66 // initialized, this will be called after all initialization functions,
67 // but inside their transaction.
finish(&self, _conn: &Connection) -> Result<()>68 fn finish(&self, _conn: &Connection) -> Result<()> {
69 Ok(())
70 }
71 }
72
open_database<CI: ConnectionInitializer, P: AsRef<Path>>( path: P, connection_initializer: &CI, ) -> Result<Connection>73 pub fn open_database<CI: ConnectionInitializer, P: AsRef<Path>>(
74 path: P,
75 connection_initializer: &CI,
76 ) -> Result<Connection> {
77 open_database_with_flags(path, OpenFlags::default(), connection_initializer)
78 }
79
open_memory_database<CI: ConnectionInitializer>( conn_initializer: &CI, ) -> Result<Connection>80 pub fn open_memory_database<CI: ConnectionInitializer>(
81 conn_initializer: &CI,
82 ) -> Result<Connection> {
83 open_memory_database_with_flags(OpenFlags::default(), conn_initializer)
84 }
85
open_database_with_flags<CI: ConnectionInitializer, P: AsRef<Path>>( path: P, open_flags: OpenFlags, connection_initializer: &CI, ) -> Result<Connection>86 pub fn open_database_with_flags<CI: ConnectionInitializer, P: AsRef<Path>>(
87 path: P,
88 open_flags: OpenFlags,
89 connection_initializer: &CI,
90 ) -> Result<Connection> {
91 // Try running the migration logic with an existing file
92 log::debug!("{}: opening database", CI::NAME);
93 let mut conn = Connection::open_with_flags(path, open_flags)?;
94 let run_init = should_init(&conn)?;
95
96 log::debug!("{}: preparing", CI::NAME);
97 connection_initializer.prepare(&conn)?;
98
99 if open_flags.contains(OpenFlags::SQLITE_OPEN_READ_WRITE) {
100 let tx = conn.transaction()?;
101 if run_init {
102 log::debug!("{}: initializing new database", CI::NAME);
103 connection_initializer.init(&tx)?;
104 } else {
105 let mut current_version = get_schema_version(&tx)?;
106 if current_version > CI::END_VERSION {
107 return Err(Error::IncompatibleVersion(current_version));
108 }
109 while current_version < CI::END_VERSION {
110 log::debug!(
111 "{}: upgrading database to {}",
112 CI::NAME,
113 current_version + 1
114 );
115 connection_initializer.upgrade_from(&tx, current_version)?;
116 current_version += 1;
117 }
118 }
119 log::debug!("{}: finishing writable database open", CI::NAME);
120 connection_initializer.finish(&tx)?;
121 set_schema_version(&tx, CI::END_VERSION)?;
122 tx.commit()?;
123 } else {
124 // There's an implied requirement that the first connection to a DB is
125 // writable, so read-only connections do much less, but panic if stuff is wrong
126 assert!(!run_init, "existing writer must have initialized");
127 assert!(
128 get_schema_version(&conn)? == CI::END_VERSION,
129 "existing writer must have migrated"
130 );
131 log::debug!("{}: finishing readonly database open", CI::NAME);
132 connection_initializer.finish(&conn)?;
133 }
134 log::debug!("{}: database open successful", CI::NAME);
135 Ok(conn)
136 }
137
open_memory_database_with_flags<CI: ConnectionInitializer>( flags: OpenFlags, conn_initializer: &CI, ) -> Result<Connection>138 pub fn open_memory_database_with_flags<CI: ConnectionInitializer>(
139 flags: OpenFlags,
140 conn_initializer: &CI,
141 ) -> Result<Connection> {
142 open_database_with_flags(":memory:", flags, conn_initializer)
143 }
144
should_init(conn: &Connection) -> Result<bool>145 fn should_init(conn: &Connection) -> Result<bool> {
146 Ok(conn.query_one::<u32>("SELECT COUNT(*) FROM sqlite_master")? == 0)
147 }
148
get_schema_version(conn: &Connection) -> Result<u32>149 fn get_schema_version(conn: &Connection) -> Result<u32> {
150 let version = conn.query_row_and_then("PRAGMA user_version", NO_PARAMS, |row| row.get(0))?;
151 Ok(version)
152 }
153
set_schema_version(conn: &Connection, version: u32) -> Result<()>154 fn set_schema_version(conn: &Connection, version: u32) -> Result<()> {
155 conn.set_pragma("user_version", version)?;
156 Ok(())
157 }
158
159 // It would be nice for this to be #[cfg(test)], but that doesn't allow it to be used in tests for
160 // our other crates.
161 pub mod test_utils {
162 use super::*;
163 use std::path::PathBuf;
164 use tempfile::TempDir;
165
166 // Database file that we can programatically run upgrades on
167 //
168 // We purposefully don't keep a connection to the database around to force upgrades to always
169 // run against a newly opened DB, like they would in the real world. See #4106 for
170 // details.
171 pub struct MigratedDatabaseFile<CI: ConnectionInitializer> {
172 // Keep around a TempDir to ensure the database file stays around until this struct is
173 // dropped
174 _tempdir: TempDir,
175 pub connection_initializer: CI,
176 pub path: PathBuf,
177 }
178
179 impl<CI: ConnectionInitializer> MigratedDatabaseFile<CI> {
new(connection_initializer: CI, init_sql: &str) -> Self180 pub fn new(connection_initializer: CI, init_sql: &str) -> Self {
181 Self::new_with_flags(connection_initializer, init_sql, OpenFlags::default())
182 }
183
new_with_flags( connection_initializer: CI, init_sql: &str, open_flags: OpenFlags, ) -> Self184 pub fn new_with_flags(
185 connection_initializer: CI,
186 init_sql: &str,
187 open_flags: OpenFlags,
188 ) -> Self {
189 let tempdir = tempfile::tempdir().unwrap();
190 let path = tempdir.path().join(Path::new("db.sql"));
191 let conn = Connection::open_with_flags(&path, open_flags).unwrap();
192 conn.execute_batch(init_sql).unwrap();
193 Self {
194 _tempdir: tempdir,
195 connection_initializer,
196 path,
197 }
198 }
199
upgrade_to(&self, version: u32)200 pub fn upgrade_to(&self, version: u32) {
201 let mut conn = self.open();
202 let tx = conn.transaction().unwrap();
203 let mut current_version = get_schema_version(&tx).unwrap();
204 while current_version < version {
205 self.connection_initializer
206 .upgrade_from(&tx, current_version)
207 .unwrap();
208 current_version += 1;
209 }
210 set_schema_version(&tx, current_version).unwrap();
211 self.connection_initializer.finish(&tx).unwrap();
212 tx.commit().unwrap();
213 }
214
run_all_upgrades(&self)215 pub fn run_all_upgrades(&self) {
216 let current_version = get_schema_version(&self.open()).unwrap();
217 for version in current_version..CI::END_VERSION {
218 self.upgrade_to(version + 1);
219 }
220 }
221
open(&self) -> Connection222 pub fn open(&self) -> Connection {
223 Connection::open(&self.path).unwrap()
224 }
225 }
226 }
227
228 #[cfg(test)]
229 mod test {
230 use super::test_utils::MigratedDatabaseFile;
231 use super::*;
232 use std::cell::RefCell;
233
234 struct TestConnectionInitializer {
235 pub calls: RefCell<Vec<&'static str>>,
236 pub buggy_v3_upgrade: bool,
237 }
238
239 impl TestConnectionInitializer {
new() -> Self240 pub fn new() -> Self {
241 let _ = env_logger::try_init();
242 Self {
243 calls: RefCell::new(Vec::new()),
244 buggy_v3_upgrade: false,
245 }
246 }
new_with_buggy_logic() -> Self247 pub fn new_with_buggy_logic() -> Self {
248 let _ = env_logger::try_init();
249 Self {
250 calls: RefCell::new(Vec::new()),
251 buggy_v3_upgrade: true,
252 }
253 }
254
clear_calls(&self)255 pub fn clear_calls(&self) {
256 self.calls.borrow_mut().clear();
257 }
258
push_call(&self, call: &'static str)259 pub fn push_call(&self, call: &'static str) {
260 self.calls.borrow_mut().push(call);
261 }
262
check_calls(&self, expected: Vec<&'static str>)263 pub fn check_calls(&self, expected: Vec<&'static str>) {
264 assert_eq!(*self.calls.borrow(), expected);
265 }
266 }
267
268 impl ConnectionInitializer for TestConnectionInitializer {
269 const NAME: &'static str = "test db";
270 const END_VERSION: u32 = 4;
271
prepare(&self, conn: &Connection) -> Result<()>272 fn prepare(&self, conn: &Connection) -> Result<()> {
273 self.push_call("prep");
274 conn.execute_batch(
275 "
276 PRAGMA journal_mode = wal;
277 ",
278 )?;
279 Ok(())
280 }
281
init(&self, conn: &Transaction<'_>) -> Result<()>282 fn init(&self, conn: &Transaction<'_>) -> Result<()> {
283 self.push_call("init");
284 conn.execute_batch(
285 "
286 CREATE TABLE prep_table(col);
287 INSERT INTO prep_table(col) VALUES ('correct-value');
288 CREATE TABLE my_table(col);
289 ",
290 )
291 .map_err(|e| e.into())
292 }
293
upgrade_from(&self, conn: &Transaction<'_>, version: u32) -> Result<()>294 fn upgrade_from(&self, conn: &Transaction<'_>, version: u32) -> Result<()> {
295 match version {
296 2 => {
297 self.push_call("upgrade_from_v2");
298 conn.execute_batch(
299 "
300 ALTER TABLE my_old_table_name RENAME TO my_table;
301 ",
302 )?;
303 Ok(())
304 }
305 3 => {
306 self.push_call("upgrade_from_v3");
307
308 if self.buggy_v3_upgrade {
309 conn.execute_batch("ILLEGAL_SQL_CODE")?;
310 }
311
312 conn.execute_batch(
313 "
314 ALTER TABLE my_table RENAME COLUMN old_col to col;
315 ",
316 )?;
317 Ok(())
318 }
319 _ => {
320 panic!("Unexpected version: {}", version);
321 }
322 }
323 }
324
finish(&self, conn: &Connection) -> Result<()>325 fn finish(&self, conn: &Connection) -> Result<()> {
326 self.push_call("finish");
327 conn.execute_batch(
328 "
329 INSERT INTO my_table(col) SELECT col FROM prep_table;
330 ",
331 )?;
332 Ok(())
333 }
334 }
335
336 // Initialize the database to v2 to test upgrading from there
337 static INIT_V2: &str = "
338 CREATE TABLE prep_table(col);
339 INSERT INTO prep_table(col) VALUES ('correct-value');
340 CREATE TABLE my_old_table_name(old_col);
341 PRAGMA user_version=2;
342 ";
343
check_final_data(conn: &Connection)344 fn check_final_data(conn: &Connection) {
345 let value: String = conn
346 .query_row("SELECT col FROM my_table", NO_PARAMS, |r| r.get(0))
347 .unwrap();
348 assert_eq!(value, "correct-value");
349 assert_eq!(get_schema_version(&conn).unwrap(), 4);
350 }
351
352 #[test]
test_init()353 fn test_init() {
354 let connection_initializer = TestConnectionInitializer::new();
355 let conn = open_memory_database(&connection_initializer).unwrap();
356 check_final_data(&conn);
357 connection_initializer.check_calls(vec!["prep", "init", "finish"]);
358 }
359
360 #[test]
test_upgrades()361 fn test_upgrades() {
362 let db_file = MigratedDatabaseFile::new(TestConnectionInitializer::new(), INIT_V2);
363 let conn = open_database(db_file.path.clone(), &db_file.connection_initializer).unwrap();
364 check_final_data(&conn);
365 db_file.connection_initializer.check_calls(vec![
366 "prep",
367 "upgrade_from_v2",
368 "upgrade_from_v3",
369 "finish",
370 ]);
371 }
372
373 #[test]
test_open_current_version()374 fn test_open_current_version() {
375 let db_file = MigratedDatabaseFile::new(TestConnectionInitializer::new(), INIT_V2);
376 db_file.upgrade_to(4);
377 db_file.connection_initializer.clear_calls();
378 let conn = open_database(db_file.path.clone(), &db_file.connection_initializer).unwrap();
379 check_final_data(&conn);
380 db_file
381 .connection_initializer
382 .check_calls(vec!["prep", "finish"]);
383 }
384
385 #[test]
test_pragmas()386 fn test_pragmas() {
387 let db_file = MigratedDatabaseFile::new(TestConnectionInitializer::new(), INIT_V2);
388 let conn = open_database(db_file.path.clone(), &db_file.connection_initializer).unwrap();
389 assert_eq!(
390 conn.query_one::<String>("PRAGMA journal_mode").unwrap(),
391 "wal"
392 );
393 }
394
395 #[test]
test_migration_error()396 fn test_migration_error() {
397 let db_file =
398 MigratedDatabaseFile::new(TestConnectionInitializer::new_with_buggy_logic(), INIT_V2);
399 db_file
400 .open()
401 .execute(
402 "INSERT INTO my_old_table_name(old_col) VALUES ('I should not be deleted')",
403 NO_PARAMS,
404 )
405 .unwrap();
406
407 open_database(db_file.path.clone(), &db_file.connection_initializer).unwrap_err();
408 // Even though the upgrades failed, the data should still be there. The changes that
409 // upgrade_to_v3 made should have been rolled back.
410 assert_eq!(
411 db_file
412 .open()
413 .query_one::<i32>("SELECT COUNT(*) FROM my_old_table_name")
414 .unwrap(),
415 1
416 );
417 }
418
419 #[test]
test_version_too_new()420 fn test_version_too_new() {
421 let db_file = MigratedDatabaseFile::new(TestConnectionInitializer::new(), INIT_V2);
422 set_schema_version(&db_file.open(), 5).unwrap();
423
424 db_file
425 .open()
426 .execute(
427 "INSERT INTO my_old_table_name(old_col) VALUES ('I should not be deleted')",
428 NO_PARAMS,
429 )
430 .unwrap();
431
432 assert!(matches!(
433 open_database(db_file.path.clone(), &db_file.connection_initializer,),
434 Err(Error::IncompatibleVersion(5))
435 ));
436 // Make sure that even when DeleteAndRecreate is specified, we don't delete the database
437 // file when the schema is newer
438 assert_eq!(
439 db_file
440 .open()
441 .query_one::<i32>("SELECT COUNT(*) FROM my_old_table_name")
442 .unwrap(),
443 1
444 );
445 }
446 }
447