1 //! `feature = "functions"` Create or redefine SQL functions.
2 //!
3 //! # Example
4 //!
5 //! Adding a `regexp` function to a connection in which compiled regular
6 //! expressions are cached in a `HashMap`. For an alternative implementation
7 //! that uses SQLite's [Function Auxilliary Data](https://www.sqlite.org/c3ref/get_auxdata.html) interface
8 //! to avoid recompiling regular expressions, see the unit tests for this
9 //! module.
10 //!
11 //! ```rust
12 //! use regex::Regex;
13 //! use rusqlite::functions::FunctionFlags;
14 //! use rusqlite::{Connection, Error, Result, NO_PARAMS};
15 //! use std::sync::Arc;
16 //! type BoxError = Box<dyn std::error::Error + Send + Sync + 'static>;
17 //!
18 //! fn add_regexp_function(db: &Connection) -> Result<()> {
19 //! db.create_scalar_function(
20 //! "regexp",
21 //! 2,
22 //! FunctionFlags::SQLITE_UTF8 | FunctionFlags::SQLITE_DETERMINISTIC,
23 //! move |ctx| {
24 //! assert_eq!(ctx.len(), 2, "called with unexpected number of arguments");
25 //! let regexp: Arc<Regex> = ctx
26 //! .get_or_create_aux(0, |vr| -> Result<_, BoxError> {
27 //! Ok(Regex::new(vr.as_str()?)?)
28 //! })?;
29 //! let is_match = {
30 //! let text = ctx
31 //! .get_raw(1)
32 //! .as_str()
33 //! .map_err(|e| Error::UserFunctionError(e.into()))?;
34 //!
35 //! regexp.is_match(text)
36 //! };
37 //!
38 //! Ok(is_match)
39 //! },
40 //! )
41 //! }
42 //!
43 //! fn main() -> Result<()> {
44 //! let db = Connection::open_in_memory()?;
45 //! add_regexp_function(&db)?;
46 //!
47 //! let is_match: bool = db.query_row(
48 //! "SELECT regexp('[aeiou]*', 'aaaaeeeiii')",
49 //! NO_PARAMS,
50 //! |row| row.get(0),
51 //! )?;
52 //!
53 //! assert!(is_match);
54 //! Ok(())
55 //! }
56 //! ```
57 use std::any::Any;
58 use std::os::raw::{c_int, c_void};
59 use std::panic::{catch_unwind, RefUnwindSafe, UnwindSafe};
60 use std::ptr;
61 use std::slice;
62 use std::sync::Arc;
63
64 use crate::ffi;
65 use crate::ffi::sqlite3_context;
66 use crate::ffi::sqlite3_value;
67
68 use crate::context::set_result;
69 use crate::types::{FromSql, FromSqlError, ToSql, ValueRef};
70
71 use crate::{str_to_cstring, Connection, Error, InnerConnection, Result};
72
report_error(ctx: *mut sqlite3_context, err: &Error)73 unsafe fn report_error(ctx: *mut sqlite3_context, err: &Error) {
74 // Extended constraint error codes were added in SQLite 3.7.16. We don't have
75 // an explicit feature check for that, and this doesn't really warrant one.
76 // We'll use the extended code if we're on the bundled version (since it's
77 // at least 3.17.0) and the normal constraint error code if not.
78 #[cfg(feature = "modern_sqlite")]
79 fn constraint_error_code() -> i32 {
80 ffi::SQLITE_CONSTRAINT_FUNCTION
81 }
82 #[cfg(not(feature = "modern_sqlite"))]
83 fn constraint_error_code() -> i32 {
84 ffi::SQLITE_CONSTRAINT
85 }
86
87 match *err {
88 Error::SqliteFailure(ref err, ref s) => {
89 ffi::sqlite3_result_error_code(ctx, err.extended_code);
90 if let Some(Ok(cstr)) = s.as_ref().map(|s| str_to_cstring(s)) {
91 ffi::sqlite3_result_error(ctx, cstr.as_ptr(), -1);
92 }
93 }
94 _ => {
95 ffi::sqlite3_result_error_code(ctx, constraint_error_code());
96 if let Ok(cstr) = str_to_cstring(&err.to_string()) {
97 ffi::sqlite3_result_error(ctx, cstr.as_ptr(), -1);
98 }
99 }
100 }
101 }
102
free_boxed_value<T>(p: *mut c_void)103 unsafe extern "C" fn free_boxed_value<T>(p: *mut c_void) {
104 drop(Box::from_raw(p as *mut T));
105 }
106
107 /// `feature = "functions"` Context is a wrapper for the SQLite function
108 /// evaluation context.
109 pub struct Context<'a> {
110 ctx: *mut sqlite3_context,
111 args: &'a [*mut sqlite3_value],
112 }
113
114 impl Context<'_> {
115 /// Returns the number of arguments to the function.
len(&self) -> usize116 pub fn len(&self) -> usize {
117 self.args.len()
118 }
119
120 /// Returns `true` when there is no argument.
is_empty(&self) -> bool121 pub fn is_empty(&self) -> bool {
122 self.args.is_empty()
123 }
124
125 /// Returns the `idx`th argument as a `T`.
126 ///
127 /// # Failure
128 ///
129 /// Will panic if `idx` is greater than or equal to `self.len()`.
130 ///
131 /// Will return Err if the underlying SQLite type cannot be converted to a
132 /// `T`.
get<T: FromSql>(&self, idx: usize) -> Result<T>133 pub fn get<T: FromSql>(&self, idx: usize) -> Result<T> {
134 let arg = self.args[idx];
135 let value = unsafe { ValueRef::from_value(arg) };
136 FromSql::column_result(value).map_err(|err| match err {
137 FromSqlError::InvalidType => {
138 Error::InvalidFunctionParameterType(idx, value.data_type())
139 }
140 FromSqlError::OutOfRange(i) => Error::IntegralValueOutOfRange(idx, i),
141 FromSqlError::Other(err) => {
142 Error::FromSqlConversionFailure(idx, value.data_type(), err)
143 }
144 #[cfg(feature = "i128_blob")]
145 FromSqlError::InvalidI128Size(_) => {
146 Error::FromSqlConversionFailure(idx, value.data_type(), Box::new(err))
147 }
148 #[cfg(feature = "uuid")]
149 FromSqlError::InvalidUuidSize(_) => {
150 Error::FromSqlConversionFailure(idx, value.data_type(), Box::new(err))
151 }
152 })
153 }
154
155 /// Returns the `idx`th argument as a `ValueRef`.
156 ///
157 /// # Failure
158 ///
159 /// Will panic if `idx` is greater than or equal to `self.len()`.
get_raw(&self, idx: usize) -> ValueRef<'_>160 pub fn get_raw(&self, idx: usize) -> ValueRef<'_> {
161 let arg = self.args[idx];
162 unsafe { ValueRef::from_value(arg) }
163 }
164
165 /// Fetch or insert the the auxilliary data associated with a particular
166 /// parameter. This is intended to be an easier-to-use way of fetching it
167 /// compared to calling `get_aux` and `set_aux` separately.
168 ///
169 /// See https://www.sqlite.org/c3ref/get_auxdata.html for a discussion of
170 /// this feature, or the unit tests of this module for an example.
get_or_create_aux<T, E, F>(&self, arg: c_int, func: F) -> Result<Arc<T>> where T: Send + Sync + 'static, E: Into<Box<dyn std::error::Error + Send + Sync + 'static>>, F: FnOnce(ValueRef<'_>) -> Result<T, E>,171 pub fn get_or_create_aux<T, E, F>(&self, arg: c_int, func: F) -> Result<Arc<T>>
172 where
173 T: Send + Sync + 'static,
174 E: Into<Box<dyn std::error::Error + Send + Sync + 'static>>,
175 F: FnOnce(ValueRef<'_>) -> Result<T, E>,
176 {
177 if let Some(v) = self.get_aux(arg)? {
178 Ok(v)
179 } else {
180 let vr = self.get_raw(arg as usize);
181 self.set_aux(
182 arg,
183 func(vr).map_err(|e| Error::UserFunctionError(e.into()))?,
184 )
185 }
186 }
187
188 /// Sets the auxilliary data associated with a particular parameter. See
189 /// https://www.sqlite.org/c3ref/get_auxdata.html for a discussion of
190 /// this feature, or the unit tests of this module for an example.
set_aux<T: Send + Sync + 'static>(&self, arg: c_int, value: T) -> Result<Arc<T>>191 pub fn set_aux<T: Send + Sync + 'static>(&self, arg: c_int, value: T) -> Result<Arc<T>> {
192 let orig: Arc<T> = Arc::new(value);
193 let inner: AuxInner = orig.clone();
194 let outer = Box::new(inner);
195 let raw: *mut AuxInner = Box::into_raw(outer);
196 unsafe {
197 ffi::sqlite3_set_auxdata(
198 self.ctx,
199 arg,
200 raw as *mut _,
201 Some(free_boxed_value::<AuxInner>),
202 )
203 };
204 Ok(orig)
205 }
206
207 /// Gets the auxilliary data that was associated with a given parameter via
208 /// `set_aux`. Returns `Ok(None)` if no data has been associated, and
209 /// Ok(Some(v)) if it has. Returns an error if the requested type does not
210 /// match.
get_aux<T: Send + Sync + 'static>(&self, arg: c_int) -> Result<Option<Arc<T>>>211 pub fn get_aux<T: Send + Sync + 'static>(&self, arg: c_int) -> Result<Option<Arc<T>>> {
212 let p = unsafe { ffi::sqlite3_get_auxdata(self.ctx, arg) as *const AuxInner };
213 if p.is_null() {
214 Ok(None)
215 } else {
216 let v: AuxInner = AuxInner::clone(unsafe { &*p });
217 v.downcast::<T>()
218 .map(Some)
219 .map_err(|_| Error::GetAuxWrongType)
220 }
221 }
222 }
223
224 type AuxInner = Arc<dyn Any + Send + Sync + 'static>;
225
226 /// `feature = "functions"` Aggregate is the callback interface for user-defined
227 /// aggregate function.
228 ///
229 /// `A` is the type of the aggregation context and `T` is the type of the final
230 /// result. Implementations should be stateless.
231 pub trait Aggregate<A, T>
232 where
233 A: RefUnwindSafe + UnwindSafe,
234 T: ToSql,
235 {
236 /// Initializes the aggregation context. Will be called prior to the first
237 /// call to `step()` to set up the context for an invocation of the
238 /// function. (Note: `init()` will not be called if there are no rows.)
init(&self) -> A239 fn init(&self) -> A;
240
241 /// "step" function called once for each row in an aggregate group. May be
242 /// called 0 times if there are no rows.
step(&self, _: &mut Context<'_>, _: &mut A) -> Result<()>243 fn step(&self, _: &mut Context<'_>, _: &mut A) -> Result<()>;
244
245 /// Computes and returns the final result. Will be called exactly once for
246 /// each invocation of the function. If `step()` was called at least
247 /// once, will be given `Some(A)` (the same `A` as was created by
248 /// `init` and given to `step`); if `step()` was not called (because
249 /// the function is running against 0 rows), will be given `None`.
finalize(&self, _: Option<A>) -> Result<T>250 fn finalize(&self, _: Option<A>) -> Result<T>;
251 }
252
253 /// `feature = "window"` WindowAggregate is the callback interface for
254 /// user-defined aggregate window function.
255 #[cfg(feature = "window")]
256 pub trait WindowAggregate<A, T>: Aggregate<A, T>
257 where
258 A: RefUnwindSafe + UnwindSafe,
259 T: ToSql,
260 {
261 /// Returns the current value of the aggregate. Unlike xFinal, the
262 /// implementation should not delete any context.
value(&self, _: Option<&A>) -> Result<T>263 fn value(&self, _: Option<&A>) -> Result<T>;
264
265 /// Removes a row from the current window.
inverse(&self, _: &mut Context<'_>, _: &mut A) -> Result<()>266 fn inverse(&self, _: &mut Context<'_>, _: &mut A) -> Result<()>;
267 }
268
269 bitflags::bitflags! {
270 /// Function Flags.
271 /// See [sqlite3_create_function](https://sqlite.org/c3ref/create_function.html)
272 /// and [Function Flags](https://sqlite.org/c3ref/c_deterministic.html) for details.
273 #[repr(C)]
274 pub struct FunctionFlags: ::std::os::raw::c_int {
275 /// Specifies UTF-8 as the text encoding this SQL function prefers for its parameters.
276 const SQLITE_UTF8 = ffi::SQLITE_UTF8;
277 /// Specifies UTF-16 using little-endian byte order as the text encoding this SQL function prefers for its parameters.
278 const SQLITE_UTF16LE = ffi::SQLITE_UTF16LE;
279 /// Specifies UTF-16 using big-endian byte order as the text encoding this SQL function prefers for its parameters.
280 const SQLITE_UTF16BE = ffi::SQLITE_UTF16BE;
281 /// Specifies UTF-16 using native byte order as the text encoding this SQL function prefers for its parameters.
282 const SQLITE_UTF16 = ffi::SQLITE_UTF16;
283 /// Means that the function always gives the same output when the input parameters are the same.
284 const SQLITE_DETERMINISTIC = ffi::SQLITE_DETERMINISTIC;
285 /// Means that the function may only be invoked from top-level SQL.
286 const SQLITE_DIRECTONLY = 0x0000_0008_0000; // 3.30.0
287 /// Indicates to SQLite that a function may call `sqlite3_value_subtype()` to inspect the sub-types of its arguments.
288 const SQLITE_SUBTYPE = 0x0000_0010_0000; // 3.30.0
289 /// Means that the function is unlikely to cause problems even if misused.
290 const SQLITE_INNOCUOUS = 0x0000_0020_0000; // 3.31.0
291 }
292 }
293
294 impl Default for FunctionFlags {
default() -> FunctionFlags295 fn default() -> FunctionFlags {
296 FunctionFlags::SQLITE_UTF8
297 }
298 }
299
300 impl Connection {
301 /// `feature = "functions"` Attach a user-defined scalar function to
302 /// this database connection.
303 ///
304 /// `fn_name` is the name the function will be accessible from SQL.
305 /// `n_arg` is the number of arguments to the function. Use `-1` for a
306 /// variable number. If the function always returns the same value
307 /// given the same input, `deterministic` should be `true`.
308 ///
309 /// The function will remain available until the connection is closed or
310 /// until it is explicitly removed via `remove_function`.
311 ///
312 /// # Example
313 ///
314 /// ```rust
315 /// # use rusqlite::{Connection, Result, NO_PARAMS};
316 /// # use rusqlite::functions::FunctionFlags;
317 /// fn scalar_function_example(db: Connection) -> Result<()> {
318 /// db.create_scalar_function(
319 /// "halve",
320 /// 1,
321 /// FunctionFlags::SQLITE_UTF8 | FunctionFlags::SQLITE_DETERMINISTIC,
322 /// |ctx| {
323 /// let value = ctx.get::<f64>(0)?;
324 /// Ok(value / 2f64)
325 /// },
326 /// )?;
327 ///
328 /// let six_halved: f64 = db.query_row("SELECT halve(6)", NO_PARAMS, |r| r.get(0))?;
329 /// assert_eq!(six_halved, 3f64);
330 /// Ok(())
331 /// }
332 /// ```
333 ///
334 /// # Failure
335 ///
336 /// Will return Err if the function could not be attached to the connection.
create_scalar_function<F, T>( &self, fn_name: &str, n_arg: c_int, flags: FunctionFlags, x_func: F, ) -> Result<()> where F: FnMut(&Context<'_>) -> Result<T> + Send + UnwindSafe + 'static, T: ToSql,337 pub fn create_scalar_function<F, T>(
338 &self,
339 fn_name: &str,
340 n_arg: c_int,
341 flags: FunctionFlags,
342 x_func: F,
343 ) -> Result<()>
344 where
345 F: FnMut(&Context<'_>) -> Result<T> + Send + UnwindSafe + 'static,
346 T: ToSql,
347 {
348 self.db
349 .borrow_mut()
350 .create_scalar_function(fn_name, n_arg, flags, x_func)
351 }
352
353 /// `feature = "functions"` Attach a user-defined aggregate function to this
354 /// database connection.
355 ///
356 /// # Failure
357 ///
358 /// Will return Err if the function could not be attached to the connection.
create_aggregate_function<A, D, T>( &self, fn_name: &str, n_arg: c_int, flags: FunctionFlags, aggr: D, ) -> Result<()> where A: RefUnwindSafe + UnwindSafe, D: Aggregate<A, T>, T: ToSql,359 pub fn create_aggregate_function<A, D, T>(
360 &self,
361 fn_name: &str,
362 n_arg: c_int,
363 flags: FunctionFlags,
364 aggr: D,
365 ) -> Result<()>
366 where
367 A: RefUnwindSafe + UnwindSafe,
368 D: Aggregate<A, T>,
369 T: ToSql,
370 {
371 self.db
372 .borrow_mut()
373 .create_aggregate_function(fn_name, n_arg, flags, aggr)
374 }
375
376 /// `feature = "window"` Attach a user-defined aggregate window function to
377 /// this database connection.
378 ///
379 /// See https://sqlite.org/windowfunctions.html#udfwinfunc for more
380 /// information.
381 #[cfg(feature = "window")]
create_window_function<A, W, T>( &self, fn_name: &str, n_arg: c_int, flags: FunctionFlags, aggr: W, ) -> Result<()> where A: RefUnwindSafe + UnwindSafe, W: WindowAggregate<A, T>, T: ToSql,382 pub fn create_window_function<A, W, T>(
383 &self,
384 fn_name: &str,
385 n_arg: c_int,
386 flags: FunctionFlags,
387 aggr: W,
388 ) -> Result<()>
389 where
390 A: RefUnwindSafe + UnwindSafe,
391 W: WindowAggregate<A, T>,
392 T: ToSql,
393 {
394 self.db
395 .borrow_mut()
396 .create_window_function(fn_name, n_arg, flags, aggr)
397 }
398
399 /// `feature = "functions"` Removes a user-defined function from this
400 /// database connection.
401 ///
402 /// `fn_name` and `n_arg` should match the name and number of arguments
403 /// given to `create_scalar_function` or `create_aggregate_function`.
404 ///
405 /// # Failure
406 ///
407 /// Will return Err if the function could not be removed.
remove_function(&self, fn_name: &str, n_arg: c_int) -> Result<()>408 pub fn remove_function(&self, fn_name: &str, n_arg: c_int) -> Result<()> {
409 self.db.borrow_mut().remove_function(fn_name, n_arg)
410 }
411 }
412
413 impl InnerConnection {
create_scalar_function<F, T>( &mut self, fn_name: &str, n_arg: c_int, flags: FunctionFlags, x_func: F, ) -> Result<()> where F: FnMut(&Context<'_>) -> Result<T> + Send + UnwindSafe + 'static, T: ToSql,414 fn create_scalar_function<F, T>(
415 &mut self,
416 fn_name: &str,
417 n_arg: c_int,
418 flags: FunctionFlags,
419 x_func: F,
420 ) -> Result<()>
421 where
422 F: FnMut(&Context<'_>) -> Result<T> + Send + UnwindSafe + 'static,
423 T: ToSql,
424 {
425 unsafe extern "C" fn call_boxed_closure<F, T>(
426 ctx: *mut sqlite3_context,
427 argc: c_int,
428 argv: *mut *mut sqlite3_value,
429 ) where
430 F: FnMut(&Context<'_>) -> Result<T>,
431 T: ToSql,
432 {
433 let r = catch_unwind(|| {
434 let boxed_f: *mut F = ffi::sqlite3_user_data(ctx) as *mut F;
435 assert!(!boxed_f.is_null(), "Internal error - null function pointer");
436 let ctx = Context {
437 ctx,
438 args: slice::from_raw_parts(argv, argc as usize),
439 };
440 (*boxed_f)(&ctx)
441 });
442 let t = match r {
443 Err(_) => {
444 report_error(ctx, &Error::UnwindingPanic);
445 return;
446 }
447 Ok(r) => r,
448 };
449 let t = t.as_ref().map(|t| ToSql::to_sql(t));
450
451 match t {
452 Ok(Ok(ref value)) => set_result(ctx, value),
453 Ok(Err(err)) => report_error(ctx, &err),
454 Err(err) => report_error(ctx, err),
455 }
456 }
457
458 let boxed_f: *mut F = Box::into_raw(Box::new(x_func));
459 let c_name = str_to_cstring(fn_name)?;
460 let r = unsafe {
461 ffi::sqlite3_create_function_v2(
462 self.db(),
463 c_name.as_ptr(),
464 n_arg,
465 flags.bits(),
466 boxed_f as *mut c_void,
467 Some(call_boxed_closure::<F, T>),
468 None,
469 None,
470 Some(free_boxed_value::<F>),
471 )
472 };
473 self.decode_result(r)
474 }
475
create_aggregate_function<A, D, T>( &mut self, fn_name: &str, n_arg: c_int, flags: FunctionFlags, aggr: D, ) -> Result<()> where A: RefUnwindSafe + UnwindSafe, D: Aggregate<A, T>, T: ToSql,476 fn create_aggregate_function<A, D, T>(
477 &mut self,
478 fn_name: &str,
479 n_arg: c_int,
480 flags: FunctionFlags,
481 aggr: D,
482 ) -> Result<()>
483 where
484 A: RefUnwindSafe + UnwindSafe,
485 D: Aggregate<A, T>,
486 T: ToSql,
487 {
488 let boxed_aggr: *mut D = Box::into_raw(Box::new(aggr));
489 let c_name = str_to_cstring(fn_name)?;
490 let r = unsafe {
491 ffi::sqlite3_create_function_v2(
492 self.db(),
493 c_name.as_ptr(),
494 n_arg,
495 flags.bits(),
496 boxed_aggr as *mut c_void,
497 None,
498 Some(call_boxed_step::<A, D, T>),
499 Some(call_boxed_final::<A, D, T>),
500 Some(free_boxed_value::<D>),
501 )
502 };
503 self.decode_result(r)
504 }
505
506 #[cfg(feature = "window")]
create_window_function<A, W, T>( &mut self, fn_name: &str, n_arg: c_int, flags: FunctionFlags, aggr: W, ) -> Result<()> where A: RefUnwindSafe + UnwindSafe, W: WindowAggregate<A, T>, T: ToSql,507 fn create_window_function<A, W, T>(
508 &mut self,
509 fn_name: &str,
510 n_arg: c_int,
511 flags: FunctionFlags,
512 aggr: W,
513 ) -> Result<()>
514 where
515 A: RefUnwindSafe + UnwindSafe,
516 W: WindowAggregate<A, T>,
517 T: ToSql,
518 {
519 let boxed_aggr: *mut W = Box::into_raw(Box::new(aggr));
520 let c_name = str_to_cstring(fn_name)?;
521 let r = unsafe {
522 ffi::sqlite3_create_window_function(
523 self.db(),
524 c_name.as_ptr(),
525 n_arg,
526 flags.bits(),
527 boxed_aggr as *mut c_void,
528 Some(call_boxed_step::<A, W, T>),
529 Some(call_boxed_final::<A, W, T>),
530 Some(call_boxed_value::<A, W, T>),
531 Some(call_boxed_inverse::<A, W, T>),
532 Some(free_boxed_value::<W>),
533 )
534 };
535 self.decode_result(r)
536 }
537
remove_function(&mut self, fn_name: &str, n_arg: c_int) -> Result<()>538 fn remove_function(&mut self, fn_name: &str, n_arg: c_int) -> Result<()> {
539 let c_name = str_to_cstring(fn_name)?;
540 let r = unsafe {
541 ffi::sqlite3_create_function_v2(
542 self.db(),
543 c_name.as_ptr(),
544 n_arg,
545 ffi::SQLITE_UTF8,
546 ptr::null_mut(),
547 None,
548 None,
549 None,
550 None,
551 )
552 };
553 self.decode_result(r)
554 }
555 }
556
aggregate_context<A>(ctx: *mut sqlite3_context, bytes: usize) -> Option<*mut *mut A>557 unsafe fn aggregate_context<A>(ctx: *mut sqlite3_context, bytes: usize) -> Option<*mut *mut A> {
558 let pac = ffi::sqlite3_aggregate_context(ctx, bytes as c_int) as *mut *mut A;
559 if pac.is_null() {
560 return None;
561 }
562 Some(pac)
563 }
564
call_boxed_step<A, D, T>( ctx: *mut sqlite3_context, argc: c_int, argv: *mut *mut sqlite3_value, ) where A: RefUnwindSafe + UnwindSafe, D: Aggregate<A, T>, T: ToSql,565 unsafe extern "C" fn call_boxed_step<A, D, T>(
566 ctx: *mut sqlite3_context,
567 argc: c_int,
568 argv: *mut *mut sqlite3_value,
569 ) where
570 A: RefUnwindSafe + UnwindSafe,
571 D: Aggregate<A, T>,
572 T: ToSql,
573 {
574 let pac = match aggregate_context(ctx, ::std::mem::size_of::<*mut A>()) {
575 Some(pac) => pac,
576 None => {
577 ffi::sqlite3_result_error_nomem(ctx);
578 return;
579 }
580 };
581
582 let r = catch_unwind(|| {
583 let boxed_aggr: *mut D = ffi::sqlite3_user_data(ctx) as *mut D;
584 assert!(
585 !boxed_aggr.is_null(),
586 "Internal error - null aggregate pointer"
587 );
588 if (*pac as *mut A).is_null() {
589 *pac = Box::into_raw(Box::new((*boxed_aggr).init()));
590 }
591 let mut ctx = Context {
592 ctx,
593 args: slice::from_raw_parts(argv, argc as usize),
594 };
595 (*boxed_aggr).step(&mut ctx, &mut **pac)
596 });
597 let r = match r {
598 Err(_) => {
599 report_error(ctx, &Error::UnwindingPanic);
600 return;
601 }
602 Ok(r) => r,
603 };
604 match r {
605 Ok(_) => {}
606 Err(err) => report_error(ctx, &err),
607 };
608 }
609
610 #[cfg(feature = "window")]
call_boxed_inverse<A, W, T>( ctx: *mut sqlite3_context, argc: c_int, argv: *mut *mut sqlite3_value, ) where A: RefUnwindSafe + UnwindSafe, W: WindowAggregate<A, T>, T: ToSql,611 unsafe extern "C" fn call_boxed_inverse<A, W, T>(
612 ctx: *mut sqlite3_context,
613 argc: c_int,
614 argv: *mut *mut sqlite3_value,
615 ) where
616 A: RefUnwindSafe + UnwindSafe,
617 W: WindowAggregate<A, T>,
618 T: ToSql,
619 {
620 let pac = match aggregate_context(ctx, ::std::mem::size_of::<*mut A>()) {
621 Some(pac) => pac,
622 None => {
623 ffi::sqlite3_result_error_nomem(ctx);
624 return;
625 }
626 };
627
628 let r = catch_unwind(|| {
629 let boxed_aggr: *mut W = ffi::sqlite3_user_data(ctx) as *mut W;
630 assert!(
631 !boxed_aggr.is_null(),
632 "Internal error - null aggregate pointer"
633 );
634 let mut ctx = Context {
635 ctx,
636 args: slice::from_raw_parts(argv, argc as usize),
637 };
638 (*boxed_aggr).inverse(&mut ctx, &mut **pac)
639 });
640 let r = match r {
641 Err(_) => {
642 report_error(ctx, &Error::UnwindingPanic);
643 return;
644 }
645 Ok(r) => r,
646 };
647 match r {
648 Ok(_) => {}
649 Err(err) => report_error(ctx, &err),
650 };
651 }
652
call_boxed_final<A, D, T>(ctx: *mut sqlite3_context) where A: RefUnwindSafe + UnwindSafe, D: Aggregate<A, T>, T: ToSql,653 unsafe extern "C" fn call_boxed_final<A, D, T>(ctx: *mut sqlite3_context)
654 where
655 A: RefUnwindSafe + UnwindSafe,
656 D: Aggregate<A, T>,
657 T: ToSql,
658 {
659 // Within the xFinal callback, it is customary to set N=0 in calls to
660 // sqlite3_aggregate_context(C,N) so that no pointless memory allocations occur.
661 let a: Option<A> = match aggregate_context(ctx, 0) {
662 Some(pac) => {
663 if (*pac as *mut A).is_null() {
664 None
665 } else {
666 let a = Box::from_raw(*pac);
667 Some(*a)
668 }
669 }
670 None => None,
671 };
672
673 let r = catch_unwind(|| {
674 let boxed_aggr: *mut D = ffi::sqlite3_user_data(ctx) as *mut D;
675 assert!(
676 !boxed_aggr.is_null(),
677 "Internal error - null aggregate pointer"
678 );
679 (*boxed_aggr).finalize(a)
680 });
681 let t = match r {
682 Err(_) => {
683 report_error(ctx, &Error::UnwindingPanic);
684 return;
685 }
686 Ok(r) => r,
687 };
688 let t = t.as_ref().map(|t| ToSql::to_sql(t));
689 match t {
690 Ok(Ok(ref value)) => set_result(ctx, value),
691 Ok(Err(err)) => report_error(ctx, &err),
692 Err(err) => report_error(ctx, err),
693 }
694 }
695
696 #[cfg(feature = "window")]
call_boxed_value<A, W, T>(ctx: *mut sqlite3_context) where A: RefUnwindSafe + UnwindSafe, W: WindowAggregate<A, T>, T: ToSql,697 unsafe extern "C" fn call_boxed_value<A, W, T>(ctx: *mut sqlite3_context)
698 where
699 A: RefUnwindSafe + UnwindSafe,
700 W: WindowAggregate<A, T>,
701 T: ToSql,
702 {
703 // Within the xValue callback, it is customary to set N=0 in calls to
704 // sqlite3_aggregate_context(C,N) so that no pointless memory allocations occur.
705 let a: Option<&A> = match aggregate_context(ctx, 0) {
706 Some(pac) => {
707 if (*pac as *mut A).is_null() {
708 None
709 } else {
710 let a = &**pac;
711 Some(a)
712 }
713 }
714 None => None,
715 };
716
717 let r = catch_unwind(|| {
718 let boxed_aggr: *mut W = ffi::sqlite3_user_data(ctx) as *mut W;
719 assert!(
720 !boxed_aggr.is_null(),
721 "Internal error - null aggregate pointer"
722 );
723 (*boxed_aggr).value(a)
724 });
725 let t = match r {
726 Err(_) => {
727 report_error(ctx, &Error::UnwindingPanic);
728 return;
729 }
730 Ok(r) => r,
731 };
732 let t = t.as_ref().map(|t| ToSql::to_sql(t));
733 match t {
734 Ok(Ok(ref value)) => set_result(ctx, value),
735 Ok(Err(err)) => report_error(ctx, &err),
736 Err(err) => report_error(ctx, err),
737 }
738 }
739
740 #[cfg(test)]
741 mod test {
742 use regex::Regex;
743 use std::f64::EPSILON;
744 use std::os::raw::c_double;
745
746 #[cfg(feature = "window")]
747 use crate::functions::WindowAggregate;
748 use crate::functions::{Aggregate, Context, FunctionFlags};
749 use crate::{Connection, Error, Result, NO_PARAMS};
750
half(ctx: &Context<'_>) -> Result<c_double>751 fn half(ctx: &Context<'_>) -> Result<c_double> {
752 assert_eq!(ctx.len(), 1, "called with unexpected number of arguments");
753 let value = ctx.get::<c_double>(0)?;
754 Ok(value / 2f64)
755 }
756
757 #[test]
test_function_half()758 fn test_function_half() {
759 let db = Connection::open_in_memory().unwrap();
760 db.create_scalar_function(
761 "half",
762 1,
763 FunctionFlags::SQLITE_UTF8 | FunctionFlags::SQLITE_DETERMINISTIC,
764 half,
765 )
766 .unwrap();
767 let result: Result<f64> = db.query_row("SELECT half(6)", NO_PARAMS, |r| r.get(0));
768
769 assert!((3f64 - result.unwrap()).abs() < EPSILON);
770 }
771
772 #[test]
test_remove_function()773 fn test_remove_function() {
774 let db = Connection::open_in_memory().unwrap();
775 db.create_scalar_function(
776 "half",
777 1,
778 FunctionFlags::SQLITE_UTF8 | FunctionFlags::SQLITE_DETERMINISTIC,
779 half,
780 )
781 .unwrap();
782 let result: Result<f64> = db.query_row("SELECT half(6)", NO_PARAMS, |r| r.get(0));
783 assert!((3f64 - result.unwrap()).abs() < EPSILON);
784
785 db.remove_function("half", 1).unwrap();
786 let result: Result<f64> = db.query_row("SELECT half(6)", NO_PARAMS, |r| r.get(0));
787 assert!(result.is_err());
788 }
789
790 // This implementation of a regexp scalar function uses SQLite's auxilliary data
791 // (https://www.sqlite.org/c3ref/get_auxdata.html) to avoid recompiling the regular
792 // expression multiple times within one query.
regexp_with_auxilliary(ctx: &Context<'_>) -> Result<bool>793 fn regexp_with_auxilliary(ctx: &Context<'_>) -> Result<bool> {
794 assert_eq!(ctx.len(), 2, "called with unexpected number of arguments");
795 type BoxError = Box<dyn std::error::Error + Send + Sync + 'static>;
796 let regexp: std::sync::Arc<Regex> = ctx
797 .get_or_create_aux(0, |vr| -> Result<_, BoxError> {
798 Ok(Regex::new(vr.as_str()?)?)
799 })?;
800
801 let is_match = {
802 let text = ctx
803 .get_raw(1)
804 .as_str()
805 .map_err(|e| Error::UserFunctionError(e.into()))?;
806
807 regexp.is_match(text)
808 };
809
810 Ok(is_match)
811 }
812
813 #[test]
test_function_regexp_with_auxilliary()814 fn test_function_regexp_with_auxilliary() {
815 let db = Connection::open_in_memory().unwrap();
816 db.execute_batch(
817 "BEGIN;
818 CREATE TABLE foo (x string);
819 INSERT INTO foo VALUES ('lisa');
820 INSERT INTO foo VALUES ('lXsi');
821 INSERT INTO foo VALUES ('lisX');
822 END;",
823 )
824 .unwrap();
825 db.create_scalar_function(
826 "regexp",
827 2,
828 FunctionFlags::SQLITE_UTF8 | FunctionFlags::SQLITE_DETERMINISTIC,
829 regexp_with_auxilliary,
830 )
831 .unwrap();
832
833 let result: Result<bool> =
834 db.query_row("SELECT regexp('l.s[aeiouy]', 'lisa')", NO_PARAMS, |r| {
835 r.get(0)
836 });
837
838 assert_eq!(true, result.unwrap());
839
840 let result: Result<i64> = db.query_row(
841 "SELECT COUNT(*) FROM foo WHERE regexp('l.s[aeiouy]', x) == 1",
842 NO_PARAMS,
843 |r| r.get(0),
844 );
845
846 assert_eq!(2, result.unwrap());
847 }
848
849 #[test]
test_varargs_function()850 fn test_varargs_function() {
851 let db = Connection::open_in_memory().unwrap();
852 db.create_scalar_function(
853 "my_concat",
854 -1,
855 FunctionFlags::SQLITE_UTF8 | FunctionFlags::SQLITE_DETERMINISTIC,
856 |ctx| {
857 let mut ret = String::new();
858
859 for idx in 0..ctx.len() {
860 let s = ctx.get::<String>(idx)?;
861 ret.push_str(&s);
862 }
863
864 Ok(ret)
865 },
866 )
867 .unwrap();
868
869 for &(expected, query) in &[
870 ("", "SELECT my_concat()"),
871 ("onetwo", "SELECT my_concat('one', 'two')"),
872 ("abc", "SELECT my_concat('a', 'b', 'c')"),
873 ] {
874 let result: String = db.query_row(query, NO_PARAMS, |r| r.get(0)).unwrap();
875 assert_eq!(expected, result);
876 }
877 }
878
879 #[test]
test_get_aux_type_checking()880 fn test_get_aux_type_checking() {
881 let db = Connection::open_in_memory().unwrap();
882 db.create_scalar_function("example", 2, FunctionFlags::default(), |ctx| {
883 if !ctx.get::<bool>(1)? {
884 ctx.set_aux::<i64>(0, 100)?;
885 } else {
886 assert_eq!(ctx.get_aux::<String>(0), Err(Error::GetAuxWrongType));
887 assert_eq!(*ctx.get_aux::<i64>(0).unwrap().unwrap(), 100);
888 }
889 Ok(true)
890 })
891 .unwrap();
892
893 let res: bool = db
894 .query_row(
895 "SELECT example(0, i) FROM (SELECT 0 as i UNION SELECT 1)",
896 NO_PARAMS,
897 |r| r.get(0),
898 )
899 .unwrap();
900 // Doesn't actually matter, we'll assert in the function if there's a problem.
901 assert!(res);
902 }
903
904 struct Sum;
905 struct Count;
906
907 impl Aggregate<i64, Option<i64>> for Sum {
init(&self) -> i64908 fn init(&self) -> i64 {
909 0
910 }
911
step(&self, ctx: &mut Context<'_>, sum: &mut i64) -> Result<()>912 fn step(&self, ctx: &mut Context<'_>, sum: &mut i64) -> Result<()> {
913 *sum += ctx.get::<i64>(0)?;
914 Ok(())
915 }
916
finalize(&self, sum: Option<i64>) -> Result<Option<i64>>917 fn finalize(&self, sum: Option<i64>) -> Result<Option<i64>> {
918 Ok(sum)
919 }
920 }
921
922 impl Aggregate<i64, i64> for Count {
init(&self) -> i64923 fn init(&self) -> i64 {
924 0
925 }
926
step(&self, _ctx: &mut Context<'_>, sum: &mut i64) -> Result<()>927 fn step(&self, _ctx: &mut Context<'_>, sum: &mut i64) -> Result<()> {
928 *sum += 1;
929 Ok(())
930 }
931
finalize(&self, sum: Option<i64>) -> Result<i64>932 fn finalize(&self, sum: Option<i64>) -> Result<i64> {
933 Ok(sum.unwrap_or(0))
934 }
935 }
936
937 #[test]
test_sum()938 fn test_sum() {
939 let db = Connection::open_in_memory().unwrap();
940 db.create_aggregate_function(
941 "my_sum",
942 1,
943 FunctionFlags::SQLITE_UTF8 | FunctionFlags::SQLITE_DETERMINISTIC,
944 Sum,
945 )
946 .unwrap();
947
948 // sum should return NULL when given no columns (contrast with count below)
949 let no_result = "SELECT my_sum(i) FROM (SELECT 2 AS i WHERE 1 <> 1)";
950 let result: Option<i64> = db.query_row(no_result, NO_PARAMS, |r| r.get(0)).unwrap();
951 assert!(result.is_none());
952
953 let single_sum = "SELECT my_sum(i) FROM (SELECT 2 AS i UNION ALL SELECT 2)";
954 let result: i64 = db.query_row(single_sum, NO_PARAMS, |r| r.get(0)).unwrap();
955 assert_eq!(4, result);
956
957 let dual_sum = "SELECT my_sum(i), my_sum(j) FROM (SELECT 2 AS i, 1 AS j UNION ALL SELECT \
958 2, 1)";
959 let result: (i64, i64) = db
960 .query_row(dual_sum, NO_PARAMS, |r| Ok((r.get(0)?, r.get(1)?)))
961 .unwrap();
962 assert_eq!((4, 2), result);
963 }
964
965 #[test]
test_count()966 fn test_count() {
967 let db = Connection::open_in_memory().unwrap();
968 db.create_aggregate_function(
969 "my_count",
970 -1,
971 FunctionFlags::SQLITE_UTF8 | FunctionFlags::SQLITE_DETERMINISTIC,
972 Count,
973 )
974 .unwrap();
975
976 // count should return 0 when given no columns (contrast with sum above)
977 let no_result = "SELECT my_count(i) FROM (SELECT 2 AS i WHERE 1 <> 1)";
978 let result: i64 = db.query_row(no_result, NO_PARAMS, |r| r.get(0)).unwrap();
979 assert_eq!(result, 0);
980
981 let single_sum = "SELECT my_count(i) FROM (SELECT 2 AS i UNION ALL SELECT 2)";
982 let result: i64 = db.query_row(single_sum, NO_PARAMS, |r| r.get(0)).unwrap();
983 assert_eq!(2, result);
984 }
985
986 #[cfg(feature = "window")]
987 impl WindowAggregate<i64, Option<i64>> for Sum {
inverse(&self, ctx: &mut Context<'_>, sum: &mut i64) -> Result<()>988 fn inverse(&self, ctx: &mut Context<'_>, sum: &mut i64) -> Result<()> {
989 *sum -= ctx.get::<i64>(0)?;
990 Ok(())
991 }
992
value(&self, sum: Option<&i64>) -> Result<Option<i64>>993 fn value(&self, sum: Option<&i64>) -> Result<Option<i64>> {
994 Ok(sum.copied())
995 }
996 }
997
998 #[test]
999 #[cfg(feature = "window")]
test_window()1000 fn test_window() {
1001 use fallible_iterator::FallibleIterator;
1002
1003 let db = Connection::open_in_memory().unwrap();
1004 db.create_window_function(
1005 "sumint",
1006 1,
1007 FunctionFlags::SQLITE_UTF8 | FunctionFlags::SQLITE_DETERMINISTIC,
1008 Sum,
1009 )
1010 .unwrap();
1011 db.execute_batch(
1012 "CREATE TABLE t3(x, y);
1013 INSERT INTO t3 VALUES('a', 4),
1014 ('b', 5),
1015 ('c', 3),
1016 ('d', 8),
1017 ('e', 1);",
1018 )
1019 .unwrap();
1020
1021 let mut stmt = db
1022 .prepare(
1023 "SELECT x, sumint(y) OVER (
1024 ORDER BY x ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING
1025 ) AS sum_y
1026 FROM t3 ORDER BY x;",
1027 )
1028 .unwrap();
1029
1030 let results: Vec<(String, i64)> = stmt
1031 .query(NO_PARAMS)
1032 .unwrap()
1033 .map(|row| Ok((row.get("x")?, row.get("sum_y")?)))
1034 .collect()
1035 .unwrap();
1036 let expected = vec![
1037 ("a".to_owned(), 9),
1038 ("b".to_owned(), 12),
1039 ("c".to_owned(), 16),
1040 ("d".to_owned(), 12),
1041 ("e".to_owned(), 9),
1042 ];
1043 assert_eq!(expected, results);
1044 }
1045 }
1046