1 //! `feature = "hooks"` Commit, Data Change and Rollback Notification Callbacks
2 #![allow(non_camel_case_types)]
3 
4 use std::os::raw::{c_char, c_int, c_void};
5 use std::panic::catch_unwind;
6 use std::ptr;
7 
8 use crate::ffi;
9 
10 use crate::{Connection, InnerConnection};
11 
12 /// `feature = "hooks"` Action Codes
13 #[derive(Clone, Copy, Debug, PartialEq)]
14 #[repr(i32)]
15 #[non_exhaustive]
16 pub enum Action {
17     UNKNOWN = -1,
18     SQLITE_DELETE = ffi::SQLITE_DELETE,
19     SQLITE_INSERT = ffi::SQLITE_INSERT,
20     SQLITE_UPDATE = ffi::SQLITE_UPDATE,
21 }
22 
23 impl From<i32> for Action {
from(code: i32) -> Action24     fn from(code: i32) -> Action {
25         match code {
26             ffi::SQLITE_DELETE => Action::SQLITE_DELETE,
27             ffi::SQLITE_INSERT => Action::SQLITE_INSERT,
28             ffi::SQLITE_UPDATE => Action::SQLITE_UPDATE,
29             _ => Action::UNKNOWN,
30         }
31     }
32 }
33 
34 impl Connection {
35     /// `feature = "hooks"` Register a callback function to be invoked whenever
36     /// a transaction is committed.
37     ///
38     /// The callback returns `true` to rollback.
commit_hook<F>(&self, hook: Option<F>) where F: FnMut() -> bool + Send + 'static,39     pub fn commit_hook<F>(&self, hook: Option<F>)
40     where
41         F: FnMut() -> bool + Send + 'static,
42     {
43         self.db.borrow_mut().commit_hook(hook);
44     }
45 
46     /// `feature = "hooks"` Register a callback function to be invoked whenever
47     /// a transaction is committed.
48     ///
49     /// The callback returns `true` to rollback.
rollback_hook<F>(&self, hook: Option<F>) where F: FnMut() + Send + 'static,50     pub fn rollback_hook<F>(&self, hook: Option<F>)
51     where
52         F: FnMut() + Send + 'static,
53     {
54         self.db.borrow_mut().rollback_hook(hook);
55     }
56 
57     /// `feature = "hooks"` Register a callback function to be invoked whenever
58     /// a row is updated, inserted or deleted in a rowid table.
59     ///
60     /// The callback parameters are:
61     ///
62     /// - the type of database update (SQLITE_INSERT, SQLITE_UPDATE or
63     /// SQLITE_DELETE),
64     /// - the name of the database ("main", "temp", ...),
65     /// - the name of the table that is updated,
66     /// - the ROWID of the row that is updated.
update_hook<F>(&self, hook: Option<F>) where F: FnMut(Action, &str, &str, i64) + Send + 'static,67     pub fn update_hook<F>(&self, hook: Option<F>)
68     where
69         F: FnMut(Action, &str, &str, i64) + Send + 'static,
70     {
71         self.db.borrow_mut().update_hook(hook);
72     }
73 }
74 
75 impl InnerConnection {
remove_hooks(&mut self)76     pub fn remove_hooks(&mut self) {
77         self.update_hook(None::<fn(Action, &str, &str, i64)>);
78         self.commit_hook(None::<fn() -> bool>);
79         self.rollback_hook(None::<fn()>);
80     }
81 
commit_hook<F>(&mut self, hook: Option<F>) where F: FnMut() -> bool + Send + 'static,82     fn commit_hook<F>(&mut self, hook: Option<F>)
83     where
84         F: FnMut() -> bool + Send + 'static,
85     {
86         unsafe extern "C" fn call_boxed_closure<F>(p_arg: *mut c_void) -> c_int
87         where
88             F: FnMut() -> bool,
89         {
90             let r = catch_unwind(|| {
91                 let boxed_hook: *mut F = p_arg as *mut F;
92                 (*boxed_hook)()
93             });
94             if let Ok(true) = r {
95                 1
96             } else {
97                 0
98             }
99         }
100 
101         // unlike `sqlite3_create_function_v2`, we cannot specify a `xDestroy` with
102         // `sqlite3_commit_hook`. so we keep the `xDestroy` function in
103         // `InnerConnection.free_boxed_hook`.
104         let free_commit_hook = if hook.is_some() {
105             Some(free_boxed_hook::<F> as unsafe fn(*mut c_void))
106         } else {
107             None
108         };
109 
110         let previous_hook = match hook {
111             Some(hook) => {
112                 let boxed_hook: *mut F = Box::into_raw(Box::new(hook));
113                 unsafe {
114                     ffi::sqlite3_commit_hook(
115                         self.db(),
116                         Some(call_boxed_closure::<F>),
117                         boxed_hook as *mut _,
118                     )
119                 }
120             }
121             _ => unsafe { ffi::sqlite3_commit_hook(self.db(), None, ptr::null_mut()) },
122         };
123         if !previous_hook.is_null() {
124             if let Some(free_boxed_hook) = self.free_commit_hook {
125                 unsafe { free_boxed_hook(previous_hook) };
126             }
127         }
128         self.free_commit_hook = free_commit_hook;
129     }
130 
rollback_hook<F>(&mut self, hook: Option<F>) where F: FnMut() + Send + 'static,131     fn rollback_hook<F>(&mut self, hook: Option<F>)
132     where
133         F: FnMut() + Send + 'static,
134     {
135         unsafe extern "C" fn call_boxed_closure<F>(p_arg: *mut c_void)
136         where
137             F: FnMut(),
138         {
139             let _ = catch_unwind(|| {
140                 let boxed_hook: *mut F = p_arg as *mut F;
141                 (*boxed_hook)();
142             });
143         }
144 
145         let free_rollback_hook = if hook.is_some() {
146             Some(free_boxed_hook::<F> as unsafe fn(*mut c_void))
147         } else {
148             None
149         };
150 
151         let previous_hook = match hook {
152             Some(hook) => {
153                 let boxed_hook: *mut F = Box::into_raw(Box::new(hook));
154                 unsafe {
155                     ffi::sqlite3_rollback_hook(
156                         self.db(),
157                         Some(call_boxed_closure::<F>),
158                         boxed_hook as *mut _,
159                     )
160                 }
161             }
162             _ => unsafe { ffi::sqlite3_rollback_hook(self.db(), None, ptr::null_mut()) },
163         };
164         if !previous_hook.is_null() {
165             if let Some(free_boxed_hook) = self.free_rollback_hook {
166                 unsafe { free_boxed_hook(previous_hook) };
167             }
168         }
169         self.free_rollback_hook = free_rollback_hook;
170     }
171 
update_hook<F>(&mut self, hook: Option<F>) where F: FnMut(Action, &str, &str, i64) + Send + 'static,172     fn update_hook<F>(&mut self, hook: Option<F>)
173     where
174         F: FnMut(Action, &str, &str, i64) + Send + 'static,
175     {
176         unsafe extern "C" fn call_boxed_closure<F>(
177             p_arg: *mut c_void,
178             action_code: c_int,
179             db_str: *const c_char,
180             tbl_str: *const c_char,
181             row_id: i64,
182         ) where
183             F: FnMut(Action, &str, &str, i64),
184         {
185             use std::ffi::CStr;
186             use std::str;
187 
188             let action = Action::from(action_code);
189             let db_name = {
190                 let c_slice = CStr::from_ptr(db_str).to_bytes();
191                 str::from_utf8(c_slice)
192             };
193             let tbl_name = {
194                 let c_slice = CStr::from_ptr(tbl_str).to_bytes();
195                 str::from_utf8(c_slice)
196             };
197 
198             let _ = catch_unwind(|| {
199                 let boxed_hook: *mut F = p_arg as *mut F;
200                 (*boxed_hook)(
201                     action,
202                     db_name.expect("illegal db name"),
203                     tbl_name.expect("illegal table name"),
204                     row_id,
205                 );
206             });
207         }
208 
209         let free_update_hook = if hook.is_some() {
210             Some(free_boxed_hook::<F> as unsafe fn(*mut c_void))
211         } else {
212             None
213         };
214 
215         let previous_hook = match hook {
216             Some(hook) => {
217                 let boxed_hook: *mut F = Box::into_raw(Box::new(hook));
218                 unsafe {
219                     ffi::sqlite3_update_hook(
220                         self.db(),
221                         Some(call_boxed_closure::<F>),
222                         boxed_hook as *mut _,
223                     )
224                 }
225             }
226             _ => unsafe { ffi::sqlite3_update_hook(self.db(), None, ptr::null_mut()) },
227         };
228         if !previous_hook.is_null() {
229             if let Some(free_boxed_hook) = self.free_update_hook {
230                 unsafe { free_boxed_hook(previous_hook) };
231             }
232         }
233         self.free_update_hook = free_update_hook;
234     }
235 }
236 
free_boxed_hook<F>(p: *mut c_void)237 unsafe fn free_boxed_hook<F>(p: *mut c_void) {
238     drop(Box::from_raw(p as *mut F));
239 }
240 
241 #[cfg(test)]
242 mod test {
243     use super::Action;
244     use crate::Connection;
245     use lazy_static::lazy_static;
246     use std::sync::atomic::{AtomicBool, Ordering};
247 
248     #[test]
test_commit_hook()249     fn test_commit_hook() {
250         let db = Connection::open_in_memory().unwrap();
251 
252         lazy_static! {
253             static ref CALLED: AtomicBool = AtomicBool::new(false);
254         }
255         db.commit_hook(Some(|| {
256             CALLED.store(true, Ordering::Relaxed);
257             false
258         }));
259         db.execute_batch("BEGIN; CREATE TABLE foo (t TEXT); COMMIT;")
260             .unwrap();
261         assert!(CALLED.load(Ordering::Relaxed));
262     }
263 
264     #[test]
test_fn_commit_hook()265     fn test_fn_commit_hook() {
266         let db = Connection::open_in_memory().unwrap();
267 
268         fn hook() -> bool {
269             true
270         }
271 
272         db.commit_hook(Some(hook));
273         db.execute_batch("BEGIN; CREATE TABLE foo (t TEXT); COMMIT;")
274             .unwrap_err();
275     }
276 
277     #[test]
test_rollback_hook()278     fn test_rollback_hook() {
279         let db = Connection::open_in_memory().unwrap();
280 
281         lazy_static! {
282             static ref CALLED: AtomicBool = AtomicBool::new(false);
283         }
284         db.rollback_hook(Some(|| {
285             CALLED.store(true, Ordering::Relaxed);
286         }));
287         db.execute_batch("BEGIN; CREATE TABLE foo (t TEXT); ROLLBACK;")
288             .unwrap();
289         assert!(CALLED.load(Ordering::Relaxed));
290     }
291 
292     #[test]
test_update_hook()293     fn test_update_hook() {
294         let db = Connection::open_in_memory().unwrap();
295 
296         lazy_static! {
297             static ref CALLED: AtomicBool = AtomicBool::new(false);
298         }
299         db.update_hook(Some(|action, db: &str, tbl: &str, row_id| {
300             assert_eq!(Action::SQLITE_INSERT, action);
301             assert_eq!("main", db);
302             assert_eq!("foo", tbl);
303             assert_eq!(1, row_id);
304             CALLED.store(true, Ordering::Relaxed);
305         }));
306         db.execute_batch("CREATE TABLE foo (t TEXT)").unwrap();
307         db.execute_batch("INSERT INTO foo VALUES ('lisa')").unwrap();
308         assert!(CALLED.load(Ordering::Relaxed));
309     }
310 }
311