1 use crate::{Connection, Result};
2 use std::ops::Deref;
3 
4 /// Options for transaction behavior. See [BEGIN
5 /// TRANSACTION](http://www.sqlite.org/lang_transaction.html) for details.
6 #[derive(Copy, Clone)]
7 pub enum TransactionBehavior {
8     Deferred,
9     Immediate,
10     Exclusive,
11 }
12 
13 /// Options for how a Transaction or Savepoint should behave when it is dropped.
14 #[derive(Copy, Clone, Debug, PartialEq, Eq)]
15 pub enum DropBehavior {
16     /// Roll back the changes. This is the default.
17     Rollback,
18 
19     /// Commit the changes.
20     Commit,
21 
22     /// Do not commit or roll back changes - this will leave the transaction or
23     /// savepoint open, so should be used with care.
24     Ignore,
25 
26     /// Panic. Used to enforce intentional behavior during development.
27     Panic,
28 }
29 
30 /// Represents a transaction on a database connection.
31 ///
32 /// ## Note
33 ///
34 /// Transactions will roll back by default. Use `commit` method to explicitly
35 /// commit the transaction, or use `set_drop_behavior` to change what happens
36 /// when the transaction is dropped.
37 ///
38 /// ## Example
39 ///
40 /// ```rust,no_run
41 /// # use rusqlite::{Connection, Result};
42 /// # fn do_queries_part_1(_conn: &Connection) -> Result<()> { Ok(()) }
43 /// # fn do_queries_part_2(_conn: &Connection) -> Result<()> { Ok(()) }
44 /// fn perform_queries(conn: &mut Connection) -> Result<()> {
45 ///     let tx = conn.transaction()?;
46 ///
47 ///     do_queries_part_1(&tx)?; // tx causes rollback if this fails
48 ///     do_queries_part_2(&tx)?; // tx causes rollback if this fails
49 ///
50 ///     tx.commit()
51 /// }
52 /// ```
53 #[derive(Debug)]
54 pub struct Transaction<'conn> {
55     conn: &'conn Connection,
56     drop_behavior: DropBehavior,
57 }
58 
59 /// Represents a savepoint on a database connection.
60 ///
61 /// ## Note
62 ///
63 /// Savepoints will roll back by default. Use `commit` method to explicitly
64 /// commit the savepoint, or use `set_drop_behavior` to change what happens
65 /// when the savepoint is dropped.
66 ///
67 /// ## Example
68 ///
69 /// ```rust,no_run
70 /// # use rusqlite::{Connection, Result};
71 /// # fn do_queries_part_1(_conn: &Connection) -> Result<()> { Ok(()) }
72 /// # fn do_queries_part_2(_conn: &Connection) -> Result<()> { Ok(()) }
73 /// fn perform_queries(conn: &mut Connection) -> Result<()> {
74 ///     let sp = conn.savepoint()?;
75 ///
76 ///     do_queries_part_1(&sp)?; // sp causes rollback if this fails
77 ///     do_queries_part_2(&sp)?; // sp causes rollback if this fails
78 ///
79 ///     sp.commit()
80 /// }
81 /// ```
82 pub struct Savepoint<'conn> {
83     conn: &'conn Connection,
84     name: String,
85     depth: u32,
86     drop_behavior: DropBehavior,
87     committed: bool,
88 }
89 
90 impl Transaction<'_> {
91     /// Begin a new transaction. Cannot be nested; see `savepoint` for nested
92     /// transactions.
93     /// Even though we don't mutate the connection, we take a `&mut Connection`
94     /// so as to prevent nested or concurrent transactions on the same
95     /// connection.
new(conn: &mut Connection, behavior: TransactionBehavior) -> Result<Transaction<'_>>96     pub fn new(conn: &mut Connection, behavior: TransactionBehavior) -> Result<Transaction<'_>> {
97         let query = match behavior {
98             TransactionBehavior::Deferred => "BEGIN DEFERRED",
99             TransactionBehavior::Immediate => "BEGIN IMMEDIATE",
100             TransactionBehavior::Exclusive => "BEGIN EXCLUSIVE",
101         };
102         conn.execute_batch(query).map(move |_| Transaction {
103             conn,
104             drop_behavior: DropBehavior::Rollback,
105         })
106     }
107 
108     /// Starts a new [savepoint](http://www.sqlite.org/lang_savepoint.html), allowing nested
109     /// transactions.
110     ///
111     /// ## Note
112     ///
113     /// Just like outer level transactions, savepoint transactions rollback by
114     /// default.
115     ///
116     /// ## Example
117     ///
118     /// ```rust,no_run
119     /// # use rusqlite::{Connection, Result};
120     /// # fn perform_queries_part_1_succeeds(_conn: &Connection) -> bool { true }
121     /// fn perform_queries(conn: &mut Connection) -> Result<()> {
122     ///     let mut tx = conn.transaction()?;
123     ///
124     ///     {
125     ///         let sp = tx.savepoint()?;
126     ///         if perform_queries_part_1_succeeds(&sp) {
127     ///             sp.commit()?;
128     ///         }
129     ///         // otherwise, sp will rollback
130     ///     }
131     ///
132     ///     tx.commit()
133     /// }
134     /// ```
savepoint(&mut self) -> Result<Savepoint<'_>>135     pub fn savepoint(&mut self) -> Result<Savepoint<'_>> {
136         Savepoint::with_depth(self.conn, 1)
137     }
138 
139     /// Create a new savepoint with a custom savepoint name. See `savepoint()`.
savepoint_with_name<T: Into<String>>(&mut self, name: T) -> Result<Savepoint<'_>>140     pub fn savepoint_with_name<T: Into<String>>(&mut self, name: T) -> Result<Savepoint<'_>> {
141         Savepoint::with_depth_and_name(self.conn, 1, name)
142     }
143 
144     /// Get the current setting for what happens to the transaction when it is
145     /// dropped.
drop_behavior(&self) -> DropBehavior146     pub fn drop_behavior(&self) -> DropBehavior {
147         self.drop_behavior
148     }
149 
150     /// Configure the transaction to perform the specified action when it is
151     /// dropped.
set_drop_behavior(&mut self, drop_behavior: DropBehavior)152     pub fn set_drop_behavior(&mut self, drop_behavior: DropBehavior) {
153         self.drop_behavior = drop_behavior
154     }
155 
156     /// A convenience method which consumes and commits a transaction.
commit(mut self) -> Result<()>157     pub fn commit(mut self) -> Result<()> {
158         self.commit_()
159     }
160 
commit_(&mut self) -> Result<()>161     fn commit_(&mut self) -> Result<()> {
162         self.conn.execute_batch("COMMIT")?;
163         Ok(())
164     }
165 
166     /// A convenience method which consumes and rolls back a transaction.
rollback(mut self) -> Result<()>167     pub fn rollback(mut self) -> Result<()> {
168         self.rollback_()
169     }
170 
rollback_(&mut self) -> Result<()>171     fn rollback_(&mut self) -> Result<()> {
172         self.conn.execute_batch("ROLLBACK")?;
173         Ok(())
174     }
175 
176     /// Consumes the transaction, committing or rolling back according to the
177     /// current setting (see `drop_behavior`).
178     ///
179     /// Functionally equivalent to the `Drop` implementation, but allows
180     /// callers to see any errors that occur.
finish(mut self) -> Result<()>181     pub fn finish(mut self) -> Result<()> {
182         self.finish_()
183     }
184 
finish_(&mut self) -> Result<()>185     fn finish_(&mut self) -> Result<()> {
186         if self.conn.is_autocommit() {
187             return Ok(());
188         }
189         match self.drop_behavior() {
190             DropBehavior::Commit => self.commit_().or_else(|_| self.rollback_()),
191             DropBehavior::Rollback => self.rollback_(),
192             DropBehavior::Ignore => Ok(()),
193             DropBehavior::Panic => panic!("Transaction dropped unexpectedly."),
194         }
195     }
196 }
197 
198 impl Deref for Transaction<'_> {
199     type Target = Connection;
200 
deref(&self) -> &Connection201     fn deref(&self) -> &Connection {
202         self.conn
203     }
204 }
205 
206 #[allow(unused_must_use)]
207 impl Drop for Transaction<'_> {
drop(&mut self)208     fn drop(&mut self) {
209         self.finish_();
210     }
211 }
212 
213 impl Savepoint<'_> {
with_depth_and_name<T: Into<String>>( conn: &Connection, depth: u32, name: T, ) -> Result<Savepoint<'_>>214     fn with_depth_and_name<T: Into<String>>(
215         conn: &Connection,
216         depth: u32,
217         name: T,
218     ) -> Result<Savepoint<'_>> {
219         let name = name.into();
220         conn.execute_batch(&format!("SAVEPOINT {}", name))
221             .map(|_| Savepoint {
222                 conn,
223                 name,
224                 depth,
225                 drop_behavior: DropBehavior::Rollback,
226                 committed: false,
227             })
228     }
229 
with_depth(conn: &Connection, depth: u32) -> Result<Savepoint<'_>>230     fn with_depth(conn: &Connection, depth: u32) -> Result<Savepoint<'_>> {
231         let name = format!("_rusqlite_sp_{}", depth);
232         Savepoint::with_depth_and_name(conn, depth, name)
233     }
234 
235     /// Begin a new savepoint. Can be nested.
new(conn: &mut Connection) -> Result<Savepoint<'_>>236     pub fn new(conn: &mut Connection) -> Result<Savepoint<'_>> {
237         Savepoint::with_depth(conn, 0)
238     }
239 
240     /// Begin a new savepoint with a user-provided savepoint name.
with_name<T: Into<String>>(conn: &mut Connection, name: T) -> Result<Savepoint<'_>>241     pub fn with_name<T: Into<String>>(conn: &mut Connection, name: T) -> Result<Savepoint<'_>> {
242         Savepoint::with_depth_and_name(conn, 0, name)
243     }
244 
245     /// Begin a nested savepoint.
savepoint(&mut self) -> Result<Savepoint<'_>>246     pub fn savepoint(&mut self) -> Result<Savepoint<'_>> {
247         Savepoint::with_depth(self.conn, self.depth + 1)
248     }
249 
250     /// Begin a nested savepoint with a user-provided savepoint name.
savepoint_with_name<T: Into<String>>(&mut self, name: T) -> Result<Savepoint<'_>>251     pub fn savepoint_with_name<T: Into<String>>(&mut self, name: T) -> Result<Savepoint<'_>> {
252         Savepoint::with_depth_and_name(self.conn, self.depth + 1, name)
253     }
254 
255     /// Get the current setting for what happens to the savepoint when it is
256     /// dropped.
drop_behavior(&self) -> DropBehavior257     pub fn drop_behavior(&self) -> DropBehavior {
258         self.drop_behavior
259     }
260 
261     /// Configure the savepoint to perform the specified action when it is
262     /// dropped.
set_drop_behavior(&mut self, drop_behavior: DropBehavior)263     pub fn set_drop_behavior(&mut self, drop_behavior: DropBehavior) {
264         self.drop_behavior = drop_behavior
265     }
266 
267     /// A convenience method which consumes and commits a savepoint.
commit(mut self) -> Result<()>268     pub fn commit(mut self) -> Result<()> {
269         self.commit_()
270     }
271 
commit_(&mut self) -> Result<()>272     fn commit_(&mut self) -> Result<()> {
273         self.conn.execute_batch(&format!("RELEASE {}", self.name))?;
274         self.committed = true;
275         Ok(())
276     }
277 
278     /// A convenience method which rolls back a savepoint.
279     ///
280     /// ## Note
281     ///
282     /// Unlike `Transaction`s, savepoints remain active after they have been
283     /// rolled back, and can be rolled back again or committed.
rollback(&mut self) -> Result<()>284     pub fn rollback(&mut self) -> Result<()> {
285         self.conn
286             .execute_batch(&format!("ROLLBACK TO {}", self.name))
287     }
288 
289     /// Consumes the savepoint, committing or rolling back according to the
290     /// current setting (see `drop_behavior`).
291     ///
292     /// Functionally equivalent to the `Drop` implementation, but allows
293     /// callers to see any errors that occur.
finish(mut self) -> Result<()>294     pub fn finish(mut self) -> Result<()> {
295         self.finish_()
296     }
297 
finish_(&mut self) -> Result<()>298     fn finish_(&mut self) -> Result<()> {
299         if self.committed {
300             return Ok(());
301         }
302         match self.drop_behavior() {
303             DropBehavior::Commit => self.commit_().or_else(|_| self.rollback()),
304             DropBehavior::Rollback => self.rollback(),
305             DropBehavior::Ignore => Ok(()),
306             DropBehavior::Panic => panic!("Savepoint dropped unexpectedly."),
307         }
308     }
309 }
310 
311 impl Deref for Savepoint<'_> {
312     type Target = Connection;
313 
deref(&self) -> &Connection314     fn deref(&self) -> &Connection {
315         self.conn
316     }
317 }
318 
319 #[allow(unused_must_use)]
320 impl Drop for Savepoint<'_> {
drop(&mut self)321     fn drop(&mut self) {
322         self.finish_();
323     }
324 }
325 
326 impl Connection {
327     /// Begin a new transaction with the default behavior (DEFERRED).
328     ///
329     /// The transaction defaults to rolling back when it is dropped. If you
330     /// want the transaction to commit, you must call `commit` or
331     /// `set_drop_behavior(DropBehavior::Commit)`.
332     ///
333     /// ## Example
334     ///
335     /// ```rust,no_run
336     /// # use rusqlite::{Connection, Result};
337     /// # fn do_queries_part_1(_conn: &Connection) -> Result<()> { Ok(()) }
338     /// # fn do_queries_part_2(_conn: &Connection) -> Result<()> { Ok(()) }
339     /// fn perform_queries(conn: &mut Connection) -> Result<()> {
340     ///     let tx = conn.transaction()?;
341     ///
342     ///     do_queries_part_1(&tx)?; // tx causes rollback if this fails
343     ///     do_queries_part_2(&tx)?; // tx causes rollback if this fails
344     ///
345     ///     tx.commit()
346     /// }
347     /// ```
348     ///
349     /// # Failure
350     ///
351     /// Will return `Err` if the underlying SQLite call fails.
transaction(&mut self) -> Result<Transaction<'_>>352     pub fn transaction(&mut self) -> Result<Transaction<'_>> {
353         Transaction::new(self, TransactionBehavior::Deferred)
354     }
355 
356     /// Begin a new transaction with a specified behavior.
357     ///
358     /// See `transaction`.
359     ///
360     /// # Failure
361     ///
362     /// Will return `Err` if the underlying SQLite call fails.
transaction_with_behavior( &mut self, behavior: TransactionBehavior, ) -> Result<Transaction<'_>>363     pub fn transaction_with_behavior(
364         &mut self,
365         behavior: TransactionBehavior,
366     ) -> Result<Transaction<'_>> {
367         Transaction::new(self, behavior)
368     }
369 
370     /// Begin a new savepoint with the default behavior (DEFERRED).
371     ///
372     /// The savepoint defaults to rolling back when it is dropped. If you want
373     /// the savepoint to commit, you must call `commit` or
374     /// `set_drop_behavior(DropBehavior::Commit)`.
375     ///
376     /// ## Example
377     ///
378     /// ```rust,no_run
379     /// # use rusqlite::{Connection, Result};
380     /// # fn do_queries_part_1(_conn: &Connection) -> Result<()> { Ok(()) }
381     /// # fn do_queries_part_2(_conn: &Connection) -> Result<()> { Ok(()) }
382     /// fn perform_queries(conn: &mut Connection) -> Result<()> {
383     ///     let sp = conn.savepoint()?;
384     ///
385     ///     do_queries_part_1(&sp)?; // sp causes rollback if this fails
386     ///     do_queries_part_2(&sp)?; // sp causes rollback if this fails
387     ///
388     ///     sp.commit()
389     /// }
390     /// ```
391     ///
392     /// # Failure
393     ///
394     /// Will return `Err` if the underlying SQLite call fails.
savepoint(&mut self) -> Result<Savepoint<'_>>395     pub fn savepoint(&mut self) -> Result<Savepoint<'_>> {
396         Savepoint::new(self)
397     }
398 
399     /// Begin a new savepoint with a specified name.
400     ///
401     /// See `savepoint`.
402     ///
403     /// # Failure
404     ///
405     /// Will return `Err` if the underlying SQLite call fails.
savepoint_with_name<T: Into<String>>(&mut self, name: T) -> Result<Savepoint<'_>>406     pub fn savepoint_with_name<T: Into<String>>(&mut self, name: T) -> Result<Savepoint<'_>> {
407         Savepoint::with_name(self, name)
408     }
409 }
410 
411 #[cfg(test)]
412 mod test {
413     use super::DropBehavior;
414     use crate::{Connection, NO_PARAMS};
415 
checked_memory_handle() -> Connection416     fn checked_memory_handle() -> Connection {
417         let db = Connection::open_in_memory().unwrap();
418         db.execute_batch("CREATE TABLE foo (x INTEGER)").unwrap();
419         db
420     }
421 
422     #[test]
test_drop()423     fn test_drop() {
424         let mut db = checked_memory_handle();
425         {
426             let tx = db.transaction().unwrap();
427             tx.execute_batch("INSERT INTO foo VALUES(1)").unwrap();
428             // default: rollback
429         }
430         {
431             let mut tx = db.transaction().unwrap();
432             tx.execute_batch("INSERT INTO foo VALUES(2)").unwrap();
433             tx.set_drop_behavior(DropBehavior::Commit)
434         }
435         {
436             let tx = db.transaction().unwrap();
437             assert_eq!(
438                 2i32,
439                 tx.query_row::<i32, _, _>("SELECT SUM(x) FROM foo", NO_PARAMS, |r| r.get(0))
440                     .unwrap()
441             );
442         }
443     }
444 
445     #[test]
test_explicit_rollback_commit()446     fn test_explicit_rollback_commit() {
447         let mut db = checked_memory_handle();
448         {
449             let mut tx = db.transaction().unwrap();
450             {
451                 let mut sp = tx.savepoint().unwrap();
452                 sp.execute_batch("INSERT INTO foo VALUES(1)").unwrap();
453                 sp.rollback().unwrap();
454                 sp.execute_batch("INSERT INTO foo VALUES(2)").unwrap();
455                 sp.commit().unwrap();
456             }
457             tx.commit().unwrap();
458         }
459         {
460             let tx = db.transaction().unwrap();
461             tx.execute_batch("INSERT INTO foo VALUES(4)").unwrap();
462             tx.commit().unwrap();
463         }
464         {
465             let tx = db.transaction().unwrap();
466             assert_eq!(
467                 6i32,
468                 tx.query_row::<i32, _, _>("SELECT SUM(x) FROM foo", NO_PARAMS, |r| r.get(0))
469                     .unwrap()
470             );
471         }
472     }
473 
474     #[test]
test_savepoint()475     fn test_savepoint() {
476         let mut db = checked_memory_handle();
477         {
478             let mut tx = db.transaction().unwrap();
479             tx.execute_batch("INSERT INTO foo VALUES(1)").unwrap();
480             assert_current_sum(1, &tx);
481             tx.set_drop_behavior(DropBehavior::Commit);
482             {
483                 let mut sp1 = tx.savepoint().unwrap();
484                 sp1.execute_batch("INSERT INTO foo VALUES(2)").unwrap();
485                 assert_current_sum(3, &sp1);
486                 // will rollback sp1
487                 {
488                     let mut sp2 = sp1.savepoint().unwrap();
489                     sp2.execute_batch("INSERT INTO foo VALUES(4)").unwrap();
490                     assert_current_sum(7, &sp2);
491                     // will rollback sp2
492                     {
493                         let sp3 = sp2.savepoint().unwrap();
494                         sp3.execute_batch("INSERT INTO foo VALUES(8)").unwrap();
495                         assert_current_sum(15, &sp3);
496                         sp3.commit().unwrap();
497                         // committed sp3, but will be erased by sp2 rollback
498                     }
499                     assert_current_sum(15, &sp2);
500                 }
501                 assert_current_sum(3, &sp1);
502             }
503             assert_current_sum(1, &tx);
504         }
505         assert_current_sum(1, &db);
506     }
507 
508     #[test]
test_ignore_drop_behavior()509     fn test_ignore_drop_behavior() {
510         let mut db = checked_memory_handle();
511 
512         let mut tx = db.transaction().unwrap();
513         {
514             let mut sp1 = tx.savepoint().unwrap();
515             insert(1, &sp1);
516             sp1.rollback().unwrap();
517             insert(2, &sp1);
518             {
519                 let mut sp2 = sp1.savepoint().unwrap();
520                 sp2.set_drop_behavior(DropBehavior::Ignore);
521                 insert(4, &sp2);
522             }
523             assert_current_sum(6, &sp1);
524             sp1.commit().unwrap();
525         }
526         assert_current_sum(6, &tx);
527     }
528 
529     #[test]
test_savepoint_names()530     fn test_savepoint_names() {
531         let mut db = checked_memory_handle();
532 
533         {
534             let mut sp1 = db.savepoint_with_name("my_sp").unwrap();
535             insert(1, &sp1);
536             assert_current_sum(1, &sp1);
537             {
538                 let mut sp2 = sp1.savepoint_with_name("my_sp").unwrap();
539                 sp2.set_drop_behavior(DropBehavior::Commit);
540                 insert(2, &sp2);
541                 assert_current_sum(3, &sp2);
542                 sp2.rollback().unwrap();
543                 assert_current_sum(1, &sp2);
544                 insert(4, &sp2);
545             }
546             assert_current_sum(5, &sp1);
547             sp1.rollback().unwrap();
548             {
549                 let mut sp2 = sp1.savepoint_with_name("my_sp").unwrap();
550                 sp2.set_drop_behavior(DropBehavior::Ignore);
551                 insert(8, &sp2);
552             }
553             assert_current_sum(8, &sp1);
554             sp1.commit().unwrap();
555         }
556         assert_current_sum(8, &db);
557     }
558 
559     #[test]
test_rc()560     fn test_rc() {
561         use std::rc::Rc;
562         let mut conn = Connection::open_in_memory().unwrap();
563         let rc_txn = Rc::new(conn.transaction().unwrap());
564 
565         // This will compile only if Transaction is Debug
566         Rc::try_unwrap(rc_txn).unwrap();
567     }
568 
insert(x: i32, conn: &Connection)569     fn insert(x: i32, conn: &Connection) {
570         conn.execute("INSERT INTO foo VALUES(?)", &[x]).unwrap();
571     }
572 
assert_current_sum(x: i32, conn: &Connection)573     fn assert_current_sum(x: i32, conn: &Connection) {
574         let i = conn
575             .query_row::<i32, _, _>("SELECT SUM(x) FROM foo", NO_PARAMS, |r| r.get(0))
576             .unwrap();
577         assert_eq!(x, i);
578     }
579 }
580