1 //! `feature = "functions"` Create or redefine SQL functions.
2 //!
3 //! # Example
4 //!
5 //! Adding a `regexp` function to a connection in which compiled regular
6 //! expressions are cached in a `HashMap`. For an alternative implementation
7 //! that uses SQLite's [Function Auxilliary Data](https://www.sqlite.org/c3ref/get_auxdata.html) interface
8 //! to avoid recompiling regular expressions, see the unit tests for this
9 //! module.
10 //!
11 //! ```rust
12 //! use regex::Regex;
13 //! use rusqlite::functions::FunctionFlags;
14 //! use rusqlite::{Connection, Error, Result, NO_PARAMS};
15 //! use std::sync::Arc;
16 //! type BoxError = Box<dyn std::error::Error + Send + Sync + 'static>;
17 //!
18 //! fn add_regexp_function(db: &Connection) -> Result<()> {
19 //!     db.create_scalar_function(
20 //!         "regexp",
21 //!         2,
22 //!         FunctionFlags::SQLITE_UTF8 | FunctionFlags::SQLITE_DETERMINISTIC,
23 //!         move |ctx| {
24 //!             assert_eq!(ctx.len(), 2, "called with unexpected number of arguments");
25 //!             let regexp: Arc<Regex> = ctx
26 //!                 .get_or_create_aux(0, |vr| -> Result<_, BoxError> {
27 //!                     Ok(Regex::new(vr.as_str()?)?)
28 //!                 })?;
29 //!             let is_match = {
30 //!                 let text = ctx
31 //!                     .get_raw(1)
32 //!                     .as_str()
33 //!                     .map_err(|e| Error::UserFunctionError(e.into()))?;
34 //!
35 //!                 regexp.is_match(text)
36 //!             };
37 //!
38 //!             Ok(is_match)
39 //!         },
40 //!     )
41 //! }
42 //!
43 //! fn main() -> Result<()> {
44 //!     let db = Connection::open_in_memory()?;
45 //!     add_regexp_function(&db)?;
46 //!
47 //!     let is_match: bool = db.query_row(
48 //!         "SELECT regexp('[aeiou]*', 'aaaaeeeiii')",
49 //!         NO_PARAMS,
50 //!         |row| row.get(0),
51 //!     )?;
52 //!
53 //!     assert!(is_match);
54 //!     Ok(())
55 //! }
56 //! ```
57 use std::any::Any;
58 use std::os::raw::{c_int, c_void};
59 use std::panic::{catch_unwind, RefUnwindSafe, UnwindSafe};
60 use std::ptr;
61 use std::slice;
62 use std::sync::Arc;
63 
64 use crate::ffi;
65 use crate::ffi::sqlite3_context;
66 use crate::ffi::sqlite3_value;
67 
68 use crate::context::set_result;
69 use crate::types::{FromSql, FromSqlError, ToSql, ValueRef};
70 
71 use crate::{str_to_cstring, Connection, Error, InnerConnection, Result};
72 
report_error(ctx: *mut sqlite3_context, err: &Error)73 unsafe fn report_error(ctx: *mut sqlite3_context, err: &Error) {
74     // Extended constraint error codes were added in SQLite 3.7.16. We don't have
75     // an explicit feature check for that, and this doesn't really warrant one.
76     // We'll use the extended code if we're on the bundled version (since it's
77     // at least 3.17.0) and the normal constraint error code if not.
78     #[cfg(feature = "modern_sqlite")]
79     fn constraint_error_code() -> i32 {
80         ffi::SQLITE_CONSTRAINT_FUNCTION
81     }
82     #[cfg(not(feature = "modern_sqlite"))]
83     fn constraint_error_code() -> i32 {
84         ffi::SQLITE_CONSTRAINT
85     }
86 
87     match *err {
88         Error::SqliteFailure(ref err, ref s) => {
89             ffi::sqlite3_result_error_code(ctx, err.extended_code);
90             if let Some(Ok(cstr)) = s.as_ref().map(|s| str_to_cstring(s)) {
91                 ffi::sqlite3_result_error(ctx, cstr.as_ptr(), -1);
92             }
93         }
94         _ => {
95             ffi::sqlite3_result_error_code(ctx, constraint_error_code());
96             if let Ok(cstr) = str_to_cstring(&err.to_string()) {
97                 ffi::sqlite3_result_error(ctx, cstr.as_ptr(), -1);
98             }
99         }
100     }
101 }
102 
free_boxed_value<T>(p: *mut c_void)103 unsafe extern "C" fn free_boxed_value<T>(p: *mut c_void) {
104     drop(Box::from_raw(p as *mut T));
105 }
106 
107 /// `feature = "functions"` Context is a wrapper for the SQLite function
108 /// evaluation context.
109 pub struct Context<'a> {
110     ctx: *mut sqlite3_context,
111     args: &'a [*mut sqlite3_value],
112 }
113 
114 impl Context<'_> {
115     /// Returns the number of arguments to the function.
len(&self) -> usize116     pub fn len(&self) -> usize {
117         self.args.len()
118     }
119 
120     /// Returns `true` when there is no argument.
is_empty(&self) -> bool121     pub fn is_empty(&self) -> bool {
122         self.args.is_empty()
123     }
124 
125     /// Returns the `idx`th argument as a `T`.
126     ///
127     /// # Failure
128     ///
129     /// Will panic if `idx` is greater than or equal to `self.len()`.
130     ///
131     /// Will return Err if the underlying SQLite type cannot be converted to a
132     /// `T`.
get<T: FromSql>(&self, idx: usize) -> Result<T>133     pub fn get<T: FromSql>(&self, idx: usize) -> Result<T> {
134         let arg = self.args[idx];
135         let value = unsafe { ValueRef::from_value(arg) };
136         FromSql::column_result(value).map_err(|err| match err {
137             FromSqlError::InvalidType => {
138                 Error::InvalidFunctionParameterType(idx, value.data_type())
139             }
140             FromSqlError::OutOfRange(i) => Error::IntegralValueOutOfRange(idx, i),
141             FromSqlError::Other(err) => {
142                 Error::FromSqlConversionFailure(idx, value.data_type(), err)
143             }
144             #[cfg(feature = "i128_blob")]
145             FromSqlError::InvalidI128Size(_) => {
146                 Error::FromSqlConversionFailure(idx, value.data_type(), Box::new(err))
147             }
148             #[cfg(feature = "uuid")]
149             FromSqlError::InvalidUuidSize(_) => {
150                 Error::FromSqlConversionFailure(idx, value.data_type(), Box::new(err))
151             }
152         })
153     }
154 
155     /// Returns the `idx`th argument as a `ValueRef`.
156     ///
157     /// # Failure
158     ///
159     /// Will panic if `idx` is greater than or equal to `self.len()`.
get_raw(&self, idx: usize) -> ValueRef<'_>160     pub fn get_raw(&self, idx: usize) -> ValueRef<'_> {
161         let arg = self.args[idx];
162         unsafe { ValueRef::from_value(arg) }
163     }
164 
165     /// Fetch or insert the the auxilliary data associated with a particular
166     /// parameter. This is intended to be an easier-to-use way of fetching it
167     /// compared to calling `get_aux` and `set_aux` separately.
168     ///
169     /// See https://www.sqlite.org/c3ref/get_auxdata.html for a discussion of
170     /// this feature, or the unit tests of this module for an example.
get_or_create_aux<T, E, F>(&self, arg: c_int, func: F) -> Result<Arc<T>> where T: Send + Sync + 'static, E: Into<Box<dyn std::error::Error + Send + Sync + 'static>>, F: FnOnce(ValueRef<'_>) -> Result<T, E>,171     pub fn get_or_create_aux<T, E, F>(&self, arg: c_int, func: F) -> Result<Arc<T>>
172     where
173         T: Send + Sync + 'static,
174         E: Into<Box<dyn std::error::Error + Send + Sync + 'static>>,
175         F: FnOnce(ValueRef<'_>) -> Result<T, E>,
176     {
177         if let Some(v) = self.get_aux(arg)? {
178             Ok(v)
179         } else {
180             let vr = self.get_raw(arg as usize);
181             self.set_aux(
182                 arg,
183                 func(vr).map_err(|e| Error::UserFunctionError(e.into()))?,
184             )
185         }
186     }
187 
188     /// Sets the auxilliary data associated with a particular parameter. See
189     /// https://www.sqlite.org/c3ref/get_auxdata.html for a discussion of
190     /// this feature, or the unit tests of this module for an example.
set_aux<T: Send + Sync + 'static>(&self, arg: c_int, value: T) -> Result<Arc<T>>191     pub fn set_aux<T: Send + Sync + 'static>(&self, arg: c_int, value: T) -> Result<Arc<T>> {
192         let orig: Arc<T> = Arc::new(value);
193         let inner: AuxInner = orig.clone();
194         let outer = Box::new(inner);
195         let raw: *mut AuxInner = Box::into_raw(outer);
196         unsafe {
197             ffi::sqlite3_set_auxdata(
198                 self.ctx,
199                 arg,
200                 raw as *mut _,
201                 Some(free_boxed_value::<AuxInner>),
202             )
203         };
204         Ok(orig)
205     }
206 
207     /// Gets the auxilliary data that was associated with a given parameter via
208     /// `set_aux`. Returns `Ok(None)` if no data has been associated, and
209     /// Ok(Some(v)) if it has. Returns an error if the requested type does not
210     /// match.
get_aux<T: Send + Sync + 'static>(&self, arg: c_int) -> Result<Option<Arc<T>>>211     pub fn get_aux<T: Send + Sync + 'static>(&self, arg: c_int) -> Result<Option<Arc<T>>> {
212         let p = unsafe { ffi::sqlite3_get_auxdata(self.ctx, arg) as *const AuxInner };
213         if p.is_null() {
214             Ok(None)
215         } else {
216             let v: AuxInner = AuxInner::clone(unsafe { &*p });
217             v.downcast::<T>()
218                 .map(Some)
219                 .map_err(|_| Error::GetAuxWrongType)
220         }
221     }
222 }
223 
224 type AuxInner = Arc<dyn Any + Send + Sync + 'static>;
225 
226 /// `feature = "functions"` Aggregate is the callback interface for user-defined
227 /// aggregate function.
228 ///
229 /// `A` is the type of the aggregation context and `T` is the type of the final
230 /// result. Implementations should be stateless.
231 pub trait Aggregate<A, T>
232 where
233     A: RefUnwindSafe + UnwindSafe,
234     T: ToSql,
235 {
236     /// Initializes the aggregation context. Will be called prior to the first
237     /// call to `step()` to set up the context for an invocation of the
238     /// function. (Note: `init()` will not be called if there are no rows.)
init(&self) -> A239     fn init(&self) -> A;
240 
241     /// "step" function called once for each row in an aggregate group. May be
242     /// called 0 times if there are no rows.
step(&self, _: &mut Context<'_>, _: &mut A) -> Result<()>243     fn step(&self, _: &mut Context<'_>, _: &mut A) -> Result<()>;
244 
245     /// Computes and returns the final result. Will be called exactly once for
246     /// each invocation of the function. If `step()` was called at least
247     /// once, will be given `Some(A)` (the same `A` as was created by
248     /// `init` and given to `step`); if `step()` was not called (because
249     /// the function is running against 0 rows), will be given `None`.
finalize(&self, _: Option<A>) -> Result<T>250     fn finalize(&self, _: Option<A>) -> Result<T>;
251 }
252 
253 /// `feature = "window"` WindowAggregate is the callback interface for
254 /// user-defined aggregate window function.
255 #[cfg(feature = "window")]
256 pub trait WindowAggregate<A, T>: Aggregate<A, T>
257 where
258     A: RefUnwindSafe + UnwindSafe,
259     T: ToSql,
260 {
261     /// Returns the current value of the aggregate. Unlike xFinal, the
262     /// implementation should not delete any context.
value(&self, _: Option<&A>) -> Result<T>263     fn value(&self, _: Option<&A>) -> Result<T>;
264 
265     /// Removes a row from the current window.
inverse(&self, _: &mut Context<'_>, _: &mut A) -> Result<()>266     fn inverse(&self, _: &mut Context<'_>, _: &mut A) -> Result<()>;
267 }
268 
269 bitflags::bitflags! {
270     #[doc = "Function Flags."]
271     #[doc = "See [sqlite3_create_function](https://sqlite.org/c3ref/create_function.html) for details."]
272     #[repr(C)]
273     pub struct FunctionFlags: ::std::os::raw::c_int {
274         const SQLITE_UTF8     = ffi::SQLITE_UTF8;
275         const SQLITE_UTF16LE  = ffi::SQLITE_UTF16LE;
276         const SQLITE_UTF16BE  = ffi::SQLITE_UTF16BE;
277         const SQLITE_UTF16    = ffi::SQLITE_UTF16;
278         const SQLITE_DETERMINISTIC = ffi::SQLITE_DETERMINISTIC;
279         const SQLITE_DIRECTONLY    = 0x0000_0008_0000; // 3.30.0
280         const SQLITE_SUBTYPE       = 0x0000_0010_0000; // 3.30.0
281         const SQLITE_INNOCUOUS     = 0x0000_0020_0000; // 3.31.0
282     }
283 }
284 
285 impl Default for FunctionFlags {
default() -> FunctionFlags286     fn default() -> FunctionFlags {
287         FunctionFlags::SQLITE_UTF8
288     }
289 }
290 
291 impl Connection {
292     /// `feature = "functions"` Attach a user-defined scalar function to
293     /// this database connection.
294     ///
295     /// `fn_name` is the name the function will be accessible from SQL.
296     /// `n_arg` is the number of arguments to the function. Use `-1` for a
297     /// variable number. If the function always returns the same value
298     /// given the same input, `deterministic` should be `true`.
299     ///
300     /// The function will remain available until the connection is closed or
301     /// until it is explicitly removed via `remove_function`.
302     ///
303     /// # Example
304     ///
305     /// ```rust
306     /// # use rusqlite::{Connection, Result, NO_PARAMS};
307     /// # use rusqlite::functions::FunctionFlags;
308     /// fn scalar_function_example(db: Connection) -> Result<()> {
309     ///     db.create_scalar_function(
310     ///         "halve",
311     ///         1,
312     ///         FunctionFlags::SQLITE_UTF8 | FunctionFlags::SQLITE_DETERMINISTIC,
313     ///         |ctx| {
314     ///             let value = ctx.get::<f64>(0)?;
315     ///             Ok(value / 2f64)
316     ///         },
317     ///     )?;
318     ///
319     ///     let six_halved: f64 = db.query_row("SELECT halve(6)", NO_PARAMS, |r| r.get(0))?;
320     ///     assert_eq!(six_halved, 3f64);
321     ///     Ok(())
322     /// }
323     /// ```
324     ///
325     /// # Failure
326     ///
327     /// Will return Err if the function could not be attached to the connection.
create_scalar_function<F, T>( &self, fn_name: &str, n_arg: c_int, flags: FunctionFlags, x_func: F, ) -> Result<()> where F: FnMut(&Context<'_>) -> Result<T> + Send + UnwindSafe + 'static, T: ToSql,328     pub fn create_scalar_function<F, T>(
329         &self,
330         fn_name: &str,
331         n_arg: c_int,
332         flags: FunctionFlags,
333         x_func: F,
334     ) -> Result<()>
335     where
336         F: FnMut(&Context<'_>) -> Result<T> + Send + UnwindSafe + 'static,
337         T: ToSql,
338     {
339         self.db
340             .borrow_mut()
341             .create_scalar_function(fn_name, n_arg, flags, x_func)
342     }
343 
344     /// `feature = "functions"` Attach a user-defined aggregate function to this
345     /// database connection.
346     ///
347     /// # Failure
348     ///
349     /// Will return Err if the function could not be attached to the connection.
create_aggregate_function<A, D, T>( &self, fn_name: &str, n_arg: c_int, flags: FunctionFlags, aggr: D, ) -> Result<()> where A: RefUnwindSafe + UnwindSafe, D: Aggregate<A, T>, T: ToSql,350     pub fn create_aggregate_function<A, D, T>(
351         &self,
352         fn_name: &str,
353         n_arg: c_int,
354         flags: FunctionFlags,
355         aggr: D,
356     ) -> Result<()>
357     where
358         A: RefUnwindSafe + UnwindSafe,
359         D: Aggregate<A, T>,
360         T: ToSql,
361     {
362         self.db
363             .borrow_mut()
364             .create_aggregate_function(fn_name, n_arg, flags, aggr)
365     }
366 
367     /// `feature = "window"` Attach a user-defined aggregate window function to
368     /// this database connection.
369     ///
370     /// See https://sqlite.org/windowfunctions.html#udfwinfunc for more
371     /// information.
372     #[cfg(feature = "window")]
create_window_function<A, W, T>( &self, fn_name: &str, n_arg: c_int, flags: FunctionFlags, aggr: W, ) -> Result<()> where A: RefUnwindSafe + UnwindSafe, W: WindowAggregate<A, T>, T: ToSql,373     pub fn create_window_function<A, W, T>(
374         &self,
375         fn_name: &str,
376         n_arg: c_int,
377         flags: FunctionFlags,
378         aggr: W,
379     ) -> Result<()>
380     where
381         A: RefUnwindSafe + UnwindSafe,
382         W: WindowAggregate<A, T>,
383         T: ToSql,
384     {
385         self.db
386             .borrow_mut()
387             .create_window_function(fn_name, n_arg, flags, aggr)
388     }
389 
390     /// `feature = "functions"` Removes a user-defined function from this
391     /// database connection.
392     ///
393     /// `fn_name` and `n_arg` should match the name and number of arguments
394     /// given to `create_scalar_function` or `create_aggregate_function`.
395     ///
396     /// # Failure
397     ///
398     /// Will return Err if the function could not be removed.
remove_function(&self, fn_name: &str, n_arg: c_int) -> Result<()>399     pub fn remove_function(&self, fn_name: &str, n_arg: c_int) -> Result<()> {
400         self.db.borrow_mut().remove_function(fn_name, n_arg)
401     }
402 }
403 
404 impl InnerConnection {
create_scalar_function<F, T>( &mut self, fn_name: &str, n_arg: c_int, flags: FunctionFlags, x_func: F, ) -> Result<()> where F: FnMut(&Context<'_>) -> Result<T> + Send + UnwindSafe + 'static, T: ToSql,405     fn create_scalar_function<F, T>(
406         &mut self,
407         fn_name: &str,
408         n_arg: c_int,
409         flags: FunctionFlags,
410         x_func: F,
411     ) -> Result<()>
412     where
413         F: FnMut(&Context<'_>) -> Result<T> + Send + UnwindSafe + 'static,
414         T: ToSql,
415     {
416         unsafe extern "C" fn call_boxed_closure<F, T>(
417             ctx: *mut sqlite3_context,
418             argc: c_int,
419             argv: *mut *mut sqlite3_value,
420         ) where
421             F: FnMut(&Context<'_>) -> Result<T>,
422             T: ToSql,
423         {
424             let r = catch_unwind(|| {
425                 let boxed_f: *mut F = ffi::sqlite3_user_data(ctx) as *mut F;
426                 assert!(!boxed_f.is_null(), "Internal error - null function pointer");
427                 let ctx = Context {
428                     ctx,
429                     args: slice::from_raw_parts(argv, argc as usize),
430                 };
431                 (*boxed_f)(&ctx)
432             });
433             let t = match r {
434                 Err(_) => {
435                     report_error(ctx, &Error::UnwindingPanic);
436                     return;
437                 }
438                 Ok(r) => r,
439             };
440             let t = t.as_ref().map(|t| ToSql::to_sql(t));
441 
442             match t {
443                 Ok(Ok(ref value)) => set_result(ctx, value),
444                 Ok(Err(err)) => report_error(ctx, &err),
445                 Err(err) => report_error(ctx, err),
446             }
447         }
448 
449         let boxed_f: *mut F = Box::into_raw(Box::new(x_func));
450         let c_name = str_to_cstring(fn_name)?;
451         let r = unsafe {
452             ffi::sqlite3_create_function_v2(
453                 self.db(),
454                 c_name.as_ptr(),
455                 n_arg,
456                 flags.bits(),
457                 boxed_f as *mut c_void,
458                 Some(call_boxed_closure::<F, T>),
459                 None,
460                 None,
461                 Some(free_boxed_value::<F>),
462             )
463         };
464         self.decode_result(r)
465     }
466 
create_aggregate_function<A, D, T>( &mut self, fn_name: &str, n_arg: c_int, flags: FunctionFlags, aggr: D, ) -> Result<()> where A: RefUnwindSafe + UnwindSafe, D: Aggregate<A, T>, T: ToSql,467     fn create_aggregate_function<A, D, T>(
468         &mut self,
469         fn_name: &str,
470         n_arg: c_int,
471         flags: FunctionFlags,
472         aggr: D,
473     ) -> Result<()>
474     where
475         A: RefUnwindSafe + UnwindSafe,
476         D: Aggregate<A, T>,
477         T: ToSql,
478     {
479         let boxed_aggr: *mut D = Box::into_raw(Box::new(aggr));
480         let c_name = str_to_cstring(fn_name)?;
481         let r = unsafe {
482             ffi::sqlite3_create_function_v2(
483                 self.db(),
484                 c_name.as_ptr(),
485                 n_arg,
486                 flags.bits(),
487                 boxed_aggr as *mut c_void,
488                 None,
489                 Some(call_boxed_step::<A, D, T>),
490                 Some(call_boxed_final::<A, D, T>),
491                 Some(free_boxed_value::<D>),
492             )
493         };
494         self.decode_result(r)
495     }
496 
497     #[cfg(feature = "window")]
create_window_function<A, W, T>( &mut self, fn_name: &str, n_arg: c_int, flags: FunctionFlags, aggr: W, ) -> Result<()> where A: RefUnwindSafe + UnwindSafe, W: WindowAggregate<A, T>, T: ToSql,498     fn create_window_function<A, W, T>(
499         &mut self,
500         fn_name: &str,
501         n_arg: c_int,
502         flags: FunctionFlags,
503         aggr: W,
504     ) -> Result<()>
505     where
506         A: RefUnwindSafe + UnwindSafe,
507         W: WindowAggregate<A, T>,
508         T: ToSql,
509     {
510         let boxed_aggr: *mut W = Box::into_raw(Box::new(aggr));
511         let c_name = str_to_cstring(fn_name)?;
512         let r = unsafe {
513             ffi::sqlite3_create_window_function(
514                 self.db(),
515                 c_name.as_ptr(),
516                 n_arg,
517                 flags.bits(),
518                 boxed_aggr as *mut c_void,
519                 Some(call_boxed_step::<A, W, T>),
520                 Some(call_boxed_final::<A, W, T>),
521                 Some(call_boxed_value::<A, W, T>),
522                 Some(call_boxed_inverse::<A, W, T>),
523                 Some(free_boxed_value::<W>),
524             )
525         };
526         self.decode_result(r)
527     }
528 
remove_function(&mut self, fn_name: &str, n_arg: c_int) -> Result<()>529     fn remove_function(&mut self, fn_name: &str, n_arg: c_int) -> Result<()> {
530         let c_name = str_to_cstring(fn_name)?;
531         let r = unsafe {
532             ffi::sqlite3_create_function_v2(
533                 self.db(),
534                 c_name.as_ptr(),
535                 n_arg,
536                 ffi::SQLITE_UTF8,
537                 ptr::null_mut(),
538                 None,
539                 None,
540                 None,
541                 None,
542             )
543         };
544         self.decode_result(r)
545     }
546 }
547 
aggregate_context<A>(ctx: *mut sqlite3_context, bytes: usize) -> Option<*mut *mut A>548 unsafe fn aggregate_context<A>(ctx: *mut sqlite3_context, bytes: usize) -> Option<*mut *mut A> {
549     let pac = ffi::sqlite3_aggregate_context(ctx, bytes as c_int) as *mut *mut A;
550     if pac.is_null() {
551         return None;
552     }
553     Some(pac)
554 }
555 
call_boxed_step<A, D, T>( ctx: *mut sqlite3_context, argc: c_int, argv: *mut *mut sqlite3_value, ) where A: RefUnwindSafe + UnwindSafe, D: Aggregate<A, T>, T: ToSql,556 unsafe extern "C" fn call_boxed_step<A, D, T>(
557     ctx: *mut sqlite3_context,
558     argc: c_int,
559     argv: *mut *mut sqlite3_value,
560 ) where
561     A: RefUnwindSafe + UnwindSafe,
562     D: Aggregate<A, T>,
563     T: ToSql,
564 {
565     let pac = match aggregate_context(ctx, ::std::mem::size_of::<*mut A>()) {
566         Some(pac) => pac,
567         None => {
568             ffi::sqlite3_result_error_nomem(ctx);
569             return;
570         }
571     };
572 
573     let r = catch_unwind(|| {
574         let boxed_aggr: *mut D = ffi::sqlite3_user_data(ctx) as *mut D;
575         assert!(
576             !boxed_aggr.is_null(),
577             "Internal error - null aggregate pointer"
578         );
579         if (*pac as *mut A).is_null() {
580             *pac = Box::into_raw(Box::new((*boxed_aggr).init()));
581         }
582         let mut ctx = Context {
583             ctx,
584             args: slice::from_raw_parts(argv, argc as usize),
585         };
586         (*boxed_aggr).step(&mut ctx, &mut **pac)
587     });
588     let r = match r {
589         Err(_) => {
590             report_error(ctx, &Error::UnwindingPanic);
591             return;
592         }
593         Ok(r) => r,
594     };
595     match r {
596         Ok(_) => {}
597         Err(err) => report_error(ctx, &err),
598     };
599 }
600 
601 #[cfg(feature = "window")]
call_boxed_inverse<A, W, T>( ctx: *mut sqlite3_context, argc: c_int, argv: *mut *mut sqlite3_value, ) where A: RefUnwindSafe + UnwindSafe, W: WindowAggregate<A, T>, T: ToSql,602 unsafe extern "C" fn call_boxed_inverse<A, W, T>(
603     ctx: *mut sqlite3_context,
604     argc: c_int,
605     argv: *mut *mut sqlite3_value,
606 ) where
607     A: RefUnwindSafe + UnwindSafe,
608     W: WindowAggregate<A, T>,
609     T: ToSql,
610 {
611     let pac = match aggregate_context(ctx, ::std::mem::size_of::<*mut A>()) {
612         Some(pac) => pac,
613         None => {
614             ffi::sqlite3_result_error_nomem(ctx);
615             return;
616         }
617     };
618 
619     let r = catch_unwind(|| {
620         let boxed_aggr: *mut W = ffi::sqlite3_user_data(ctx) as *mut W;
621         assert!(
622             !boxed_aggr.is_null(),
623             "Internal error - null aggregate pointer"
624         );
625         let mut ctx = Context {
626             ctx,
627             args: slice::from_raw_parts(argv, argc as usize),
628         };
629         (*boxed_aggr).inverse(&mut ctx, &mut **pac)
630     });
631     let r = match r {
632         Err(_) => {
633             report_error(ctx, &Error::UnwindingPanic);
634             return;
635         }
636         Ok(r) => r,
637     };
638     match r {
639         Ok(_) => {}
640         Err(err) => report_error(ctx, &err),
641     };
642 }
643 
call_boxed_final<A, D, T>(ctx: *mut sqlite3_context) where A: RefUnwindSafe + UnwindSafe, D: Aggregate<A, T>, T: ToSql,644 unsafe extern "C" fn call_boxed_final<A, D, T>(ctx: *mut sqlite3_context)
645 where
646     A: RefUnwindSafe + UnwindSafe,
647     D: Aggregate<A, T>,
648     T: ToSql,
649 {
650     // Within the xFinal callback, it is customary to set N=0 in calls to
651     // sqlite3_aggregate_context(C,N) so that no pointless memory allocations occur.
652     let a: Option<A> = match aggregate_context(ctx, 0) {
653         Some(pac) => {
654             if (*pac as *mut A).is_null() {
655                 None
656             } else {
657                 let a = Box::from_raw(*pac);
658                 Some(*a)
659             }
660         }
661         None => None,
662     };
663 
664     let r = catch_unwind(|| {
665         let boxed_aggr: *mut D = ffi::sqlite3_user_data(ctx) as *mut D;
666         assert!(
667             !boxed_aggr.is_null(),
668             "Internal error - null aggregate pointer"
669         );
670         (*boxed_aggr).finalize(a)
671     });
672     let t = match r {
673         Err(_) => {
674             report_error(ctx, &Error::UnwindingPanic);
675             return;
676         }
677         Ok(r) => r,
678     };
679     let t = t.as_ref().map(|t| ToSql::to_sql(t));
680     match t {
681         Ok(Ok(ref value)) => set_result(ctx, value),
682         Ok(Err(err)) => report_error(ctx, &err),
683         Err(err) => report_error(ctx, err),
684     }
685 }
686 
687 #[cfg(feature = "window")]
call_boxed_value<A, W, T>(ctx: *mut sqlite3_context) where A: RefUnwindSafe + UnwindSafe, W: WindowAggregate<A, T>, T: ToSql,688 unsafe extern "C" fn call_boxed_value<A, W, T>(ctx: *mut sqlite3_context)
689 where
690     A: RefUnwindSafe + UnwindSafe,
691     W: WindowAggregate<A, T>,
692     T: ToSql,
693 {
694     // Within the xValue callback, it is customary to set N=0 in calls to
695     // sqlite3_aggregate_context(C,N) so that no pointless memory allocations occur.
696     let a: Option<&A> = match aggregate_context(ctx, 0) {
697         Some(pac) => {
698             if (*pac as *mut A).is_null() {
699                 None
700             } else {
701                 let a = &**pac;
702                 Some(a)
703             }
704         }
705         None => None,
706     };
707 
708     let r = catch_unwind(|| {
709         let boxed_aggr: *mut W = ffi::sqlite3_user_data(ctx) as *mut W;
710         assert!(
711             !boxed_aggr.is_null(),
712             "Internal error - null aggregate pointer"
713         );
714         (*boxed_aggr).value(a)
715     });
716     let t = match r {
717         Err(_) => {
718             report_error(ctx, &Error::UnwindingPanic);
719             return;
720         }
721         Ok(r) => r,
722     };
723     let t = t.as_ref().map(|t| ToSql::to_sql(t));
724     match t {
725         Ok(Ok(ref value)) => set_result(ctx, value),
726         Ok(Err(err)) => report_error(ctx, &err),
727         Err(err) => report_error(ctx, err),
728     }
729 }
730 
731 #[cfg(test)]
732 mod test {
733     use regex::Regex;
734     use std::f64::EPSILON;
735     use std::os::raw::c_double;
736 
737     #[cfg(feature = "window")]
738     use crate::functions::WindowAggregate;
739     use crate::functions::{Aggregate, Context, FunctionFlags};
740     use crate::{Connection, Error, Result, NO_PARAMS};
741 
half(ctx: &Context<'_>) -> Result<c_double>742     fn half(ctx: &Context<'_>) -> Result<c_double> {
743         assert_eq!(ctx.len(), 1, "called with unexpected number of arguments");
744         let value = ctx.get::<c_double>(0)?;
745         Ok(value / 2f64)
746     }
747 
748     #[test]
test_function_half()749     fn test_function_half() {
750         let db = Connection::open_in_memory().unwrap();
751         db.create_scalar_function(
752             "half",
753             1,
754             FunctionFlags::SQLITE_UTF8 | FunctionFlags::SQLITE_DETERMINISTIC,
755             half,
756         )
757         .unwrap();
758         let result: Result<f64> = db.query_row("SELECT half(6)", NO_PARAMS, |r| r.get(0));
759 
760         assert!((3f64 - result.unwrap()).abs() < EPSILON);
761     }
762 
763     #[test]
test_remove_function()764     fn test_remove_function() {
765         let db = Connection::open_in_memory().unwrap();
766         db.create_scalar_function(
767             "half",
768             1,
769             FunctionFlags::SQLITE_UTF8 | FunctionFlags::SQLITE_DETERMINISTIC,
770             half,
771         )
772         .unwrap();
773         let result: Result<f64> = db.query_row("SELECT half(6)", NO_PARAMS, |r| r.get(0));
774         assert!((3f64 - result.unwrap()).abs() < EPSILON);
775 
776         db.remove_function("half", 1).unwrap();
777         let result: Result<f64> = db.query_row("SELECT half(6)", NO_PARAMS, |r| r.get(0));
778         assert!(result.is_err());
779     }
780 
781     // This implementation of a regexp scalar function uses SQLite's auxilliary data
782     // (https://www.sqlite.org/c3ref/get_auxdata.html) to avoid recompiling the regular
783     // expression multiple times within one query.
regexp_with_auxilliary(ctx: &Context<'_>) -> Result<bool>784     fn regexp_with_auxilliary(ctx: &Context<'_>) -> Result<bool> {
785         assert_eq!(ctx.len(), 2, "called with unexpected number of arguments");
786         type BoxError = Box<dyn std::error::Error + Send + Sync + 'static>;
787         let regexp: std::sync::Arc<Regex> = ctx
788             .get_or_create_aux(0, |vr| -> Result<_, BoxError> {
789                 Ok(Regex::new(vr.as_str()?)?)
790             })?;
791 
792         let is_match = {
793             let text = ctx
794                 .get_raw(1)
795                 .as_str()
796                 .map_err(|e| Error::UserFunctionError(e.into()))?;
797 
798             regexp.is_match(text)
799         };
800 
801         Ok(is_match)
802     }
803 
804     #[test]
test_function_regexp_with_auxilliary()805     fn test_function_regexp_with_auxilliary() {
806         let db = Connection::open_in_memory().unwrap();
807         db.execute_batch(
808             "BEGIN;
809              CREATE TABLE foo (x string);
810              INSERT INTO foo VALUES ('lisa');
811              INSERT INTO foo VALUES ('lXsi');
812              INSERT INTO foo VALUES ('lisX');
813              END;",
814         )
815         .unwrap();
816         db.create_scalar_function(
817             "regexp",
818             2,
819             FunctionFlags::SQLITE_UTF8 | FunctionFlags::SQLITE_DETERMINISTIC,
820             regexp_with_auxilliary,
821         )
822         .unwrap();
823 
824         let result: Result<bool> =
825             db.query_row("SELECT regexp('l.s[aeiouy]', 'lisa')", NO_PARAMS, |r| {
826                 r.get(0)
827             });
828 
829         assert_eq!(true, result.unwrap());
830 
831         let result: Result<i64> = db.query_row(
832             "SELECT COUNT(*) FROM foo WHERE regexp('l.s[aeiouy]', x) == 1",
833             NO_PARAMS,
834             |r| r.get(0),
835         );
836 
837         assert_eq!(2, result.unwrap());
838     }
839 
840     #[test]
test_varargs_function()841     fn test_varargs_function() {
842         let db = Connection::open_in_memory().unwrap();
843         db.create_scalar_function(
844             "my_concat",
845             -1,
846             FunctionFlags::SQLITE_UTF8 | FunctionFlags::SQLITE_DETERMINISTIC,
847             |ctx| {
848                 let mut ret = String::new();
849 
850                 for idx in 0..ctx.len() {
851                     let s = ctx.get::<String>(idx)?;
852                     ret.push_str(&s);
853                 }
854 
855                 Ok(ret)
856             },
857         )
858         .unwrap();
859 
860         for &(expected, query) in &[
861             ("", "SELECT my_concat()"),
862             ("onetwo", "SELECT my_concat('one', 'two')"),
863             ("abc", "SELECT my_concat('a', 'b', 'c')"),
864         ] {
865             let result: String = db.query_row(query, NO_PARAMS, |r| r.get(0)).unwrap();
866             assert_eq!(expected, result);
867         }
868     }
869 
870     #[test]
test_get_aux_type_checking()871     fn test_get_aux_type_checking() {
872         let db = Connection::open_in_memory().unwrap();
873         db.create_scalar_function("example", 2, FunctionFlags::default(), |ctx| {
874             if !ctx.get::<bool>(1)? {
875                 ctx.set_aux::<i64>(0, 100)?;
876             } else {
877                 assert_eq!(ctx.get_aux::<String>(0), Err(Error::GetAuxWrongType));
878                 assert_eq!(*ctx.get_aux::<i64>(0).unwrap().unwrap(), 100);
879             }
880             Ok(true)
881         })
882         .unwrap();
883 
884         let res: bool = db
885             .query_row(
886                 "SELECT example(0, i) FROM (SELECT 0 as i UNION SELECT 1)",
887                 NO_PARAMS,
888                 |r| r.get(0),
889             )
890             .unwrap();
891         // Doesn't actually matter, we'll assert in the function if there's a problem.
892         assert!(res);
893     }
894 
895     struct Sum;
896     struct Count;
897 
898     impl Aggregate<i64, Option<i64>> for Sum {
init(&self) -> i64899         fn init(&self) -> i64 {
900             0
901         }
902 
step(&self, ctx: &mut Context<'_>, sum: &mut i64) -> Result<()>903         fn step(&self, ctx: &mut Context<'_>, sum: &mut i64) -> Result<()> {
904             *sum += ctx.get::<i64>(0)?;
905             Ok(())
906         }
907 
finalize(&self, sum: Option<i64>) -> Result<Option<i64>>908         fn finalize(&self, sum: Option<i64>) -> Result<Option<i64>> {
909             Ok(sum)
910         }
911     }
912 
913     impl Aggregate<i64, i64> for Count {
init(&self) -> i64914         fn init(&self) -> i64 {
915             0
916         }
917 
step(&self, _ctx: &mut Context<'_>, sum: &mut i64) -> Result<()>918         fn step(&self, _ctx: &mut Context<'_>, sum: &mut i64) -> Result<()> {
919             *sum += 1;
920             Ok(())
921         }
922 
finalize(&self, sum: Option<i64>) -> Result<i64>923         fn finalize(&self, sum: Option<i64>) -> Result<i64> {
924             Ok(sum.unwrap_or(0))
925         }
926     }
927 
928     #[test]
test_sum()929     fn test_sum() {
930         let db = Connection::open_in_memory().unwrap();
931         db.create_aggregate_function(
932             "my_sum",
933             1,
934             FunctionFlags::SQLITE_UTF8 | FunctionFlags::SQLITE_DETERMINISTIC,
935             Sum,
936         )
937         .unwrap();
938 
939         // sum should return NULL when given no columns (contrast with count below)
940         let no_result = "SELECT my_sum(i) FROM (SELECT 2 AS i WHERE 1 <> 1)";
941         let result: Option<i64> = db.query_row(no_result, NO_PARAMS, |r| r.get(0)).unwrap();
942         assert!(result.is_none());
943 
944         let single_sum = "SELECT my_sum(i) FROM (SELECT 2 AS i UNION ALL SELECT 2)";
945         let result: i64 = db.query_row(single_sum, NO_PARAMS, |r| r.get(0)).unwrap();
946         assert_eq!(4, result);
947 
948         let dual_sum = "SELECT my_sum(i), my_sum(j) FROM (SELECT 2 AS i, 1 AS j UNION ALL SELECT \
949                         2, 1)";
950         let result: (i64, i64) = db
951             .query_row(dual_sum, NO_PARAMS, |r| Ok((r.get(0)?, r.get(1)?)))
952             .unwrap();
953         assert_eq!((4, 2), result);
954     }
955 
956     #[test]
test_count()957     fn test_count() {
958         let db = Connection::open_in_memory().unwrap();
959         db.create_aggregate_function(
960             "my_count",
961             -1,
962             FunctionFlags::SQLITE_UTF8 | FunctionFlags::SQLITE_DETERMINISTIC,
963             Count,
964         )
965         .unwrap();
966 
967         // count should return 0 when given no columns (contrast with sum above)
968         let no_result = "SELECT my_count(i) FROM (SELECT 2 AS i WHERE 1 <> 1)";
969         let result: i64 = db.query_row(no_result, NO_PARAMS, |r| r.get(0)).unwrap();
970         assert_eq!(result, 0);
971 
972         let single_sum = "SELECT my_count(i) FROM (SELECT 2 AS i UNION ALL SELECT 2)";
973         let result: i64 = db.query_row(single_sum, NO_PARAMS, |r| r.get(0)).unwrap();
974         assert_eq!(2, result);
975     }
976 
977     #[cfg(feature = "window")]
978     impl WindowAggregate<i64, Option<i64>> for Sum {
inverse(&self, ctx: &mut Context<'_>, sum: &mut i64) -> Result<()>979         fn inverse(&self, ctx: &mut Context<'_>, sum: &mut i64) -> Result<()> {
980             *sum -= ctx.get::<i64>(0)?;
981             Ok(())
982         }
983 
value(&self, sum: Option<&i64>) -> Result<Option<i64>>984         fn value(&self, sum: Option<&i64>) -> Result<Option<i64>> {
985             Ok(sum.copied())
986         }
987     }
988 
989     #[test]
990     #[cfg(feature = "window")]
test_window()991     fn test_window() {
992         use fallible_iterator::FallibleIterator;
993 
994         let db = Connection::open_in_memory().unwrap();
995         db.create_window_function(
996             "sumint",
997             1,
998             FunctionFlags::SQLITE_UTF8 | FunctionFlags::SQLITE_DETERMINISTIC,
999             Sum,
1000         )
1001         .unwrap();
1002         db.execute_batch(
1003             "CREATE TABLE t3(x, y);
1004              INSERT INTO t3 VALUES('a', 4),
1005                      ('b', 5),
1006                      ('c', 3),
1007                      ('d', 8),
1008                      ('e', 1);",
1009         )
1010         .unwrap();
1011 
1012         let mut stmt = db
1013             .prepare(
1014                 "SELECT x, sumint(y) OVER (
1015                    ORDER BY x ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING
1016                  ) AS sum_y
1017                  FROM t3 ORDER BY x;",
1018             )
1019             .unwrap();
1020 
1021         let results: Vec<(String, i64)> = stmt
1022             .query(NO_PARAMS)
1023             .unwrap()
1024             .map(|row| Ok((row.get("x")?, row.get("sum_y")?)))
1025             .collect()
1026             .unwrap();
1027         let expected = vec![
1028             ("a".to_owned(), 9),
1029             ("b".to_owned(), 12),
1030             ("c".to_owned(), 16),
1031             ("d".to_owned(), 12),
1032             ("e".to_owned(), 9),
1033         ];
1034         assert_eq!(expected, results);
1035     }
1036 }
1037