1 //! `feature = "hooks"` Commit, Data Change and Rollback Notification Callbacks
2 #![allow(non_camel_case_types)]
3
4 use std::os::raw::{c_char, c_int, c_void};
5 use std::panic::catch_unwind;
6 use std::ptr;
7
8 use crate::ffi;
9
10 use crate::{Connection, InnerConnection};
11
12 /// `feature = "hooks"` Action Codes
13 #[derive(Clone, Copy, Debug, PartialEq)]
14 #[repr(i32)]
15 #[non_exhaustive]
16 pub enum Action {
17 /// Unsupported / unexpected action
18 UNKNOWN = -1,
19 /// DELETE command
20 SQLITE_DELETE = ffi::SQLITE_DELETE,
21 /// INSERT command
22 SQLITE_INSERT = ffi::SQLITE_INSERT,
23 /// UPDATE command
24 SQLITE_UPDATE = ffi::SQLITE_UPDATE,
25 }
26
27 impl From<i32> for Action {
from(code: i32) -> Action28 fn from(code: i32) -> Action {
29 match code {
30 ffi::SQLITE_DELETE => Action::SQLITE_DELETE,
31 ffi::SQLITE_INSERT => Action::SQLITE_INSERT,
32 ffi::SQLITE_UPDATE => Action::SQLITE_UPDATE,
33 _ => Action::UNKNOWN,
34 }
35 }
36 }
37
38 impl Connection {
39 /// `feature = "hooks"` Register a callback function to be invoked whenever
40 /// a transaction is committed.
41 ///
42 /// The callback returns `true` to rollback.
commit_hook<F>(&self, hook: Option<F>) where F: FnMut() -> bool + Send + 'static,43 pub fn commit_hook<F>(&self, hook: Option<F>)
44 where
45 F: FnMut() -> bool + Send + 'static,
46 {
47 self.db.borrow_mut().commit_hook(hook);
48 }
49
50 /// `feature = "hooks"` Register a callback function to be invoked whenever
51 /// a transaction is committed.
52 ///
53 /// The callback returns `true` to rollback.
rollback_hook<F>(&self, hook: Option<F>) where F: FnMut() + Send + 'static,54 pub fn rollback_hook<F>(&self, hook: Option<F>)
55 where
56 F: FnMut() + Send + 'static,
57 {
58 self.db.borrow_mut().rollback_hook(hook);
59 }
60
61 /// `feature = "hooks"` Register a callback function to be invoked whenever
62 /// a row is updated, inserted or deleted in a rowid table.
63 ///
64 /// The callback parameters are:
65 ///
66 /// - the type of database update (SQLITE_INSERT, SQLITE_UPDATE or
67 /// SQLITE_DELETE),
68 /// - the name of the database ("main", "temp", ...),
69 /// - the name of the table that is updated,
70 /// - the ROWID of the row that is updated.
update_hook<F>(&self, hook: Option<F>) where F: FnMut(Action, &str, &str, i64) + Send + 'static,71 pub fn update_hook<F>(&self, hook: Option<F>)
72 where
73 F: FnMut(Action, &str, &str, i64) + Send + 'static,
74 {
75 self.db.borrow_mut().update_hook(hook);
76 }
77 }
78
79 impl InnerConnection {
remove_hooks(&mut self)80 pub fn remove_hooks(&mut self) {
81 self.update_hook(None::<fn(Action, &str, &str, i64)>);
82 self.commit_hook(None::<fn() -> bool>);
83 self.rollback_hook(None::<fn()>);
84 }
85
commit_hook<F>(&mut self, hook: Option<F>) where F: FnMut() -> bool + Send + 'static,86 fn commit_hook<F>(&mut self, hook: Option<F>)
87 where
88 F: FnMut() -> bool + Send + 'static,
89 {
90 unsafe extern "C" fn call_boxed_closure<F>(p_arg: *mut c_void) -> c_int
91 where
92 F: FnMut() -> bool,
93 {
94 let r = catch_unwind(|| {
95 let boxed_hook: *mut F = p_arg as *mut F;
96 (*boxed_hook)()
97 });
98 if let Ok(true) = r {
99 1
100 } else {
101 0
102 }
103 }
104
105 // unlike `sqlite3_create_function_v2`, we cannot specify a `xDestroy` with
106 // `sqlite3_commit_hook`. so we keep the `xDestroy` function in
107 // `InnerConnection.free_boxed_hook`.
108 let free_commit_hook = if hook.is_some() {
109 Some(free_boxed_hook::<F> as unsafe fn(*mut c_void))
110 } else {
111 None
112 };
113
114 let previous_hook = match hook {
115 Some(hook) => {
116 let boxed_hook: *mut F = Box::into_raw(Box::new(hook));
117 unsafe {
118 ffi::sqlite3_commit_hook(
119 self.db(),
120 Some(call_boxed_closure::<F>),
121 boxed_hook as *mut _,
122 )
123 }
124 }
125 _ => unsafe { ffi::sqlite3_commit_hook(self.db(), None, ptr::null_mut()) },
126 };
127 if !previous_hook.is_null() {
128 if let Some(free_boxed_hook) = self.free_commit_hook {
129 unsafe { free_boxed_hook(previous_hook) };
130 }
131 }
132 self.free_commit_hook = free_commit_hook;
133 }
134
rollback_hook<F>(&mut self, hook: Option<F>) where F: FnMut() + Send + 'static,135 fn rollback_hook<F>(&mut self, hook: Option<F>)
136 where
137 F: FnMut() + Send + 'static,
138 {
139 unsafe extern "C" fn call_boxed_closure<F>(p_arg: *mut c_void)
140 where
141 F: FnMut(),
142 {
143 let _ = catch_unwind(|| {
144 let boxed_hook: *mut F = p_arg as *mut F;
145 (*boxed_hook)();
146 });
147 }
148
149 let free_rollback_hook = if hook.is_some() {
150 Some(free_boxed_hook::<F> as unsafe fn(*mut c_void))
151 } else {
152 None
153 };
154
155 let previous_hook = match hook {
156 Some(hook) => {
157 let boxed_hook: *mut F = Box::into_raw(Box::new(hook));
158 unsafe {
159 ffi::sqlite3_rollback_hook(
160 self.db(),
161 Some(call_boxed_closure::<F>),
162 boxed_hook as *mut _,
163 )
164 }
165 }
166 _ => unsafe { ffi::sqlite3_rollback_hook(self.db(), None, ptr::null_mut()) },
167 };
168 if !previous_hook.is_null() {
169 if let Some(free_boxed_hook) = self.free_rollback_hook {
170 unsafe { free_boxed_hook(previous_hook) };
171 }
172 }
173 self.free_rollback_hook = free_rollback_hook;
174 }
175
update_hook<F>(&mut self, hook: Option<F>) where F: FnMut(Action, &str, &str, i64) + Send + 'static,176 fn update_hook<F>(&mut self, hook: Option<F>)
177 where
178 F: FnMut(Action, &str, &str, i64) + Send + 'static,
179 {
180 unsafe extern "C" fn call_boxed_closure<F>(
181 p_arg: *mut c_void,
182 action_code: c_int,
183 db_str: *const c_char,
184 tbl_str: *const c_char,
185 row_id: i64,
186 ) where
187 F: FnMut(Action, &str, &str, i64),
188 {
189 use std::ffi::CStr;
190 use std::str;
191
192 let action = Action::from(action_code);
193 let db_name = {
194 let c_slice = CStr::from_ptr(db_str).to_bytes();
195 str::from_utf8(c_slice)
196 };
197 let tbl_name = {
198 let c_slice = CStr::from_ptr(tbl_str).to_bytes();
199 str::from_utf8(c_slice)
200 };
201
202 let _ = catch_unwind(|| {
203 let boxed_hook: *mut F = p_arg as *mut F;
204 (*boxed_hook)(
205 action,
206 db_name.expect("illegal db name"),
207 tbl_name.expect("illegal table name"),
208 row_id,
209 );
210 });
211 }
212
213 let free_update_hook = if hook.is_some() {
214 Some(free_boxed_hook::<F> as unsafe fn(*mut c_void))
215 } else {
216 None
217 };
218
219 let previous_hook = match hook {
220 Some(hook) => {
221 let boxed_hook: *mut F = Box::into_raw(Box::new(hook));
222 unsafe {
223 ffi::sqlite3_update_hook(
224 self.db(),
225 Some(call_boxed_closure::<F>),
226 boxed_hook as *mut _,
227 )
228 }
229 }
230 _ => unsafe { ffi::sqlite3_update_hook(self.db(), None, ptr::null_mut()) },
231 };
232 if !previous_hook.is_null() {
233 if let Some(free_boxed_hook) = self.free_update_hook {
234 unsafe { free_boxed_hook(previous_hook) };
235 }
236 }
237 self.free_update_hook = free_update_hook;
238 }
239 }
240
free_boxed_hook<F>(p: *mut c_void)241 unsafe fn free_boxed_hook<F>(p: *mut c_void) {
242 drop(Box::from_raw(p as *mut F));
243 }
244
245 #[cfg(test)]
246 mod test {
247 use super::Action;
248 use crate::Connection;
249 use lazy_static::lazy_static;
250 use std::sync::atomic::{AtomicBool, Ordering};
251
252 #[test]
test_commit_hook()253 fn test_commit_hook() {
254 let db = Connection::open_in_memory().unwrap();
255
256 lazy_static! {
257 static ref CALLED: AtomicBool = AtomicBool::new(false);
258 }
259 db.commit_hook(Some(|| {
260 CALLED.store(true, Ordering::Relaxed);
261 false
262 }));
263 db.execute_batch("BEGIN; CREATE TABLE foo (t TEXT); COMMIT;")
264 .unwrap();
265 assert!(CALLED.load(Ordering::Relaxed));
266 }
267
268 #[test]
test_fn_commit_hook()269 fn test_fn_commit_hook() {
270 let db = Connection::open_in_memory().unwrap();
271
272 fn hook() -> bool {
273 true
274 }
275
276 db.commit_hook(Some(hook));
277 db.execute_batch("BEGIN; CREATE TABLE foo (t TEXT); COMMIT;")
278 .unwrap_err();
279 }
280
281 #[test]
test_rollback_hook()282 fn test_rollback_hook() {
283 let db = Connection::open_in_memory().unwrap();
284
285 lazy_static! {
286 static ref CALLED: AtomicBool = AtomicBool::new(false);
287 }
288 db.rollback_hook(Some(|| {
289 CALLED.store(true, Ordering::Relaxed);
290 }));
291 db.execute_batch("BEGIN; CREATE TABLE foo (t TEXT); ROLLBACK;")
292 .unwrap();
293 assert!(CALLED.load(Ordering::Relaxed));
294 }
295
296 #[test]
test_update_hook()297 fn test_update_hook() {
298 let db = Connection::open_in_memory().unwrap();
299
300 lazy_static! {
301 static ref CALLED: AtomicBool = AtomicBool::new(false);
302 }
303 db.update_hook(Some(|action, db: &str, tbl: &str, row_id| {
304 assert_eq!(Action::SQLITE_INSERT, action);
305 assert_eq!("main", db);
306 assert_eq!("foo", tbl);
307 assert_eq!(1, row_id);
308 CALLED.store(true, Ordering::Relaxed);
309 }));
310 db.execute_batch("CREATE TABLE foo (t TEXT)").unwrap();
311 db.execute_batch("INSERT INTO foo VALUES ('lisa')").unwrap();
312 assert!(CALLED.load(Ordering::Relaxed));
313 }
314 }
315