1 //! `feature = "hooks"` Commit, Data Change and Rollback Notification Callbacks
2 #![allow(non_camel_case_types)]
3 
4 use std::os::raw::{c_char, c_int, c_void};
5 use std::panic::catch_unwind;
6 use std::ptr;
7 
8 use crate::ffi;
9 
10 use crate::{Connection, InnerConnection};
11 
12 /// `feature = "hooks"` Action Codes
13 #[derive(Clone, Copy, Debug, PartialEq)]
14 #[repr(i32)]
15 #[non_exhaustive]
16 pub enum Action {
17     /// Unsupported / unexpected action
18     UNKNOWN = -1,
19     /// DELETE command
20     SQLITE_DELETE = ffi::SQLITE_DELETE,
21     /// INSERT command
22     SQLITE_INSERT = ffi::SQLITE_INSERT,
23     /// UPDATE command
24     SQLITE_UPDATE = ffi::SQLITE_UPDATE,
25 }
26 
27 impl From<i32> for Action {
from(code: i32) -> Action28     fn from(code: i32) -> Action {
29         match code {
30             ffi::SQLITE_DELETE => Action::SQLITE_DELETE,
31             ffi::SQLITE_INSERT => Action::SQLITE_INSERT,
32             ffi::SQLITE_UPDATE => Action::SQLITE_UPDATE,
33             _ => Action::UNKNOWN,
34         }
35     }
36 }
37 
38 impl Connection {
39     /// `feature = "hooks"` Register a callback function to be invoked whenever
40     /// a transaction is committed.
41     ///
42     /// The callback returns `true` to rollback.
commit_hook<F>(&self, hook: Option<F>) where F: FnMut() -> bool + Send + 'static,43     pub fn commit_hook<F>(&self, hook: Option<F>)
44     where
45         F: FnMut() -> bool + Send + 'static,
46     {
47         self.db.borrow_mut().commit_hook(hook);
48     }
49 
50     /// `feature = "hooks"` Register a callback function to be invoked whenever
51     /// a transaction is committed.
52     ///
53     /// The callback returns `true` to rollback.
rollback_hook<F>(&self, hook: Option<F>) where F: FnMut() + Send + 'static,54     pub fn rollback_hook<F>(&self, hook: Option<F>)
55     where
56         F: FnMut() + Send + 'static,
57     {
58         self.db.borrow_mut().rollback_hook(hook);
59     }
60 
61     /// `feature = "hooks"` Register a callback function to be invoked whenever
62     /// a row is updated, inserted or deleted in a rowid table.
63     ///
64     /// The callback parameters are:
65     ///
66     /// - the type of database update (SQLITE_INSERT, SQLITE_UPDATE or
67     /// SQLITE_DELETE),
68     /// - the name of the database ("main", "temp", ...),
69     /// - the name of the table that is updated,
70     /// - the ROWID of the row that is updated.
update_hook<F>(&self, hook: Option<F>) where F: FnMut(Action, &str, &str, i64) + Send + 'static,71     pub fn update_hook<F>(&self, hook: Option<F>)
72     where
73         F: FnMut(Action, &str, &str, i64) + Send + 'static,
74     {
75         self.db.borrow_mut().update_hook(hook);
76     }
77 }
78 
79 impl InnerConnection {
remove_hooks(&mut self)80     pub fn remove_hooks(&mut self) {
81         self.update_hook(None::<fn(Action, &str, &str, i64)>);
82         self.commit_hook(None::<fn() -> bool>);
83         self.rollback_hook(None::<fn()>);
84     }
85 
commit_hook<F>(&mut self, hook: Option<F>) where F: FnMut() -> bool + Send + 'static,86     fn commit_hook<F>(&mut self, hook: Option<F>)
87     where
88         F: FnMut() -> bool + Send + 'static,
89     {
90         unsafe extern "C" fn call_boxed_closure<F>(p_arg: *mut c_void) -> c_int
91         where
92             F: FnMut() -> bool,
93         {
94             let r = catch_unwind(|| {
95                 let boxed_hook: *mut F = p_arg as *mut F;
96                 (*boxed_hook)()
97             });
98             if let Ok(true) = r {
99                 1
100             } else {
101                 0
102             }
103         }
104 
105         // unlike `sqlite3_create_function_v2`, we cannot specify a `xDestroy` with
106         // `sqlite3_commit_hook`. so we keep the `xDestroy` function in
107         // `InnerConnection.free_boxed_hook`.
108         let free_commit_hook = if hook.is_some() {
109             Some(free_boxed_hook::<F> as unsafe fn(*mut c_void))
110         } else {
111             None
112         };
113 
114         let previous_hook = match hook {
115             Some(hook) => {
116                 let boxed_hook: *mut F = Box::into_raw(Box::new(hook));
117                 unsafe {
118                     ffi::sqlite3_commit_hook(
119                         self.db(),
120                         Some(call_boxed_closure::<F>),
121                         boxed_hook as *mut _,
122                     )
123                 }
124             }
125             _ => unsafe { ffi::sqlite3_commit_hook(self.db(), None, ptr::null_mut()) },
126         };
127         if !previous_hook.is_null() {
128             if let Some(free_boxed_hook) = self.free_commit_hook {
129                 unsafe { free_boxed_hook(previous_hook) };
130             }
131         }
132         self.free_commit_hook = free_commit_hook;
133     }
134 
rollback_hook<F>(&mut self, hook: Option<F>) where F: FnMut() + Send + 'static,135     fn rollback_hook<F>(&mut self, hook: Option<F>)
136     where
137         F: FnMut() + Send + 'static,
138     {
139         unsafe extern "C" fn call_boxed_closure<F>(p_arg: *mut c_void)
140         where
141             F: FnMut(),
142         {
143             let _ = catch_unwind(|| {
144                 let boxed_hook: *mut F = p_arg as *mut F;
145                 (*boxed_hook)();
146             });
147         }
148 
149         let free_rollback_hook = if hook.is_some() {
150             Some(free_boxed_hook::<F> as unsafe fn(*mut c_void))
151         } else {
152             None
153         };
154 
155         let previous_hook = match hook {
156             Some(hook) => {
157                 let boxed_hook: *mut F = Box::into_raw(Box::new(hook));
158                 unsafe {
159                     ffi::sqlite3_rollback_hook(
160                         self.db(),
161                         Some(call_boxed_closure::<F>),
162                         boxed_hook as *mut _,
163                     )
164                 }
165             }
166             _ => unsafe { ffi::sqlite3_rollback_hook(self.db(), None, ptr::null_mut()) },
167         };
168         if !previous_hook.is_null() {
169             if let Some(free_boxed_hook) = self.free_rollback_hook {
170                 unsafe { free_boxed_hook(previous_hook) };
171             }
172         }
173         self.free_rollback_hook = free_rollback_hook;
174     }
175 
update_hook<F>(&mut self, hook: Option<F>) where F: FnMut(Action, &str, &str, i64) + Send + 'static,176     fn update_hook<F>(&mut self, hook: Option<F>)
177     where
178         F: FnMut(Action, &str, &str, i64) + Send + 'static,
179     {
180         unsafe extern "C" fn call_boxed_closure<F>(
181             p_arg: *mut c_void,
182             action_code: c_int,
183             db_str: *const c_char,
184             tbl_str: *const c_char,
185             row_id: i64,
186         ) where
187             F: FnMut(Action, &str, &str, i64),
188         {
189             use std::ffi::CStr;
190             use std::str;
191 
192             let action = Action::from(action_code);
193             let db_name = {
194                 let c_slice = CStr::from_ptr(db_str).to_bytes();
195                 str::from_utf8(c_slice)
196             };
197             let tbl_name = {
198                 let c_slice = CStr::from_ptr(tbl_str).to_bytes();
199                 str::from_utf8(c_slice)
200             };
201 
202             let _ = catch_unwind(|| {
203                 let boxed_hook: *mut F = p_arg as *mut F;
204                 (*boxed_hook)(
205                     action,
206                     db_name.expect("illegal db name"),
207                     tbl_name.expect("illegal table name"),
208                     row_id,
209                 );
210             });
211         }
212 
213         let free_update_hook = if hook.is_some() {
214             Some(free_boxed_hook::<F> as unsafe fn(*mut c_void))
215         } else {
216             None
217         };
218 
219         let previous_hook = match hook {
220             Some(hook) => {
221                 let boxed_hook: *mut F = Box::into_raw(Box::new(hook));
222                 unsafe {
223                     ffi::sqlite3_update_hook(
224                         self.db(),
225                         Some(call_boxed_closure::<F>),
226                         boxed_hook as *mut _,
227                     )
228                 }
229             }
230             _ => unsafe { ffi::sqlite3_update_hook(self.db(), None, ptr::null_mut()) },
231         };
232         if !previous_hook.is_null() {
233             if let Some(free_boxed_hook) = self.free_update_hook {
234                 unsafe { free_boxed_hook(previous_hook) };
235             }
236         }
237         self.free_update_hook = free_update_hook;
238     }
239 }
240 
free_boxed_hook<F>(p: *mut c_void)241 unsafe fn free_boxed_hook<F>(p: *mut c_void) {
242     drop(Box::from_raw(p as *mut F));
243 }
244 
245 #[cfg(test)]
246 mod test {
247     use super::Action;
248     use crate::Connection;
249     use lazy_static::lazy_static;
250     use std::sync::atomic::{AtomicBool, Ordering};
251 
252     #[test]
test_commit_hook()253     fn test_commit_hook() {
254         let db = Connection::open_in_memory().unwrap();
255 
256         lazy_static! {
257             static ref CALLED: AtomicBool = AtomicBool::new(false);
258         }
259         db.commit_hook(Some(|| {
260             CALLED.store(true, Ordering::Relaxed);
261             false
262         }));
263         db.execute_batch("BEGIN; CREATE TABLE foo (t TEXT); COMMIT;")
264             .unwrap();
265         assert!(CALLED.load(Ordering::Relaxed));
266     }
267 
268     #[test]
test_fn_commit_hook()269     fn test_fn_commit_hook() {
270         let db = Connection::open_in_memory().unwrap();
271 
272         fn hook() -> bool {
273             true
274         }
275 
276         db.commit_hook(Some(hook));
277         db.execute_batch("BEGIN; CREATE TABLE foo (t TEXT); COMMIT;")
278             .unwrap_err();
279     }
280 
281     #[test]
test_rollback_hook()282     fn test_rollback_hook() {
283         let db = Connection::open_in_memory().unwrap();
284 
285         lazy_static! {
286             static ref CALLED: AtomicBool = AtomicBool::new(false);
287         }
288         db.rollback_hook(Some(|| {
289             CALLED.store(true, Ordering::Relaxed);
290         }));
291         db.execute_batch("BEGIN; CREATE TABLE foo (t TEXT); ROLLBACK;")
292             .unwrap();
293         assert!(CALLED.load(Ordering::Relaxed));
294     }
295 
296     #[test]
test_update_hook()297     fn test_update_hook() {
298         let db = Connection::open_in_memory().unwrap();
299 
300         lazy_static! {
301             static ref CALLED: AtomicBool = AtomicBool::new(false);
302         }
303         db.update_hook(Some(|action, db: &str, tbl: &str, row_id| {
304             assert_eq!(Action::SQLITE_INSERT, action);
305             assert_eq!("main", db);
306             assert_eq!("foo", tbl);
307             assert_eq!(1, row_id);
308             CALLED.store(true, Ordering::Relaxed);
309         }));
310         db.execute_batch("CREATE TABLE foo (t TEXT)").unwrap();
311         db.execute_batch("INSERT INTO foo VALUES ('lisa')").unwrap();
312         assert!(CALLED.load(Ordering::Relaxed));
313     }
314 }
315