1 //! `feature = "hooks"` Commit, Data Change and Rollback Notification Callbacks
2 #![allow(non_camel_case_types)]
3
4 use std::os::raw::{c_char, c_int, c_void};
5 use std::panic::catch_unwind;
6 use std::ptr;
7
8 use crate::ffi;
9
10 use crate::{Connection, InnerConnection};
11
12 /// `feature = "hooks"` Action Codes
13 #[derive(Clone, Copy, Debug, PartialEq)]
14 #[repr(i32)]
15 #[non_exhaustive]
16 pub enum Action {
17 UNKNOWN = -1,
18 SQLITE_DELETE = ffi::SQLITE_DELETE,
19 SQLITE_INSERT = ffi::SQLITE_INSERT,
20 SQLITE_UPDATE = ffi::SQLITE_UPDATE,
21 }
22
23 impl From<i32> for Action {
from(code: i32) -> Action24 fn from(code: i32) -> Action {
25 match code {
26 ffi::SQLITE_DELETE => Action::SQLITE_DELETE,
27 ffi::SQLITE_INSERT => Action::SQLITE_INSERT,
28 ffi::SQLITE_UPDATE => Action::SQLITE_UPDATE,
29 _ => Action::UNKNOWN,
30 }
31 }
32 }
33
34 impl Connection {
35 /// `feature = "hooks"` Register a callback function to be invoked whenever
36 /// a transaction is committed.
37 ///
38 /// The callback returns `true` to rollback.
commit_hook<F>(&self, hook: Option<F>) where F: FnMut() -> bool + Send + 'static,39 pub fn commit_hook<F>(&self, hook: Option<F>)
40 where
41 F: FnMut() -> bool + Send + 'static,
42 {
43 self.db.borrow_mut().commit_hook(hook);
44 }
45
46 /// `feature = "hooks"` Register a callback function to be invoked whenever
47 /// a transaction is committed.
48 ///
49 /// The callback returns `true` to rollback.
rollback_hook<F>(&self, hook: Option<F>) where F: FnMut() + Send + 'static,50 pub fn rollback_hook<F>(&self, hook: Option<F>)
51 where
52 F: FnMut() + Send + 'static,
53 {
54 self.db.borrow_mut().rollback_hook(hook);
55 }
56
57 /// `feature = "hooks"` Register a callback function to be invoked whenever
58 /// a row is updated, inserted or deleted in a rowid table.
59 ///
60 /// The callback parameters are:
61 ///
62 /// - the type of database update (SQLITE_INSERT, SQLITE_UPDATE or
63 /// SQLITE_DELETE),
64 /// - the name of the database ("main", "temp", ...),
65 /// - the name of the table that is updated,
66 /// - the ROWID of the row that is updated.
update_hook<F>(&self, hook: Option<F>) where F: FnMut(Action, &str, &str, i64) + Send + 'static,67 pub fn update_hook<F>(&self, hook: Option<F>)
68 where
69 F: FnMut(Action, &str, &str, i64) + Send + 'static,
70 {
71 self.db.borrow_mut().update_hook(hook);
72 }
73 }
74
75 impl InnerConnection {
remove_hooks(&mut self)76 pub fn remove_hooks(&mut self) {
77 self.update_hook(None::<fn(Action, &str, &str, i64)>);
78 self.commit_hook(None::<fn() -> bool>);
79 self.rollback_hook(None::<fn()>);
80 }
81
commit_hook<F>(&mut self, hook: Option<F>) where F: FnMut() -> bool + Send + 'static,82 fn commit_hook<F>(&mut self, hook: Option<F>)
83 where
84 F: FnMut() -> bool + Send + 'static,
85 {
86 unsafe extern "C" fn call_boxed_closure<F>(p_arg: *mut c_void) -> c_int
87 where
88 F: FnMut() -> bool,
89 {
90 let r = catch_unwind(|| {
91 let boxed_hook: *mut F = p_arg as *mut F;
92 (*boxed_hook)()
93 });
94 if let Ok(true) = r {
95 1
96 } else {
97 0
98 }
99 }
100
101 // unlike `sqlite3_create_function_v2`, we cannot specify a `xDestroy` with
102 // `sqlite3_commit_hook`. so we keep the `xDestroy` function in
103 // `InnerConnection.free_boxed_hook`.
104 let free_commit_hook = if hook.is_some() {
105 Some(free_boxed_hook::<F> as unsafe fn(*mut c_void))
106 } else {
107 None
108 };
109
110 let previous_hook = match hook {
111 Some(hook) => {
112 let boxed_hook: *mut F = Box::into_raw(Box::new(hook));
113 unsafe {
114 ffi::sqlite3_commit_hook(
115 self.db(),
116 Some(call_boxed_closure::<F>),
117 boxed_hook as *mut _,
118 )
119 }
120 }
121 _ => unsafe { ffi::sqlite3_commit_hook(self.db(), None, ptr::null_mut()) },
122 };
123 if !previous_hook.is_null() {
124 if let Some(free_boxed_hook) = self.free_commit_hook {
125 unsafe { free_boxed_hook(previous_hook) };
126 }
127 }
128 self.free_commit_hook = free_commit_hook;
129 }
130
rollback_hook<F>(&mut self, hook: Option<F>) where F: FnMut() + Send + 'static,131 fn rollback_hook<F>(&mut self, hook: Option<F>)
132 where
133 F: FnMut() + Send + 'static,
134 {
135 unsafe extern "C" fn call_boxed_closure<F>(p_arg: *mut c_void)
136 where
137 F: FnMut(),
138 {
139 let _ = catch_unwind(|| {
140 let boxed_hook: *mut F = p_arg as *mut F;
141 (*boxed_hook)();
142 });
143 }
144
145 let free_rollback_hook = if hook.is_some() {
146 Some(free_boxed_hook::<F> as unsafe fn(*mut c_void))
147 } else {
148 None
149 };
150
151 let previous_hook = match hook {
152 Some(hook) => {
153 let boxed_hook: *mut F = Box::into_raw(Box::new(hook));
154 unsafe {
155 ffi::sqlite3_rollback_hook(
156 self.db(),
157 Some(call_boxed_closure::<F>),
158 boxed_hook as *mut _,
159 )
160 }
161 }
162 _ => unsafe { ffi::sqlite3_rollback_hook(self.db(), None, ptr::null_mut()) },
163 };
164 if !previous_hook.is_null() {
165 if let Some(free_boxed_hook) = self.free_rollback_hook {
166 unsafe { free_boxed_hook(previous_hook) };
167 }
168 }
169 self.free_rollback_hook = free_rollback_hook;
170 }
171
update_hook<F>(&mut self, hook: Option<F>) where F: FnMut(Action, &str, &str, i64) + Send + 'static,172 fn update_hook<F>(&mut self, hook: Option<F>)
173 where
174 F: FnMut(Action, &str, &str, i64) + Send + 'static,
175 {
176 unsafe extern "C" fn call_boxed_closure<F>(
177 p_arg: *mut c_void,
178 action_code: c_int,
179 db_str: *const c_char,
180 tbl_str: *const c_char,
181 row_id: i64,
182 ) where
183 F: FnMut(Action, &str, &str, i64),
184 {
185 use std::ffi::CStr;
186 use std::str;
187
188 let action = Action::from(action_code);
189 let db_name = {
190 let c_slice = CStr::from_ptr(db_str).to_bytes();
191 str::from_utf8(c_slice)
192 };
193 let tbl_name = {
194 let c_slice = CStr::from_ptr(tbl_str).to_bytes();
195 str::from_utf8(c_slice)
196 };
197
198 let _ = catch_unwind(|| {
199 let boxed_hook: *mut F = p_arg as *mut F;
200 (*boxed_hook)(
201 action,
202 db_name.expect("illegal db name"),
203 tbl_name.expect("illegal table name"),
204 row_id,
205 );
206 });
207 }
208
209 let free_update_hook = if hook.is_some() {
210 Some(free_boxed_hook::<F> as unsafe fn(*mut c_void))
211 } else {
212 None
213 };
214
215 let previous_hook = match hook {
216 Some(hook) => {
217 let boxed_hook: *mut F = Box::into_raw(Box::new(hook));
218 unsafe {
219 ffi::sqlite3_update_hook(
220 self.db(),
221 Some(call_boxed_closure::<F>),
222 boxed_hook as *mut _,
223 )
224 }
225 }
226 _ => unsafe { ffi::sqlite3_update_hook(self.db(), None, ptr::null_mut()) },
227 };
228 if !previous_hook.is_null() {
229 if let Some(free_boxed_hook) = self.free_update_hook {
230 unsafe { free_boxed_hook(previous_hook) };
231 }
232 }
233 self.free_update_hook = free_update_hook;
234 }
235 }
236
free_boxed_hook<F>(p: *mut c_void)237 unsafe fn free_boxed_hook<F>(p: *mut c_void) {
238 drop(Box::from_raw(p as *mut F));
239 }
240
241 #[cfg(test)]
242 mod test {
243 use super::Action;
244 use crate::Connection;
245 use lazy_static::lazy_static;
246 use std::sync::atomic::{AtomicBool, Ordering};
247
248 #[test]
test_commit_hook()249 fn test_commit_hook() {
250 let db = Connection::open_in_memory().unwrap();
251
252 lazy_static! {
253 static ref CALLED: AtomicBool = AtomicBool::new(false);
254 }
255 db.commit_hook(Some(|| {
256 CALLED.store(true, Ordering::Relaxed);
257 false
258 }));
259 db.execute_batch("BEGIN; CREATE TABLE foo (t TEXT); COMMIT;")
260 .unwrap();
261 assert!(CALLED.load(Ordering::Relaxed));
262 }
263
264 #[test]
test_fn_commit_hook()265 fn test_fn_commit_hook() {
266 let db = Connection::open_in_memory().unwrap();
267
268 fn hook() -> bool {
269 true
270 }
271
272 db.commit_hook(Some(hook));
273 db.execute_batch("BEGIN; CREATE TABLE foo (t TEXT); COMMIT;")
274 .unwrap_err();
275 }
276
277 #[test]
test_rollback_hook()278 fn test_rollback_hook() {
279 let db = Connection::open_in_memory().unwrap();
280
281 lazy_static! {
282 static ref CALLED: AtomicBool = AtomicBool::new(false);
283 }
284 db.rollback_hook(Some(|| {
285 CALLED.store(true, Ordering::Relaxed);
286 }));
287 db.execute_batch("BEGIN; CREATE TABLE foo (t TEXT); ROLLBACK;")
288 .unwrap();
289 assert!(CALLED.load(Ordering::Relaxed));
290 }
291
292 #[test]
test_update_hook()293 fn test_update_hook() {
294 let db = Connection::open_in_memory().unwrap();
295
296 lazy_static! {
297 static ref CALLED: AtomicBool = AtomicBool::new(false);
298 }
299 db.update_hook(Some(|action, db: &str, tbl: &str, row_id| {
300 assert_eq!(Action::SQLITE_INSERT, action);
301 assert_eq!("main", db);
302 assert_eq!("foo", tbl);
303 assert_eq!(1, row_id);
304 CALLED.store(true, Ordering::Relaxed);
305 }));
306 db.execute_batch("CREATE TABLE foo (t TEXT)").unwrap();
307 db.execute_batch("INSERT INTO foo VALUES ('lisa')").unwrap();
308 assert!(CALLED.load(Ordering::Relaxed));
309 }
310 }
311