1 //! `feature = "functions"` Create or redefine SQL functions.
2 //!
3 //! # Example
4 //!
5 //! Adding a `regexp` function to a connection in which compiled regular
6 //! expressions are cached in a `HashMap`. For an alternative implementation
7 //! that uses SQLite's [Function Auxilliary Data](https://www.sqlite.org/c3ref/get_auxdata.html) interface
8 //! to avoid recompiling regular expressions, see the unit tests for this
9 //! module.
10 //!
11 //! ```rust
12 //! use regex::Regex;
13 //! use rusqlite::functions::FunctionFlags;
14 //! use rusqlite::{Connection, Error, Result, NO_PARAMS};
15 //! use std::sync::Arc;
16 //! type BoxError = Box<dyn std::error::Error + Send + Sync + 'static>;
17 //!
18 //! fn add_regexp_function(db: &Connection) -> Result<()> {
19 //!     db.create_scalar_function(
20 //!         "regexp",
21 //!         2,
22 //!         FunctionFlags::SQLITE_UTF8 | FunctionFlags::SQLITE_DETERMINISTIC,
23 //!         move |ctx| {
24 //!             assert_eq!(ctx.len(), 2, "called with unexpected number of arguments");
25 //!             let regexp: Arc<Regex> = ctx
26 //!                 .get_or_create_aux(0, |vr| -> Result<_, BoxError> {
27 //!                     Ok(Regex::new(vr.as_str()?)?)
28 //!                 })?;
29 //!             let is_match = {
30 //!                 let text = ctx
31 //!                     .get_raw(1)
32 //!                     .as_str()
33 //!                     .map_err(|e| Error::UserFunctionError(e.into()))?;
34 //!
35 //!                 regexp.is_match(text)
36 //!             };
37 //!
38 //!             Ok(is_match)
39 //!         },
40 //!     )
41 //! }
42 //!
43 //! fn main() -> Result<()> {
44 //!     let db = Connection::open_in_memory()?;
45 //!     add_regexp_function(&db)?;
46 //!
47 //!     let is_match: bool = db.query_row(
48 //!         "SELECT regexp('[aeiou]*', 'aaaaeeeiii')",
49 //!         NO_PARAMS,
50 //!         |row| row.get(0),
51 //!     )?;
52 //!
53 //!     assert!(is_match);
54 //!     Ok(())
55 //! }
56 //! ```
57 use std::any::Any;
58 use std::os::raw::{c_int, c_void};
59 use std::panic::{catch_unwind, RefUnwindSafe, UnwindSafe};
60 use std::ptr;
61 use std::slice;
62 use std::sync::Arc;
63 
64 use crate::ffi;
65 use crate::ffi::sqlite3_context;
66 use crate::ffi::sqlite3_value;
67 
68 use crate::context::set_result;
69 use crate::types::{FromSql, FromSqlError, ToSql, ValueRef};
70 
71 use crate::{str_to_cstring, Connection, Error, InnerConnection, Result};
72 
report_error(ctx: *mut sqlite3_context, err: &Error)73 unsafe fn report_error(ctx: *mut sqlite3_context, err: &Error) {
74     // Extended constraint error codes were added in SQLite 3.7.16. We don't have
75     // an explicit feature check for that, and this doesn't really warrant one.
76     // We'll use the extended code if we're on the bundled version (since it's
77     // at least 3.17.0) and the normal constraint error code if not.
78     #[cfg(feature = "modern_sqlite")]
79     fn constraint_error_code() -> i32 {
80         ffi::SQLITE_CONSTRAINT_FUNCTION
81     }
82     #[cfg(not(feature = "modern_sqlite"))]
83     fn constraint_error_code() -> i32 {
84         ffi::SQLITE_CONSTRAINT
85     }
86 
87     match *err {
88         Error::SqliteFailure(ref err, ref s) => {
89             ffi::sqlite3_result_error_code(ctx, err.extended_code);
90             if let Some(Ok(cstr)) = s.as_ref().map(|s| str_to_cstring(s)) {
91                 ffi::sqlite3_result_error(ctx, cstr.as_ptr(), -1);
92             }
93         }
94         _ => {
95             ffi::sqlite3_result_error_code(ctx, constraint_error_code());
96             if let Ok(cstr) = str_to_cstring(&err.to_string()) {
97                 ffi::sqlite3_result_error(ctx, cstr.as_ptr(), -1);
98             }
99         }
100     }
101 }
102 
free_boxed_value<T>(p: *mut c_void)103 unsafe extern "C" fn free_boxed_value<T>(p: *mut c_void) {
104     drop(Box::from_raw(p as *mut T));
105 }
106 
107 /// `feature = "functions"` Context is a wrapper for the SQLite function
108 /// evaluation context.
109 pub struct Context<'a> {
110     ctx: *mut sqlite3_context,
111     args: &'a [*mut sqlite3_value],
112 }
113 
114 impl Context<'_> {
115     /// Returns the number of arguments to the function.
len(&self) -> usize116     pub fn len(&self) -> usize {
117         self.args.len()
118     }
119 
120     /// Returns `true` when there is no argument.
is_empty(&self) -> bool121     pub fn is_empty(&self) -> bool {
122         self.args.is_empty()
123     }
124 
125     /// Returns the `idx`th argument as a `T`.
126     ///
127     /// # Failure
128     ///
129     /// Will panic if `idx` is greater than or equal to `self.len()`.
130     ///
131     /// Will return Err if the underlying SQLite type cannot be converted to a
132     /// `T`.
get<T: FromSql>(&self, idx: usize) -> Result<T>133     pub fn get<T: FromSql>(&self, idx: usize) -> Result<T> {
134         let arg = self.args[idx];
135         let value = unsafe { ValueRef::from_value(arg) };
136         FromSql::column_result(value).map_err(|err| match err {
137             FromSqlError::InvalidType => {
138                 Error::InvalidFunctionParameterType(idx, value.data_type())
139             }
140             FromSqlError::OutOfRange(i) => Error::IntegralValueOutOfRange(idx, i),
141             FromSqlError::Other(err) => {
142                 Error::FromSqlConversionFailure(idx, value.data_type(), err)
143             }
144             #[cfg(feature = "i128_blob")]
145             FromSqlError::InvalidI128Size(_) => {
146                 Error::FromSqlConversionFailure(idx, value.data_type(), Box::new(err))
147             }
148             #[cfg(feature = "uuid")]
149             FromSqlError::InvalidUuidSize(_) => {
150                 Error::FromSqlConversionFailure(idx, value.data_type(), Box::new(err))
151             }
152         })
153     }
154 
155     /// Returns the `idx`th argument as a `ValueRef`.
156     ///
157     /// # Failure
158     ///
159     /// Will panic if `idx` is greater than or equal to `self.len()`.
get_raw(&self, idx: usize) -> ValueRef<'_>160     pub fn get_raw(&self, idx: usize) -> ValueRef<'_> {
161         let arg = self.args[idx];
162         unsafe { ValueRef::from_value(arg) }
163     }
164 
165     /// Fetch or insert the the auxilliary data associated with a particular
166     /// parameter. This is intended to be an easier-to-use way of fetching it
167     /// compared to calling `get_aux` and `set_aux` separately.
168     ///
169     /// See https://www.sqlite.org/c3ref/get_auxdata.html for a discussion of
170     /// this feature, or the unit tests of this module for an example.
get_or_create_aux<T, E, F>(&self, arg: c_int, func: F) -> Result<Arc<T>> where T: Send + Sync + 'static, E: Into<Box<dyn std::error::Error + Send + Sync + 'static>>, F: FnOnce(ValueRef<'_>) -> Result<T, E>,171     pub fn get_or_create_aux<T, E, F>(&self, arg: c_int, func: F) -> Result<Arc<T>>
172     where
173         T: Send + Sync + 'static,
174         E: Into<Box<dyn std::error::Error + Send + Sync + 'static>>,
175         F: FnOnce(ValueRef<'_>) -> Result<T, E>,
176     {
177         if let Some(v) = self.get_aux(arg)? {
178             Ok(v)
179         } else {
180             let vr = self.get_raw(arg as usize);
181             self.set_aux(
182                 arg,
183                 func(vr).map_err(|e| Error::UserFunctionError(e.into()))?,
184             )
185         }
186     }
187 
188     /// Sets the auxilliary data associated with a particular parameter. See
189     /// https://www.sqlite.org/c3ref/get_auxdata.html for a discussion of
190     /// this feature, or the unit tests of this module for an example.
set_aux<T: Send + Sync + 'static>(&self, arg: c_int, value: T) -> Result<Arc<T>>191     pub fn set_aux<T: Send + Sync + 'static>(&self, arg: c_int, value: T) -> Result<Arc<T>> {
192         let orig: Arc<T> = Arc::new(value);
193         let inner: AuxInner = orig.clone();
194         let outer = Box::new(inner);
195         let raw: *mut AuxInner = Box::into_raw(outer);
196         unsafe {
197             ffi::sqlite3_set_auxdata(
198                 self.ctx,
199                 arg,
200                 raw as *mut _,
201                 Some(free_boxed_value::<AuxInner>),
202             )
203         };
204         Ok(orig)
205     }
206 
207     /// Gets the auxilliary data that was associated with a given parameter via
208     /// `set_aux`. Returns `Ok(None)` if no data has been associated, and
209     /// Ok(Some(v)) if it has. Returns an error if the requested type does not
210     /// match.
get_aux<T: Send + Sync + 'static>(&self, arg: c_int) -> Result<Option<Arc<T>>>211     pub fn get_aux<T: Send + Sync + 'static>(&self, arg: c_int) -> Result<Option<Arc<T>>> {
212         let p = unsafe { ffi::sqlite3_get_auxdata(self.ctx, arg) as *const AuxInner };
213         if p.is_null() {
214             Ok(None)
215         } else {
216             let v: AuxInner = AuxInner::clone(unsafe { &*p });
217             v.downcast::<T>()
218                 .map(Some)
219                 .map_err(|_| Error::GetAuxWrongType)
220         }
221     }
222 }
223 
224 type AuxInner = Arc<dyn Any + Send + Sync + 'static>;
225 
226 /// `feature = "functions"` Aggregate is the callback interface for user-defined
227 /// aggregate function.
228 ///
229 /// `A` is the type of the aggregation context and `T` is the type of the final
230 /// result. Implementations should be stateless.
231 pub trait Aggregate<A, T>
232 where
233     A: RefUnwindSafe + UnwindSafe,
234     T: ToSql,
235 {
236     /// Initializes the aggregation context. Will be called prior to the first
237     /// call to `step()` to set up the context for an invocation of the
238     /// function. (Note: `init()` will not be called if there are no rows.)
init(&self) -> A239     fn init(&self) -> A;
240 
241     /// "step" function called once for each row in an aggregate group. May be
242     /// called 0 times if there are no rows.
step(&self, _: &mut Context<'_>, _: &mut A) -> Result<()>243     fn step(&self, _: &mut Context<'_>, _: &mut A) -> Result<()>;
244 
245     /// Computes and returns the final result. Will be called exactly once for
246     /// each invocation of the function. If `step()` was called at least
247     /// once, will be given `Some(A)` (the same `A` as was created by
248     /// `init` and given to `step`); if `step()` was not called (because
249     /// the function is running against 0 rows), will be given `None`.
finalize(&self, _: Option<A>) -> Result<T>250     fn finalize(&self, _: Option<A>) -> Result<T>;
251 }
252 
253 /// `feature = "window"` WindowAggregate is the callback interface for
254 /// user-defined aggregate window function.
255 #[cfg(feature = "window")]
256 pub trait WindowAggregate<A, T>: Aggregate<A, T>
257 where
258     A: RefUnwindSafe + UnwindSafe,
259     T: ToSql,
260 {
261     /// Returns the current value of the aggregate. Unlike xFinal, the
262     /// implementation should not delete any context.
value(&self, _: Option<&A>) -> Result<T>263     fn value(&self, _: Option<&A>) -> Result<T>;
264 
265     /// Removes a row from the current window.
inverse(&self, _: &mut Context<'_>, _: &mut A) -> Result<()>266     fn inverse(&self, _: &mut Context<'_>, _: &mut A) -> Result<()>;
267 }
268 
269 bitflags::bitflags! {
270     /// Function Flags.
271     /// See [sqlite3_create_function](https://sqlite.org/c3ref/create_function.html)
272     /// and [Function Flags](https://sqlite.org/c3ref/c_deterministic.html) for details.
273     #[repr(C)]
274     pub struct FunctionFlags: ::std::os::raw::c_int {
275         /// Specifies UTF-8 as the text encoding this SQL function prefers for its parameters.
276         const SQLITE_UTF8     = ffi::SQLITE_UTF8;
277         /// Specifies UTF-16 using little-endian byte order as the text encoding this SQL function prefers for its parameters.
278         const SQLITE_UTF16LE  = ffi::SQLITE_UTF16LE;
279         /// Specifies UTF-16 using big-endian byte order as the text encoding this SQL function prefers for its parameters.
280         const SQLITE_UTF16BE  = ffi::SQLITE_UTF16BE;
281         /// Specifies UTF-16 using native byte order as the text encoding this SQL function prefers for its parameters.
282         const SQLITE_UTF16    = ffi::SQLITE_UTF16;
283         /// Means that the function always gives the same output when the input parameters are the same.
284         const SQLITE_DETERMINISTIC = ffi::SQLITE_DETERMINISTIC;
285         /// Means that the function may only be invoked from top-level SQL.
286         const SQLITE_DIRECTONLY    = 0x0000_0008_0000; // 3.30.0
287         /// Indicates to SQLite that a function may call `sqlite3_value_subtype()` to inspect the sub-types of its arguments.
288         const SQLITE_SUBTYPE       = 0x0000_0010_0000; // 3.30.0
289         /// Means that the function is unlikely to cause problems even if misused.
290         const SQLITE_INNOCUOUS     = 0x0000_0020_0000; // 3.31.0
291     }
292 }
293 
294 impl Default for FunctionFlags {
default() -> FunctionFlags295     fn default() -> FunctionFlags {
296         FunctionFlags::SQLITE_UTF8
297     }
298 }
299 
300 impl Connection {
301     /// `feature = "functions"` Attach a user-defined scalar function to
302     /// this database connection.
303     ///
304     /// `fn_name` is the name the function will be accessible from SQL.
305     /// `n_arg` is the number of arguments to the function. Use `-1` for a
306     /// variable number. If the function always returns the same value
307     /// given the same input, `deterministic` should be `true`.
308     ///
309     /// The function will remain available until the connection is closed or
310     /// until it is explicitly removed via `remove_function`.
311     ///
312     /// # Example
313     ///
314     /// ```rust
315     /// # use rusqlite::{Connection, Result, NO_PARAMS};
316     /// # use rusqlite::functions::FunctionFlags;
317     /// fn scalar_function_example(db: Connection) -> Result<()> {
318     ///     db.create_scalar_function(
319     ///         "halve",
320     ///         1,
321     ///         FunctionFlags::SQLITE_UTF8 | FunctionFlags::SQLITE_DETERMINISTIC,
322     ///         |ctx| {
323     ///             let value = ctx.get::<f64>(0)?;
324     ///             Ok(value / 2f64)
325     ///         },
326     ///     )?;
327     ///
328     ///     let six_halved: f64 = db.query_row("SELECT halve(6)", NO_PARAMS, |r| r.get(0))?;
329     ///     assert_eq!(six_halved, 3f64);
330     ///     Ok(())
331     /// }
332     /// ```
333     ///
334     /// # Failure
335     ///
336     /// Will return Err if the function could not be attached to the connection.
create_scalar_function<F, T>( &self, fn_name: &str, n_arg: c_int, flags: FunctionFlags, x_func: F, ) -> Result<()> where F: FnMut(&Context<'_>) -> Result<T> + Send + UnwindSafe + 'static, T: ToSql,337     pub fn create_scalar_function<F, T>(
338         &self,
339         fn_name: &str,
340         n_arg: c_int,
341         flags: FunctionFlags,
342         x_func: F,
343     ) -> Result<()>
344     where
345         F: FnMut(&Context<'_>) -> Result<T> + Send + UnwindSafe + 'static,
346         T: ToSql,
347     {
348         self.db
349             .borrow_mut()
350             .create_scalar_function(fn_name, n_arg, flags, x_func)
351     }
352 
353     /// `feature = "functions"` Attach a user-defined aggregate function to this
354     /// database connection.
355     ///
356     /// # Failure
357     ///
358     /// Will return Err if the function could not be attached to the connection.
create_aggregate_function<A, D, T>( &self, fn_name: &str, n_arg: c_int, flags: FunctionFlags, aggr: D, ) -> Result<()> where A: RefUnwindSafe + UnwindSafe, D: Aggregate<A, T>, T: ToSql,359     pub fn create_aggregate_function<A, D, T>(
360         &self,
361         fn_name: &str,
362         n_arg: c_int,
363         flags: FunctionFlags,
364         aggr: D,
365     ) -> Result<()>
366     where
367         A: RefUnwindSafe + UnwindSafe,
368         D: Aggregate<A, T>,
369         T: ToSql,
370     {
371         self.db
372             .borrow_mut()
373             .create_aggregate_function(fn_name, n_arg, flags, aggr)
374     }
375 
376     /// `feature = "window"` Attach a user-defined aggregate window function to
377     /// this database connection.
378     ///
379     /// See https://sqlite.org/windowfunctions.html#udfwinfunc for more
380     /// information.
381     #[cfg(feature = "window")]
create_window_function<A, W, T>( &self, fn_name: &str, n_arg: c_int, flags: FunctionFlags, aggr: W, ) -> Result<()> where A: RefUnwindSafe + UnwindSafe, W: WindowAggregate<A, T>, T: ToSql,382     pub fn create_window_function<A, W, T>(
383         &self,
384         fn_name: &str,
385         n_arg: c_int,
386         flags: FunctionFlags,
387         aggr: W,
388     ) -> Result<()>
389     where
390         A: RefUnwindSafe + UnwindSafe,
391         W: WindowAggregate<A, T>,
392         T: ToSql,
393     {
394         self.db
395             .borrow_mut()
396             .create_window_function(fn_name, n_arg, flags, aggr)
397     }
398 
399     /// `feature = "functions"` Removes a user-defined function from this
400     /// database connection.
401     ///
402     /// `fn_name` and `n_arg` should match the name and number of arguments
403     /// given to `create_scalar_function` or `create_aggregate_function`.
404     ///
405     /// # Failure
406     ///
407     /// Will return Err if the function could not be removed.
remove_function(&self, fn_name: &str, n_arg: c_int) -> Result<()>408     pub fn remove_function(&self, fn_name: &str, n_arg: c_int) -> Result<()> {
409         self.db.borrow_mut().remove_function(fn_name, n_arg)
410     }
411 }
412 
413 impl InnerConnection {
create_scalar_function<F, T>( &mut self, fn_name: &str, n_arg: c_int, flags: FunctionFlags, x_func: F, ) -> Result<()> where F: FnMut(&Context<'_>) -> Result<T> + Send + UnwindSafe + 'static, T: ToSql,414     fn create_scalar_function<F, T>(
415         &mut self,
416         fn_name: &str,
417         n_arg: c_int,
418         flags: FunctionFlags,
419         x_func: F,
420     ) -> Result<()>
421     where
422         F: FnMut(&Context<'_>) -> Result<T> + Send + UnwindSafe + 'static,
423         T: ToSql,
424     {
425         unsafe extern "C" fn call_boxed_closure<F, T>(
426             ctx: *mut sqlite3_context,
427             argc: c_int,
428             argv: *mut *mut sqlite3_value,
429         ) where
430             F: FnMut(&Context<'_>) -> Result<T>,
431             T: ToSql,
432         {
433             let r = catch_unwind(|| {
434                 let boxed_f: *mut F = ffi::sqlite3_user_data(ctx) as *mut F;
435                 assert!(!boxed_f.is_null(), "Internal error - null function pointer");
436                 let ctx = Context {
437                     ctx,
438                     args: slice::from_raw_parts(argv, argc as usize),
439                 };
440                 (*boxed_f)(&ctx)
441             });
442             let t = match r {
443                 Err(_) => {
444                     report_error(ctx, &Error::UnwindingPanic);
445                     return;
446                 }
447                 Ok(r) => r,
448             };
449             let t = t.as_ref().map(|t| ToSql::to_sql(t));
450 
451             match t {
452                 Ok(Ok(ref value)) => set_result(ctx, value),
453                 Ok(Err(err)) => report_error(ctx, &err),
454                 Err(err) => report_error(ctx, err),
455             }
456         }
457 
458         let boxed_f: *mut F = Box::into_raw(Box::new(x_func));
459         let c_name = str_to_cstring(fn_name)?;
460         let r = unsafe {
461             ffi::sqlite3_create_function_v2(
462                 self.db(),
463                 c_name.as_ptr(),
464                 n_arg,
465                 flags.bits(),
466                 boxed_f as *mut c_void,
467                 Some(call_boxed_closure::<F, T>),
468                 None,
469                 None,
470                 Some(free_boxed_value::<F>),
471             )
472         };
473         self.decode_result(r)
474     }
475 
create_aggregate_function<A, D, T>( &mut self, fn_name: &str, n_arg: c_int, flags: FunctionFlags, aggr: D, ) -> Result<()> where A: RefUnwindSafe + UnwindSafe, D: Aggregate<A, T>, T: ToSql,476     fn create_aggregate_function<A, D, T>(
477         &mut self,
478         fn_name: &str,
479         n_arg: c_int,
480         flags: FunctionFlags,
481         aggr: D,
482     ) -> Result<()>
483     where
484         A: RefUnwindSafe + UnwindSafe,
485         D: Aggregate<A, T>,
486         T: ToSql,
487     {
488         let boxed_aggr: *mut D = Box::into_raw(Box::new(aggr));
489         let c_name = str_to_cstring(fn_name)?;
490         let r = unsafe {
491             ffi::sqlite3_create_function_v2(
492                 self.db(),
493                 c_name.as_ptr(),
494                 n_arg,
495                 flags.bits(),
496                 boxed_aggr as *mut c_void,
497                 None,
498                 Some(call_boxed_step::<A, D, T>),
499                 Some(call_boxed_final::<A, D, T>),
500                 Some(free_boxed_value::<D>),
501             )
502         };
503         self.decode_result(r)
504     }
505 
506     #[cfg(feature = "window")]
create_window_function<A, W, T>( &mut self, fn_name: &str, n_arg: c_int, flags: FunctionFlags, aggr: W, ) -> Result<()> where A: RefUnwindSafe + UnwindSafe, W: WindowAggregate<A, T>, T: ToSql,507     fn create_window_function<A, W, T>(
508         &mut self,
509         fn_name: &str,
510         n_arg: c_int,
511         flags: FunctionFlags,
512         aggr: W,
513     ) -> Result<()>
514     where
515         A: RefUnwindSafe + UnwindSafe,
516         W: WindowAggregate<A, T>,
517         T: ToSql,
518     {
519         let boxed_aggr: *mut W = Box::into_raw(Box::new(aggr));
520         let c_name = str_to_cstring(fn_name)?;
521         let r = unsafe {
522             ffi::sqlite3_create_window_function(
523                 self.db(),
524                 c_name.as_ptr(),
525                 n_arg,
526                 flags.bits(),
527                 boxed_aggr as *mut c_void,
528                 Some(call_boxed_step::<A, W, T>),
529                 Some(call_boxed_final::<A, W, T>),
530                 Some(call_boxed_value::<A, W, T>),
531                 Some(call_boxed_inverse::<A, W, T>),
532                 Some(free_boxed_value::<W>),
533             )
534         };
535         self.decode_result(r)
536     }
537 
remove_function(&mut self, fn_name: &str, n_arg: c_int) -> Result<()>538     fn remove_function(&mut self, fn_name: &str, n_arg: c_int) -> Result<()> {
539         let c_name = str_to_cstring(fn_name)?;
540         let r = unsafe {
541             ffi::sqlite3_create_function_v2(
542                 self.db(),
543                 c_name.as_ptr(),
544                 n_arg,
545                 ffi::SQLITE_UTF8,
546                 ptr::null_mut(),
547                 None,
548                 None,
549                 None,
550                 None,
551             )
552         };
553         self.decode_result(r)
554     }
555 }
556 
aggregate_context<A>(ctx: *mut sqlite3_context, bytes: usize) -> Option<*mut *mut A>557 unsafe fn aggregate_context<A>(ctx: *mut sqlite3_context, bytes: usize) -> Option<*mut *mut A> {
558     let pac = ffi::sqlite3_aggregate_context(ctx, bytes as c_int) as *mut *mut A;
559     if pac.is_null() {
560         return None;
561     }
562     Some(pac)
563 }
564 
call_boxed_step<A, D, T>( ctx: *mut sqlite3_context, argc: c_int, argv: *mut *mut sqlite3_value, ) where A: RefUnwindSafe + UnwindSafe, D: Aggregate<A, T>, T: ToSql,565 unsafe extern "C" fn call_boxed_step<A, D, T>(
566     ctx: *mut sqlite3_context,
567     argc: c_int,
568     argv: *mut *mut sqlite3_value,
569 ) where
570     A: RefUnwindSafe + UnwindSafe,
571     D: Aggregate<A, T>,
572     T: ToSql,
573 {
574     let pac = match aggregate_context(ctx, ::std::mem::size_of::<*mut A>()) {
575         Some(pac) => pac,
576         None => {
577             ffi::sqlite3_result_error_nomem(ctx);
578             return;
579         }
580     };
581 
582     let r = catch_unwind(|| {
583         let boxed_aggr: *mut D = ffi::sqlite3_user_data(ctx) as *mut D;
584         assert!(
585             !boxed_aggr.is_null(),
586             "Internal error - null aggregate pointer"
587         );
588         if (*pac as *mut A).is_null() {
589             *pac = Box::into_raw(Box::new((*boxed_aggr).init()));
590         }
591         let mut ctx = Context {
592             ctx,
593             args: slice::from_raw_parts(argv, argc as usize),
594         };
595         (*boxed_aggr).step(&mut ctx, &mut **pac)
596     });
597     let r = match r {
598         Err(_) => {
599             report_error(ctx, &Error::UnwindingPanic);
600             return;
601         }
602         Ok(r) => r,
603     };
604     match r {
605         Ok(_) => {}
606         Err(err) => report_error(ctx, &err),
607     };
608 }
609 
610 #[cfg(feature = "window")]
call_boxed_inverse<A, W, T>( ctx: *mut sqlite3_context, argc: c_int, argv: *mut *mut sqlite3_value, ) where A: RefUnwindSafe + UnwindSafe, W: WindowAggregate<A, T>, T: ToSql,611 unsafe extern "C" fn call_boxed_inverse<A, W, T>(
612     ctx: *mut sqlite3_context,
613     argc: c_int,
614     argv: *mut *mut sqlite3_value,
615 ) where
616     A: RefUnwindSafe + UnwindSafe,
617     W: WindowAggregate<A, T>,
618     T: ToSql,
619 {
620     let pac = match aggregate_context(ctx, ::std::mem::size_of::<*mut A>()) {
621         Some(pac) => pac,
622         None => {
623             ffi::sqlite3_result_error_nomem(ctx);
624             return;
625         }
626     };
627 
628     let r = catch_unwind(|| {
629         let boxed_aggr: *mut W = ffi::sqlite3_user_data(ctx) as *mut W;
630         assert!(
631             !boxed_aggr.is_null(),
632             "Internal error - null aggregate pointer"
633         );
634         let mut ctx = Context {
635             ctx,
636             args: slice::from_raw_parts(argv, argc as usize),
637         };
638         (*boxed_aggr).inverse(&mut ctx, &mut **pac)
639     });
640     let r = match r {
641         Err(_) => {
642             report_error(ctx, &Error::UnwindingPanic);
643             return;
644         }
645         Ok(r) => r,
646     };
647     match r {
648         Ok(_) => {}
649         Err(err) => report_error(ctx, &err),
650     };
651 }
652 
call_boxed_final<A, D, T>(ctx: *mut sqlite3_context) where A: RefUnwindSafe + UnwindSafe, D: Aggregate<A, T>, T: ToSql,653 unsafe extern "C" fn call_boxed_final<A, D, T>(ctx: *mut sqlite3_context)
654 where
655     A: RefUnwindSafe + UnwindSafe,
656     D: Aggregate<A, T>,
657     T: ToSql,
658 {
659     // Within the xFinal callback, it is customary to set N=0 in calls to
660     // sqlite3_aggregate_context(C,N) so that no pointless memory allocations occur.
661     let a: Option<A> = match aggregate_context(ctx, 0) {
662         Some(pac) => {
663             if (*pac as *mut A).is_null() {
664                 None
665             } else {
666                 let a = Box::from_raw(*pac);
667                 Some(*a)
668             }
669         }
670         None => None,
671     };
672 
673     let r = catch_unwind(|| {
674         let boxed_aggr: *mut D = ffi::sqlite3_user_data(ctx) as *mut D;
675         assert!(
676             !boxed_aggr.is_null(),
677             "Internal error - null aggregate pointer"
678         );
679         (*boxed_aggr).finalize(a)
680     });
681     let t = match r {
682         Err(_) => {
683             report_error(ctx, &Error::UnwindingPanic);
684             return;
685         }
686         Ok(r) => r,
687     };
688     let t = t.as_ref().map(|t| ToSql::to_sql(t));
689     match t {
690         Ok(Ok(ref value)) => set_result(ctx, value),
691         Ok(Err(err)) => report_error(ctx, &err),
692         Err(err) => report_error(ctx, err),
693     }
694 }
695 
696 #[cfg(feature = "window")]
call_boxed_value<A, W, T>(ctx: *mut sqlite3_context) where A: RefUnwindSafe + UnwindSafe, W: WindowAggregate<A, T>, T: ToSql,697 unsafe extern "C" fn call_boxed_value<A, W, T>(ctx: *mut sqlite3_context)
698 where
699     A: RefUnwindSafe + UnwindSafe,
700     W: WindowAggregate<A, T>,
701     T: ToSql,
702 {
703     // Within the xValue callback, it is customary to set N=0 in calls to
704     // sqlite3_aggregate_context(C,N) so that no pointless memory allocations occur.
705     let a: Option<&A> = match aggregate_context(ctx, 0) {
706         Some(pac) => {
707             if (*pac as *mut A).is_null() {
708                 None
709             } else {
710                 let a = &**pac;
711                 Some(a)
712             }
713         }
714         None => None,
715     };
716 
717     let r = catch_unwind(|| {
718         let boxed_aggr: *mut W = ffi::sqlite3_user_data(ctx) as *mut W;
719         assert!(
720             !boxed_aggr.is_null(),
721             "Internal error - null aggregate pointer"
722         );
723         (*boxed_aggr).value(a)
724     });
725     let t = match r {
726         Err(_) => {
727             report_error(ctx, &Error::UnwindingPanic);
728             return;
729         }
730         Ok(r) => r,
731     };
732     let t = t.as_ref().map(|t| ToSql::to_sql(t));
733     match t {
734         Ok(Ok(ref value)) => set_result(ctx, value),
735         Ok(Err(err)) => report_error(ctx, &err),
736         Err(err) => report_error(ctx, err),
737     }
738 }
739 
740 #[cfg(test)]
741 mod test {
742     use regex::Regex;
743     use std::f64::EPSILON;
744     use std::os::raw::c_double;
745 
746     #[cfg(feature = "window")]
747     use crate::functions::WindowAggregate;
748     use crate::functions::{Aggregate, Context, FunctionFlags};
749     use crate::{Connection, Error, Result, NO_PARAMS};
750 
half(ctx: &Context<'_>) -> Result<c_double>751     fn half(ctx: &Context<'_>) -> Result<c_double> {
752         assert_eq!(ctx.len(), 1, "called with unexpected number of arguments");
753         let value = ctx.get::<c_double>(0)?;
754         Ok(value / 2f64)
755     }
756 
757     #[test]
test_function_half()758     fn test_function_half() {
759         let db = Connection::open_in_memory().unwrap();
760         db.create_scalar_function(
761             "half",
762             1,
763             FunctionFlags::SQLITE_UTF8 | FunctionFlags::SQLITE_DETERMINISTIC,
764             half,
765         )
766         .unwrap();
767         let result: Result<f64> = db.query_row("SELECT half(6)", NO_PARAMS, |r| r.get(0));
768 
769         assert!((3f64 - result.unwrap()).abs() < EPSILON);
770     }
771 
772     #[test]
test_remove_function()773     fn test_remove_function() {
774         let db = Connection::open_in_memory().unwrap();
775         db.create_scalar_function(
776             "half",
777             1,
778             FunctionFlags::SQLITE_UTF8 | FunctionFlags::SQLITE_DETERMINISTIC,
779             half,
780         )
781         .unwrap();
782         let result: Result<f64> = db.query_row("SELECT half(6)", NO_PARAMS, |r| r.get(0));
783         assert!((3f64 - result.unwrap()).abs() < EPSILON);
784 
785         db.remove_function("half", 1).unwrap();
786         let result: Result<f64> = db.query_row("SELECT half(6)", NO_PARAMS, |r| r.get(0));
787         assert!(result.is_err());
788     }
789 
790     // This implementation of a regexp scalar function uses SQLite's auxilliary data
791     // (https://www.sqlite.org/c3ref/get_auxdata.html) to avoid recompiling the regular
792     // expression multiple times within one query.
regexp_with_auxilliary(ctx: &Context<'_>) -> Result<bool>793     fn regexp_with_auxilliary(ctx: &Context<'_>) -> Result<bool> {
794         assert_eq!(ctx.len(), 2, "called with unexpected number of arguments");
795         type BoxError = Box<dyn std::error::Error + Send + Sync + 'static>;
796         let regexp: std::sync::Arc<Regex> = ctx
797             .get_or_create_aux(0, |vr| -> Result<_, BoxError> {
798                 Ok(Regex::new(vr.as_str()?)?)
799             })?;
800 
801         let is_match = {
802             let text = ctx
803                 .get_raw(1)
804                 .as_str()
805                 .map_err(|e| Error::UserFunctionError(e.into()))?;
806 
807             regexp.is_match(text)
808         };
809 
810         Ok(is_match)
811     }
812 
813     #[test]
test_function_regexp_with_auxilliary()814     fn test_function_regexp_with_auxilliary() {
815         let db = Connection::open_in_memory().unwrap();
816         db.execute_batch(
817             "BEGIN;
818              CREATE TABLE foo (x string);
819              INSERT INTO foo VALUES ('lisa');
820              INSERT INTO foo VALUES ('lXsi');
821              INSERT INTO foo VALUES ('lisX');
822              END;",
823         )
824         .unwrap();
825         db.create_scalar_function(
826             "regexp",
827             2,
828             FunctionFlags::SQLITE_UTF8 | FunctionFlags::SQLITE_DETERMINISTIC,
829             regexp_with_auxilliary,
830         )
831         .unwrap();
832 
833         let result: Result<bool> =
834             db.query_row("SELECT regexp('l.s[aeiouy]', 'lisa')", NO_PARAMS, |r| {
835                 r.get(0)
836             });
837 
838         assert_eq!(true, result.unwrap());
839 
840         let result: Result<i64> = db.query_row(
841             "SELECT COUNT(*) FROM foo WHERE regexp('l.s[aeiouy]', x) == 1",
842             NO_PARAMS,
843             |r| r.get(0),
844         );
845 
846         assert_eq!(2, result.unwrap());
847     }
848 
849     #[test]
test_varargs_function()850     fn test_varargs_function() {
851         let db = Connection::open_in_memory().unwrap();
852         db.create_scalar_function(
853             "my_concat",
854             -1,
855             FunctionFlags::SQLITE_UTF8 | FunctionFlags::SQLITE_DETERMINISTIC,
856             |ctx| {
857                 let mut ret = String::new();
858 
859                 for idx in 0..ctx.len() {
860                     let s = ctx.get::<String>(idx)?;
861                     ret.push_str(&s);
862                 }
863 
864                 Ok(ret)
865             },
866         )
867         .unwrap();
868 
869         for &(expected, query) in &[
870             ("", "SELECT my_concat()"),
871             ("onetwo", "SELECT my_concat('one', 'two')"),
872             ("abc", "SELECT my_concat('a', 'b', 'c')"),
873         ] {
874             let result: String = db.query_row(query, NO_PARAMS, |r| r.get(0)).unwrap();
875             assert_eq!(expected, result);
876         }
877     }
878 
879     #[test]
test_get_aux_type_checking()880     fn test_get_aux_type_checking() {
881         let db = Connection::open_in_memory().unwrap();
882         db.create_scalar_function("example", 2, FunctionFlags::default(), |ctx| {
883             if !ctx.get::<bool>(1)? {
884                 ctx.set_aux::<i64>(0, 100)?;
885             } else {
886                 assert_eq!(ctx.get_aux::<String>(0), Err(Error::GetAuxWrongType));
887                 assert_eq!(*ctx.get_aux::<i64>(0).unwrap().unwrap(), 100);
888             }
889             Ok(true)
890         })
891         .unwrap();
892 
893         let res: bool = db
894             .query_row(
895                 "SELECT example(0, i) FROM (SELECT 0 as i UNION SELECT 1)",
896                 NO_PARAMS,
897                 |r| r.get(0),
898             )
899             .unwrap();
900         // Doesn't actually matter, we'll assert in the function if there's a problem.
901         assert!(res);
902     }
903 
904     struct Sum;
905     struct Count;
906 
907     impl Aggregate<i64, Option<i64>> for Sum {
init(&self) -> i64908         fn init(&self) -> i64 {
909             0
910         }
911 
step(&self, ctx: &mut Context<'_>, sum: &mut i64) -> Result<()>912         fn step(&self, ctx: &mut Context<'_>, sum: &mut i64) -> Result<()> {
913             *sum += ctx.get::<i64>(0)?;
914             Ok(())
915         }
916 
finalize(&self, sum: Option<i64>) -> Result<Option<i64>>917         fn finalize(&self, sum: Option<i64>) -> Result<Option<i64>> {
918             Ok(sum)
919         }
920     }
921 
922     impl Aggregate<i64, i64> for Count {
init(&self) -> i64923         fn init(&self) -> i64 {
924             0
925         }
926 
step(&self, _ctx: &mut Context<'_>, sum: &mut i64) -> Result<()>927         fn step(&self, _ctx: &mut Context<'_>, sum: &mut i64) -> Result<()> {
928             *sum += 1;
929             Ok(())
930         }
931 
finalize(&self, sum: Option<i64>) -> Result<i64>932         fn finalize(&self, sum: Option<i64>) -> Result<i64> {
933             Ok(sum.unwrap_or(0))
934         }
935     }
936 
937     #[test]
test_sum()938     fn test_sum() {
939         let db = Connection::open_in_memory().unwrap();
940         db.create_aggregate_function(
941             "my_sum",
942             1,
943             FunctionFlags::SQLITE_UTF8 | FunctionFlags::SQLITE_DETERMINISTIC,
944             Sum,
945         )
946         .unwrap();
947 
948         // sum should return NULL when given no columns (contrast with count below)
949         let no_result = "SELECT my_sum(i) FROM (SELECT 2 AS i WHERE 1 <> 1)";
950         let result: Option<i64> = db.query_row(no_result, NO_PARAMS, |r| r.get(0)).unwrap();
951         assert!(result.is_none());
952 
953         let single_sum = "SELECT my_sum(i) FROM (SELECT 2 AS i UNION ALL SELECT 2)";
954         let result: i64 = db.query_row(single_sum, NO_PARAMS, |r| r.get(0)).unwrap();
955         assert_eq!(4, result);
956 
957         let dual_sum = "SELECT my_sum(i), my_sum(j) FROM (SELECT 2 AS i, 1 AS j UNION ALL SELECT \
958                         2, 1)";
959         let result: (i64, i64) = db
960             .query_row(dual_sum, NO_PARAMS, |r| Ok((r.get(0)?, r.get(1)?)))
961             .unwrap();
962         assert_eq!((4, 2), result);
963     }
964 
965     #[test]
test_count()966     fn test_count() {
967         let db = Connection::open_in_memory().unwrap();
968         db.create_aggregate_function(
969             "my_count",
970             -1,
971             FunctionFlags::SQLITE_UTF8 | FunctionFlags::SQLITE_DETERMINISTIC,
972             Count,
973         )
974         .unwrap();
975 
976         // count should return 0 when given no columns (contrast with sum above)
977         let no_result = "SELECT my_count(i) FROM (SELECT 2 AS i WHERE 1 <> 1)";
978         let result: i64 = db.query_row(no_result, NO_PARAMS, |r| r.get(0)).unwrap();
979         assert_eq!(result, 0);
980 
981         let single_sum = "SELECT my_count(i) FROM (SELECT 2 AS i UNION ALL SELECT 2)";
982         let result: i64 = db.query_row(single_sum, NO_PARAMS, |r| r.get(0)).unwrap();
983         assert_eq!(2, result);
984     }
985 
986     #[cfg(feature = "window")]
987     impl WindowAggregate<i64, Option<i64>> for Sum {
inverse(&self, ctx: &mut Context<'_>, sum: &mut i64) -> Result<()>988         fn inverse(&self, ctx: &mut Context<'_>, sum: &mut i64) -> Result<()> {
989             *sum -= ctx.get::<i64>(0)?;
990             Ok(())
991         }
992 
value(&self, sum: Option<&i64>) -> Result<Option<i64>>993         fn value(&self, sum: Option<&i64>) -> Result<Option<i64>> {
994             Ok(sum.copied())
995         }
996     }
997 
998     #[test]
999     #[cfg(feature = "window")]
test_window()1000     fn test_window() {
1001         use fallible_iterator::FallibleIterator;
1002 
1003         let db = Connection::open_in_memory().unwrap();
1004         db.create_window_function(
1005             "sumint",
1006             1,
1007             FunctionFlags::SQLITE_UTF8 | FunctionFlags::SQLITE_DETERMINISTIC,
1008             Sum,
1009         )
1010         .unwrap();
1011         db.execute_batch(
1012             "CREATE TABLE t3(x, y);
1013              INSERT INTO t3 VALUES('a', 4),
1014                      ('b', 5),
1015                      ('c', 3),
1016                      ('d', 8),
1017                      ('e', 1);",
1018         )
1019         .unwrap();
1020 
1021         let mut stmt = db
1022             .prepare(
1023                 "SELECT x, sumint(y) OVER (
1024                    ORDER BY x ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING
1025                  ) AS sum_y
1026                  FROM t3 ORDER BY x;",
1027             )
1028             .unwrap();
1029 
1030         let results: Vec<(String, i64)> = stmt
1031             .query(NO_PARAMS)
1032             .unwrap()
1033             .map(|row| Ok((row.get("x")?, row.get("sum_y")?)))
1034             .collect()
1035             .unwrap();
1036         let expected = vec![
1037             ("a".to_owned(), 9),
1038             ("b".to_owned(), 12),
1039             ("c".to_owned(), 16),
1040             ("d".to_owned(), 12),
1041             ("e".to_owned(), 9),
1042         ];
1043         assert_eq!(expected, results);
1044     }
1045 }
1046