1 /* This Source Code Form is subject to the terms of the Mozilla Public
2  * License, v. 2.0. If a copy of the MPL was not distributed with this
3  * file, You can obtain one at http://mozilla.org/MPL/2.0/. */
4 
5 /// Use this module to open a new SQLite database connection.
6 ///
7 /// Usage:
8 ///    - Define a struct that implements ConnectionInitializer.  This handles:
9 ///      - Initializing the schema for a new database
10 ///      - Upgrading the schema for an existing database
11 ///      - Extra preparation/finishing steps, for example setting up SQLite functions
12 ///
13 ///    - Call open_database() in your database constructor:
14 ///      - The first method called is `prepare()`.  This is executed outside of a transaction
15 ///        and is suitable for executing pragmas (eg, `PRAGMA journal_mode=wal`), defining
16 ///        functions, etc.
17 ///      - If the database file is not present and the connection is writable, open_database()
18 ///        will create a new DB and call init(), then finish(). If the connection is not
19 ///        writable it will panic, meaning that if you support ReadOnly connections, they must
20 ///        be created after a writable connection is open.
21 ///      - If the database file exists and the connection is writable, open_database() will open
22 ///        it and call prepare(), upgrade_from() for each upgrade that needs to be applied, then
23 ///        finish(). As above, a read-only connection will panic if upgrades are necessary, so
24 ///        you should ensure the first connection opened is writable.
25 ///      - If the connection is not writable, `finish()` will be called (ie, `finish()`, like
26 ///        `prepare()`, is called for all connections)
27 ///
28 ///  See the autofill DB code for an example.
29 ///
30 use crate::ConnExt;
31 use rusqlite::{Connection, OpenFlags, Transaction, NO_PARAMS};
32 use std::path::Path;
33 use thiserror::Error;
34 
35 #[derive(Error, Debug)]
36 pub enum Error {
37     #[error("Incompatible database version: {0}")]
38     IncompatibleVersion(u32),
39     #[error("Error executing SQL: {0}")]
40     SqlError(#[from] rusqlite::Error),
41 }
42 
43 pub type Result<T> = std::result::Result<T, Error>;
44 
45 pub trait ConnectionInitializer {
46     // Name to display in the logs
47     const NAME: &'static str;
48 
49     // The version that the last upgrade function upgrades to.
50     const END_VERSION: u32;
51 
52     // Functions called only for writable connections all take a Transaction
53     // Initialize a newly created database to END_VERSION
init(&self, tx: &Transaction<'_>) -> Result<()>54     fn init(&self, tx: &Transaction<'_>) -> Result<()>;
55 
56     // Upgrade schema from version -> version + 1
upgrade_from(&self, conn: &Transaction<'_>, version: u32) -> Result<()>57     fn upgrade_from(&self, conn: &Transaction<'_>, version: u32) -> Result<()>;
58 
59     // Runs immediately after creation for all types of connections. If writable,
60     // will *not* be in the transaction created for the "only writable" functions above.
prepare(&self, _conn: &Connection) -> Result<()>61     fn prepare(&self, _conn: &Connection) -> Result<()> {
62         Ok(())
63     }
64 
65     // Runs for all types of connections. If a writable connection is being
66     // initialized, this will be called after all initialization functions,
67     // but inside their transaction.
finish(&self, _conn: &Connection) -> Result<()>68     fn finish(&self, _conn: &Connection) -> Result<()> {
69         Ok(())
70     }
71 }
72 
open_database<CI: ConnectionInitializer, P: AsRef<Path>>( path: P, connection_initializer: &CI, ) -> Result<Connection>73 pub fn open_database<CI: ConnectionInitializer, P: AsRef<Path>>(
74     path: P,
75     connection_initializer: &CI,
76 ) -> Result<Connection> {
77     open_database_with_flags(path, OpenFlags::default(), connection_initializer)
78 }
79 
open_memory_database<CI: ConnectionInitializer>( conn_initializer: &CI, ) -> Result<Connection>80 pub fn open_memory_database<CI: ConnectionInitializer>(
81     conn_initializer: &CI,
82 ) -> Result<Connection> {
83     open_memory_database_with_flags(OpenFlags::default(), conn_initializer)
84 }
85 
open_database_with_flags<CI: ConnectionInitializer, P: AsRef<Path>>( path: P, open_flags: OpenFlags, connection_initializer: &CI, ) -> Result<Connection>86 pub fn open_database_with_flags<CI: ConnectionInitializer, P: AsRef<Path>>(
87     path: P,
88     open_flags: OpenFlags,
89     connection_initializer: &CI,
90 ) -> Result<Connection> {
91     // Try running the migration logic with an existing file
92     log::debug!("{}: opening database", CI::NAME);
93     let mut conn = Connection::open_with_flags(path, open_flags)?;
94     let run_init = should_init(&conn)?;
95 
96     log::debug!("{}: preparing", CI::NAME);
97     connection_initializer.prepare(&conn)?;
98 
99     if open_flags.contains(OpenFlags::SQLITE_OPEN_READ_WRITE) {
100         let tx = conn.transaction()?;
101         if run_init {
102             log::debug!("{}: initializing new database", CI::NAME);
103             connection_initializer.init(&tx)?;
104         } else {
105             let mut current_version = get_schema_version(&tx)?;
106             if current_version > CI::END_VERSION {
107                 return Err(Error::IncompatibleVersion(current_version));
108             }
109             while current_version < CI::END_VERSION {
110                 log::debug!(
111                     "{}: upgrading database to {}",
112                     CI::NAME,
113                     current_version + 1
114                 );
115                 connection_initializer.upgrade_from(&tx, current_version)?;
116                 current_version += 1;
117             }
118         }
119         log::debug!("{}: finishing writable database open", CI::NAME);
120         connection_initializer.finish(&tx)?;
121         set_schema_version(&tx, CI::END_VERSION)?;
122         tx.commit()?;
123     } else {
124         // There's an implied requirement that the first connection to a DB is
125         // writable, so read-only connections do much less, but panic if stuff is wrong
126         assert!(!run_init, "existing writer must have initialized");
127         assert!(
128             get_schema_version(&conn)? == CI::END_VERSION,
129             "existing writer must have migrated"
130         );
131         log::debug!("{}: finishing readonly database open", CI::NAME);
132         connection_initializer.finish(&conn)?;
133     }
134     log::debug!("{}: database open successful", CI::NAME);
135     Ok(conn)
136 }
137 
open_memory_database_with_flags<CI: ConnectionInitializer>( flags: OpenFlags, conn_initializer: &CI, ) -> Result<Connection>138 pub fn open_memory_database_with_flags<CI: ConnectionInitializer>(
139     flags: OpenFlags,
140     conn_initializer: &CI,
141 ) -> Result<Connection> {
142     open_database_with_flags(":memory:", flags, conn_initializer)
143 }
144 
should_init(conn: &Connection) -> Result<bool>145 fn should_init(conn: &Connection) -> Result<bool> {
146     Ok(conn.query_one::<u32>("SELECT COUNT(*) FROM sqlite_master")? == 0)
147 }
148 
get_schema_version(conn: &Connection) -> Result<u32>149 fn get_schema_version(conn: &Connection) -> Result<u32> {
150     let version = conn.query_row_and_then("PRAGMA user_version", NO_PARAMS, |row| row.get(0))?;
151     Ok(version)
152 }
153 
set_schema_version(conn: &Connection, version: u32) -> Result<()>154 fn set_schema_version(conn: &Connection, version: u32) -> Result<()> {
155     conn.set_pragma("user_version", version)?;
156     Ok(())
157 }
158 
159 // It would be nice for this to be #[cfg(test)], but that doesn't allow it to be used in tests for
160 // our other crates.
161 pub mod test_utils {
162     use super::*;
163     use std::path::PathBuf;
164     use tempfile::TempDir;
165 
166     // Database file that we can programatically run upgrades on
167     //
168     // We purposefully don't keep a connection to the database around to force upgrades to always
169     // run against a newly opened DB, like they would in the real world.  See #4106 for
170     // details.
171     pub struct MigratedDatabaseFile<CI: ConnectionInitializer> {
172         // Keep around a TempDir to ensure the database file stays around until this struct is
173         // dropped
174         _tempdir: TempDir,
175         pub connection_initializer: CI,
176         pub path: PathBuf,
177     }
178 
179     impl<CI: ConnectionInitializer> MigratedDatabaseFile<CI> {
new(connection_initializer: CI, init_sql: &str) -> Self180         pub fn new(connection_initializer: CI, init_sql: &str) -> Self {
181             Self::new_with_flags(connection_initializer, init_sql, OpenFlags::default())
182         }
183 
new_with_flags( connection_initializer: CI, init_sql: &str, open_flags: OpenFlags, ) -> Self184         pub fn new_with_flags(
185             connection_initializer: CI,
186             init_sql: &str,
187             open_flags: OpenFlags,
188         ) -> Self {
189             let tempdir = tempfile::tempdir().unwrap();
190             let path = tempdir.path().join(Path::new("db.sql"));
191             let conn = Connection::open_with_flags(&path, open_flags).unwrap();
192             conn.execute_batch(init_sql).unwrap();
193             Self {
194                 _tempdir: tempdir,
195                 connection_initializer,
196                 path,
197             }
198         }
199 
upgrade_to(&self, version: u32)200         pub fn upgrade_to(&self, version: u32) {
201             let mut conn = self.open();
202             let tx = conn.transaction().unwrap();
203             let mut current_version = get_schema_version(&tx).unwrap();
204             while current_version < version {
205                 self.connection_initializer
206                     .upgrade_from(&tx, current_version)
207                     .unwrap();
208                 current_version += 1;
209             }
210             set_schema_version(&tx, current_version).unwrap();
211             self.connection_initializer.finish(&tx).unwrap();
212             tx.commit().unwrap();
213         }
214 
run_all_upgrades(&self)215         pub fn run_all_upgrades(&self) {
216             let current_version = get_schema_version(&self.open()).unwrap();
217             for version in current_version..CI::END_VERSION {
218                 self.upgrade_to(version + 1);
219             }
220         }
221 
open(&self) -> Connection222         pub fn open(&self) -> Connection {
223             Connection::open(&self.path).unwrap()
224         }
225     }
226 }
227 
228 #[cfg(test)]
229 mod test {
230     use super::test_utils::MigratedDatabaseFile;
231     use super::*;
232     use std::cell::RefCell;
233 
234     struct TestConnectionInitializer {
235         pub calls: RefCell<Vec<&'static str>>,
236         pub buggy_v3_upgrade: bool,
237     }
238 
239     impl TestConnectionInitializer {
new() -> Self240         pub fn new() -> Self {
241             let _ = env_logger::try_init();
242             Self {
243                 calls: RefCell::new(Vec::new()),
244                 buggy_v3_upgrade: false,
245             }
246         }
new_with_buggy_logic() -> Self247         pub fn new_with_buggy_logic() -> Self {
248             let _ = env_logger::try_init();
249             Self {
250                 calls: RefCell::new(Vec::new()),
251                 buggy_v3_upgrade: true,
252             }
253         }
254 
clear_calls(&self)255         pub fn clear_calls(&self) {
256             self.calls.borrow_mut().clear();
257         }
258 
push_call(&self, call: &'static str)259         pub fn push_call(&self, call: &'static str) {
260             self.calls.borrow_mut().push(call);
261         }
262 
check_calls(&self, expected: Vec<&'static str>)263         pub fn check_calls(&self, expected: Vec<&'static str>) {
264             assert_eq!(*self.calls.borrow(), expected);
265         }
266     }
267 
268     impl ConnectionInitializer for TestConnectionInitializer {
269         const NAME: &'static str = "test db";
270         const END_VERSION: u32 = 4;
271 
prepare(&self, conn: &Connection) -> Result<()>272         fn prepare(&self, conn: &Connection) -> Result<()> {
273             self.push_call("prep");
274             conn.execute_batch(
275                 "
276                 PRAGMA journal_mode = wal;
277                 ",
278             )?;
279             Ok(())
280         }
281 
init(&self, conn: &Transaction<'_>) -> Result<()>282         fn init(&self, conn: &Transaction<'_>) -> Result<()> {
283             self.push_call("init");
284             conn.execute_batch(
285                 "
286                 CREATE TABLE prep_table(col);
287                 INSERT INTO prep_table(col) VALUES ('correct-value');
288                 CREATE TABLE my_table(col);
289                 ",
290             )
291             .map_err(|e| e.into())
292         }
293 
upgrade_from(&self, conn: &Transaction<'_>, version: u32) -> Result<()>294         fn upgrade_from(&self, conn: &Transaction<'_>, version: u32) -> Result<()> {
295             match version {
296                 2 => {
297                     self.push_call("upgrade_from_v2");
298                     conn.execute_batch(
299                         "
300                         ALTER TABLE my_old_table_name RENAME TO my_table;
301                         ",
302                     )?;
303                     Ok(())
304                 }
305                 3 => {
306                     self.push_call("upgrade_from_v3");
307 
308                     if self.buggy_v3_upgrade {
309                         conn.execute_batch("ILLEGAL_SQL_CODE")?;
310                     }
311 
312                     conn.execute_batch(
313                         "
314                         ALTER TABLE my_table RENAME COLUMN old_col to col;
315                         ",
316                     )?;
317                     Ok(())
318                 }
319                 _ => {
320                     panic!("Unexpected version: {}", version);
321                 }
322             }
323         }
324 
finish(&self, conn: &Connection) -> Result<()>325         fn finish(&self, conn: &Connection) -> Result<()> {
326             self.push_call("finish");
327             conn.execute_batch(
328                 "
329                 INSERT INTO my_table(col) SELECT col FROM prep_table;
330                 ",
331             )?;
332             Ok(())
333         }
334     }
335 
336     // Initialize the database to v2 to test upgrading from there
337     static INIT_V2: &str = "
338         CREATE TABLE prep_table(col);
339         INSERT INTO prep_table(col) VALUES ('correct-value');
340         CREATE TABLE my_old_table_name(old_col);
341         PRAGMA user_version=2;
342     ";
343 
check_final_data(conn: &Connection)344     fn check_final_data(conn: &Connection) {
345         let value: String = conn
346             .query_row("SELECT col FROM my_table", NO_PARAMS, |r| r.get(0))
347             .unwrap();
348         assert_eq!(value, "correct-value");
349         assert_eq!(get_schema_version(&conn).unwrap(), 4);
350     }
351 
352     #[test]
test_init()353     fn test_init() {
354         let connection_initializer = TestConnectionInitializer::new();
355         let conn = open_memory_database(&connection_initializer).unwrap();
356         check_final_data(&conn);
357         connection_initializer.check_calls(vec!["prep", "init", "finish"]);
358     }
359 
360     #[test]
test_upgrades()361     fn test_upgrades() {
362         let db_file = MigratedDatabaseFile::new(TestConnectionInitializer::new(), INIT_V2);
363         let conn = open_database(db_file.path.clone(), &db_file.connection_initializer).unwrap();
364         check_final_data(&conn);
365         db_file.connection_initializer.check_calls(vec![
366             "prep",
367             "upgrade_from_v2",
368             "upgrade_from_v3",
369             "finish",
370         ]);
371     }
372 
373     #[test]
test_open_current_version()374     fn test_open_current_version() {
375         let db_file = MigratedDatabaseFile::new(TestConnectionInitializer::new(), INIT_V2);
376         db_file.upgrade_to(4);
377         db_file.connection_initializer.clear_calls();
378         let conn = open_database(db_file.path.clone(), &db_file.connection_initializer).unwrap();
379         check_final_data(&conn);
380         db_file
381             .connection_initializer
382             .check_calls(vec!["prep", "finish"]);
383     }
384 
385     #[test]
test_pragmas()386     fn test_pragmas() {
387         let db_file = MigratedDatabaseFile::new(TestConnectionInitializer::new(), INIT_V2);
388         let conn = open_database(db_file.path.clone(), &db_file.connection_initializer).unwrap();
389         assert_eq!(
390             conn.query_one::<String>("PRAGMA journal_mode").unwrap(),
391             "wal"
392         );
393     }
394 
395     #[test]
test_migration_error()396     fn test_migration_error() {
397         let db_file =
398             MigratedDatabaseFile::new(TestConnectionInitializer::new_with_buggy_logic(), INIT_V2);
399         db_file
400             .open()
401             .execute(
402                 "INSERT INTO my_old_table_name(old_col) VALUES ('I should not be deleted')",
403                 NO_PARAMS,
404             )
405             .unwrap();
406 
407         open_database(db_file.path.clone(), &db_file.connection_initializer).unwrap_err();
408         // Even though the upgrades failed, the data should still be there.  The changes that
409         // upgrade_to_v3 made should have been rolled back.
410         assert_eq!(
411             db_file
412                 .open()
413                 .query_one::<i32>("SELECT COUNT(*) FROM my_old_table_name")
414                 .unwrap(),
415             1
416         );
417     }
418 
419     #[test]
test_version_too_new()420     fn test_version_too_new() {
421         let db_file = MigratedDatabaseFile::new(TestConnectionInitializer::new(), INIT_V2);
422         set_schema_version(&db_file.open(), 5).unwrap();
423 
424         db_file
425             .open()
426             .execute(
427                 "INSERT INTO my_old_table_name(old_col) VALUES ('I should not be deleted')",
428                 NO_PARAMS,
429             )
430             .unwrap();
431 
432         assert!(matches!(
433             open_database(db_file.path.clone(), &db_file.connection_initializer,),
434             Err(Error::IncompatibleVersion(5))
435         ));
436         // Make sure that even when DeleteAndRecreate is specified, we don't delete the database
437         // file when the schema is newer
438         assert_eq!(
439             db_file
440                 .open()
441                 .query_one::<i32>("SELECT COUNT(*) FROM my_old_table_name")
442                 .unwrap(),
443             1
444         );
445     }
446 }
447