1 //! `feature = "functions"` Create or redefine SQL functions.
2 //!
3 //! # Example
4 //!
5 //! Adding a `regexp` function to a connection in which compiled regular
6 //! expressions are cached in a `HashMap`. For an alternative implementation
7 //! that uses SQLite's [Function Auxilliary Data](https://www.sqlite.org/c3ref/get_auxdata.html) interface
8 //! to avoid recompiling regular expressions, see the unit tests for this
9 //! module.
10 //!
11 //! ```rust
12 //! use regex::Regex;
13 //! use rusqlite::functions::FunctionFlags;
14 //! use rusqlite::{Connection, Error, Result, NO_PARAMS};
15 //! use std::sync::Arc;
16 //! type BoxError = Box<dyn std::error::Error + Send + Sync + 'static>;
17 //!
18 //! fn add_regexp_function(db: &Connection) -> Result<()> {
19 //! db.create_scalar_function(
20 //! "regexp",
21 //! 2,
22 //! FunctionFlags::SQLITE_UTF8 | FunctionFlags::SQLITE_DETERMINISTIC,
23 //! move |ctx| {
24 //! assert_eq!(ctx.len(), 2, "called with unexpected number of arguments");
25 //! let regexp: Arc<Regex> = ctx
26 //! .get_or_create_aux(0, |vr| -> Result<_, BoxError> {
27 //! Ok(Regex::new(vr.as_str()?)?)
28 //! })?;
29 //! let is_match = {
30 //! let text = ctx
31 //! .get_raw(1)
32 //! .as_str()
33 //! .map_err(|e| Error::UserFunctionError(e.into()))?;
34 //!
35 //! regexp.is_match(text)
36 //! };
37 //!
38 //! Ok(is_match)
39 //! },
40 //! )
41 //! }
42 //!
43 //! fn main() -> Result<()> {
44 //! let db = Connection::open_in_memory()?;
45 //! add_regexp_function(&db)?;
46 //!
47 //! let is_match: bool = db.query_row(
48 //! "SELECT regexp('[aeiou]*', 'aaaaeeeiii')",
49 //! NO_PARAMS,
50 //! |row| row.get(0),
51 //! )?;
52 //!
53 //! assert!(is_match);
54 //! Ok(())
55 //! }
56 //! ```
57 use std::any::Any;
58 use std::os::raw::{c_int, c_void};
59 use std::panic::{catch_unwind, RefUnwindSafe, UnwindSafe};
60 use std::ptr;
61 use std::slice;
62 use std::sync::Arc;
63
64 use crate::ffi;
65 use crate::ffi::sqlite3_context;
66 use crate::ffi::sqlite3_value;
67
68 use crate::context::set_result;
69 use crate::types::{FromSql, FromSqlError, ToSql, ValueRef};
70
71 use crate::{str_to_cstring, Connection, Error, InnerConnection, Result};
72
report_error(ctx: *mut sqlite3_context, err: &Error)73 unsafe fn report_error(ctx: *mut sqlite3_context, err: &Error) {
74 // Extended constraint error codes were added in SQLite 3.7.16. We don't have
75 // an explicit feature check for that, and this doesn't really warrant one.
76 // We'll use the extended code if we're on the bundled version (since it's
77 // at least 3.17.0) and the normal constraint error code if not.
78 #[cfg(feature = "modern_sqlite")]
79 fn constraint_error_code() -> i32 {
80 ffi::SQLITE_CONSTRAINT_FUNCTION
81 }
82 #[cfg(not(feature = "modern_sqlite"))]
83 fn constraint_error_code() -> i32 {
84 ffi::SQLITE_CONSTRAINT
85 }
86
87 match *err {
88 Error::SqliteFailure(ref err, ref s) => {
89 ffi::sqlite3_result_error_code(ctx, err.extended_code);
90 if let Some(Ok(cstr)) = s.as_ref().map(|s| str_to_cstring(s)) {
91 ffi::sqlite3_result_error(ctx, cstr.as_ptr(), -1);
92 }
93 }
94 _ => {
95 ffi::sqlite3_result_error_code(ctx, constraint_error_code());
96 if let Ok(cstr) = str_to_cstring(&err.to_string()) {
97 ffi::sqlite3_result_error(ctx, cstr.as_ptr(), -1);
98 }
99 }
100 }
101 }
102
free_boxed_value<T>(p: *mut c_void)103 unsafe extern "C" fn free_boxed_value<T>(p: *mut c_void) {
104 drop(Box::from_raw(p as *mut T));
105 }
106
107 /// `feature = "functions"` Context is a wrapper for the SQLite function
108 /// evaluation context.
109 pub struct Context<'a> {
110 ctx: *mut sqlite3_context,
111 args: &'a [*mut sqlite3_value],
112 }
113
114 impl Context<'_> {
115 /// Returns the number of arguments to the function.
len(&self) -> usize116 pub fn len(&self) -> usize {
117 self.args.len()
118 }
119
120 /// Returns `true` when there is no argument.
is_empty(&self) -> bool121 pub fn is_empty(&self) -> bool {
122 self.args.is_empty()
123 }
124
125 /// Returns the `idx`th argument as a `T`.
126 ///
127 /// # Failure
128 ///
129 /// Will panic if `idx` is greater than or equal to `self.len()`.
130 ///
131 /// Will return Err if the underlying SQLite type cannot be converted to a
132 /// `T`.
get<T: FromSql>(&self, idx: usize) -> Result<T>133 pub fn get<T: FromSql>(&self, idx: usize) -> Result<T> {
134 let arg = self.args[idx];
135 let value = unsafe { ValueRef::from_value(arg) };
136 FromSql::column_result(value).map_err(|err| match err {
137 FromSqlError::InvalidType => {
138 Error::InvalidFunctionParameterType(idx, value.data_type())
139 }
140 FromSqlError::OutOfRange(i) => Error::IntegralValueOutOfRange(idx, i),
141 FromSqlError::Other(err) => {
142 Error::FromSqlConversionFailure(idx, value.data_type(), err)
143 }
144 #[cfg(feature = "i128_blob")]
145 FromSqlError::InvalidI128Size(_) => {
146 Error::FromSqlConversionFailure(idx, value.data_type(), Box::new(err))
147 }
148 #[cfg(feature = "uuid")]
149 FromSqlError::InvalidUuidSize(_) => {
150 Error::FromSqlConversionFailure(idx, value.data_type(), Box::new(err))
151 }
152 })
153 }
154
155 /// Returns the `idx`th argument as a `ValueRef`.
156 ///
157 /// # Failure
158 ///
159 /// Will panic if `idx` is greater than or equal to `self.len()`.
get_raw(&self, idx: usize) -> ValueRef<'_>160 pub fn get_raw(&self, idx: usize) -> ValueRef<'_> {
161 let arg = self.args[idx];
162 unsafe { ValueRef::from_value(arg) }
163 }
164
165 /// Fetch or insert the the auxilliary data associated with a particular
166 /// parameter. This is intended to be an easier-to-use way of fetching it
167 /// compared to calling `get_aux` and `set_aux` separately.
168 ///
169 /// See https://www.sqlite.org/c3ref/get_auxdata.html for a discussion of
170 /// this feature, or the unit tests of this module for an example.
get_or_create_aux<T, E, F>(&self, arg: c_int, func: F) -> Result<Arc<T>> where T: Send + Sync + 'static, E: Into<Box<dyn std::error::Error + Send + Sync + 'static>>, F: FnOnce(ValueRef<'_>) -> Result<T, E>,171 pub fn get_or_create_aux<T, E, F>(&self, arg: c_int, func: F) -> Result<Arc<T>>
172 where
173 T: Send + Sync + 'static,
174 E: Into<Box<dyn std::error::Error + Send + Sync + 'static>>,
175 F: FnOnce(ValueRef<'_>) -> Result<T, E>,
176 {
177 if let Some(v) = self.get_aux(arg)? {
178 Ok(v)
179 } else {
180 let vr = self.get_raw(arg as usize);
181 self.set_aux(
182 arg,
183 func(vr).map_err(|e| Error::UserFunctionError(e.into()))?,
184 )
185 }
186 }
187
188 /// Sets the auxilliary data associated with a particular parameter. See
189 /// https://www.sqlite.org/c3ref/get_auxdata.html for a discussion of
190 /// this feature, or the unit tests of this module for an example.
set_aux<T: Send + Sync + 'static>(&self, arg: c_int, value: T) -> Result<Arc<T>>191 pub fn set_aux<T: Send + Sync + 'static>(&self, arg: c_int, value: T) -> Result<Arc<T>> {
192 let orig: Arc<T> = Arc::new(value);
193 let inner: AuxInner = orig.clone();
194 let outer = Box::new(inner);
195 let raw: *mut AuxInner = Box::into_raw(outer);
196 unsafe {
197 ffi::sqlite3_set_auxdata(
198 self.ctx,
199 arg,
200 raw as *mut _,
201 Some(free_boxed_value::<AuxInner>),
202 )
203 };
204 Ok(orig)
205 }
206
207 /// Gets the auxilliary data that was associated with a given parameter via
208 /// `set_aux`. Returns `Ok(None)` if no data has been associated, and
209 /// Ok(Some(v)) if it has. Returns an error if the requested type does not
210 /// match.
get_aux<T: Send + Sync + 'static>(&self, arg: c_int) -> Result<Option<Arc<T>>>211 pub fn get_aux<T: Send + Sync + 'static>(&self, arg: c_int) -> Result<Option<Arc<T>>> {
212 let p = unsafe { ffi::sqlite3_get_auxdata(self.ctx, arg) as *const AuxInner };
213 if p.is_null() {
214 Ok(None)
215 } else {
216 let v: AuxInner = AuxInner::clone(unsafe { &*p });
217 v.downcast::<T>()
218 .map(Some)
219 .map_err(|_| Error::GetAuxWrongType)
220 }
221 }
222 }
223
224 type AuxInner = Arc<dyn Any + Send + Sync + 'static>;
225
226 /// `feature = "functions"` Aggregate is the callback interface for user-defined
227 /// aggregate function.
228 ///
229 /// `A` is the type of the aggregation context and `T` is the type of the final
230 /// result. Implementations should be stateless.
231 pub trait Aggregate<A, T>
232 where
233 A: RefUnwindSafe + UnwindSafe,
234 T: ToSql,
235 {
236 /// Initializes the aggregation context. Will be called prior to the first
237 /// call to `step()` to set up the context for an invocation of the
238 /// function. (Note: `init()` will not be called if there are no rows.)
init(&self) -> A239 fn init(&self) -> A;
240
241 /// "step" function called once for each row in an aggregate group. May be
242 /// called 0 times if there are no rows.
step(&self, _: &mut Context<'_>, _: &mut A) -> Result<()>243 fn step(&self, _: &mut Context<'_>, _: &mut A) -> Result<()>;
244
245 /// Computes and returns the final result. Will be called exactly once for
246 /// each invocation of the function. If `step()` was called at least
247 /// once, will be given `Some(A)` (the same `A` as was created by
248 /// `init` and given to `step`); if `step()` was not called (because
249 /// the function is running against 0 rows), will be given `None`.
finalize(&self, _: Option<A>) -> Result<T>250 fn finalize(&self, _: Option<A>) -> Result<T>;
251 }
252
253 /// `feature = "window"` WindowAggregate is the callback interface for
254 /// user-defined aggregate window function.
255 #[cfg(feature = "window")]
256 pub trait WindowAggregate<A, T>: Aggregate<A, T>
257 where
258 A: RefUnwindSafe + UnwindSafe,
259 T: ToSql,
260 {
261 /// Returns the current value of the aggregate. Unlike xFinal, the
262 /// implementation should not delete any context.
value(&self, _: Option<&A>) -> Result<T>263 fn value(&self, _: Option<&A>) -> Result<T>;
264
265 /// Removes a row from the current window.
inverse(&self, _: &mut Context<'_>, _: &mut A) -> Result<()>266 fn inverse(&self, _: &mut Context<'_>, _: &mut A) -> Result<()>;
267 }
268
269 bitflags::bitflags! {
270 #[doc = "Function Flags."]
271 #[doc = "See [sqlite3_create_function](https://sqlite.org/c3ref/create_function.html) for details."]
272 #[repr(C)]
273 pub struct FunctionFlags: ::std::os::raw::c_int {
274 const SQLITE_UTF8 = ffi::SQLITE_UTF8;
275 const SQLITE_UTF16LE = ffi::SQLITE_UTF16LE;
276 const SQLITE_UTF16BE = ffi::SQLITE_UTF16BE;
277 const SQLITE_UTF16 = ffi::SQLITE_UTF16;
278 const SQLITE_DETERMINISTIC = ffi::SQLITE_DETERMINISTIC;
279 const SQLITE_DIRECTONLY = 0x0000_0008_0000; // 3.30.0
280 const SQLITE_SUBTYPE = 0x0000_0010_0000; // 3.30.0
281 const SQLITE_INNOCUOUS = 0x0000_0020_0000; // 3.31.0
282 }
283 }
284
285 impl Default for FunctionFlags {
default() -> FunctionFlags286 fn default() -> FunctionFlags {
287 FunctionFlags::SQLITE_UTF8
288 }
289 }
290
291 impl Connection {
292 /// `feature = "functions"` Attach a user-defined scalar function to
293 /// this database connection.
294 ///
295 /// `fn_name` is the name the function will be accessible from SQL.
296 /// `n_arg` is the number of arguments to the function. Use `-1` for a
297 /// variable number. If the function always returns the same value
298 /// given the same input, `deterministic` should be `true`.
299 ///
300 /// The function will remain available until the connection is closed or
301 /// until it is explicitly removed via `remove_function`.
302 ///
303 /// # Example
304 ///
305 /// ```rust
306 /// # use rusqlite::{Connection, Result, NO_PARAMS};
307 /// # use rusqlite::functions::FunctionFlags;
308 /// fn scalar_function_example(db: Connection) -> Result<()> {
309 /// db.create_scalar_function(
310 /// "halve",
311 /// 1,
312 /// FunctionFlags::SQLITE_UTF8 | FunctionFlags::SQLITE_DETERMINISTIC,
313 /// |ctx| {
314 /// let value = ctx.get::<f64>(0)?;
315 /// Ok(value / 2f64)
316 /// },
317 /// )?;
318 ///
319 /// let six_halved: f64 = db.query_row("SELECT halve(6)", NO_PARAMS, |r| r.get(0))?;
320 /// assert_eq!(six_halved, 3f64);
321 /// Ok(())
322 /// }
323 /// ```
324 ///
325 /// # Failure
326 ///
327 /// Will return Err if the function could not be attached to the connection.
create_scalar_function<F, T>( &self, fn_name: &str, n_arg: c_int, flags: FunctionFlags, x_func: F, ) -> Result<()> where F: FnMut(&Context<'_>) -> Result<T> + Send + UnwindSafe + 'static, T: ToSql,328 pub fn create_scalar_function<F, T>(
329 &self,
330 fn_name: &str,
331 n_arg: c_int,
332 flags: FunctionFlags,
333 x_func: F,
334 ) -> Result<()>
335 where
336 F: FnMut(&Context<'_>) -> Result<T> + Send + UnwindSafe + 'static,
337 T: ToSql,
338 {
339 self.db
340 .borrow_mut()
341 .create_scalar_function(fn_name, n_arg, flags, x_func)
342 }
343
344 /// `feature = "functions"` Attach a user-defined aggregate function to this
345 /// database connection.
346 ///
347 /// # Failure
348 ///
349 /// Will return Err if the function could not be attached to the connection.
create_aggregate_function<A, D, T>( &self, fn_name: &str, n_arg: c_int, flags: FunctionFlags, aggr: D, ) -> Result<()> where A: RefUnwindSafe + UnwindSafe, D: Aggregate<A, T>, T: ToSql,350 pub fn create_aggregate_function<A, D, T>(
351 &self,
352 fn_name: &str,
353 n_arg: c_int,
354 flags: FunctionFlags,
355 aggr: D,
356 ) -> Result<()>
357 where
358 A: RefUnwindSafe + UnwindSafe,
359 D: Aggregate<A, T>,
360 T: ToSql,
361 {
362 self.db
363 .borrow_mut()
364 .create_aggregate_function(fn_name, n_arg, flags, aggr)
365 }
366
367 /// `feature = "window"` Attach a user-defined aggregate window function to
368 /// this database connection.
369 ///
370 /// See https://sqlite.org/windowfunctions.html#udfwinfunc for more
371 /// information.
372 #[cfg(feature = "window")]
create_window_function<A, W, T>( &self, fn_name: &str, n_arg: c_int, flags: FunctionFlags, aggr: W, ) -> Result<()> where A: RefUnwindSafe + UnwindSafe, W: WindowAggregate<A, T>, T: ToSql,373 pub fn create_window_function<A, W, T>(
374 &self,
375 fn_name: &str,
376 n_arg: c_int,
377 flags: FunctionFlags,
378 aggr: W,
379 ) -> Result<()>
380 where
381 A: RefUnwindSafe + UnwindSafe,
382 W: WindowAggregate<A, T>,
383 T: ToSql,
384 {
385 self.db
386 .borrow_mut()
387 .create_window_function(fn_name, n_arg, flags, aggr)
388 }
389
390 /// `feature = "functions"` Removes a user-defined function from this
391 /// database connection.
392 ///
393 /// `fn_name` and `n_arg` should match the name and number of arguments
394 /// given to `create_scalar_function` or `create_aggregate_function`.
395 ///
396 /// # Failure
397 ///
398 /// Will return Err if the function could not be removed.
remove_function(&self, fn_name: &str, n_arg: c_int) -> Result<()>399 pub fn remove_function(&self, fn_name: &str, n_arg: c_int) -> Result<()> {
400 self.db.borrow_mut().remove_function(fn_name, n_arg)
401 }
402 }
403
404 impl InnerConnection {
create_scalar_function<F, T>( &mut self, fn_name: &str, n_arg: c_int, flags: FunctionFlags, x_func: F, ) -> Result<()> where F: FnMut(&Context<'_>) -> Result<T> + Send + UnwindSafe + 'static, T: ToSql,405 fn create_scalar_function<F, T>(
406 &mut self,
407 fn_name: &str,
408 n_arg: c_int,
409 flags: FunctionFlags,
410 x_func: F,
411 ) -> Result<()>
412 where
413 F: FnMut(&Context<'_>) -> Result<T> + Send + UnwindSafe + 'static,
414 T: ToSql,
415 {
416 unsafe extern "C" fn call_boxed_closure<F, T>(
417 ctx: *mut sqlite3_context,
418 argc: c_int,
419 argv: *mut *mut sqlite3_value,
420 ) where
421 F: FnMut(&Context<'_>) -> Result<T>,
422 T: ToSql,
423 {
424 let r = catch_unwind(|| {
425 let boxed_f: *mut F = ffi::sqlite3_user_data(ctx) as *mut F;
426 assert!(!boxed_f.is_null(), "Internal error - null function pointer");
427 let ctx = Context {
428 ctx,
429 args: slice::from_raw_parts(argv, argc as usize),
430 };
431 (*boxed_f)(&ctx)
432 });
433 let t = match r {
434 Err(_) => {
435 report_error(ctx, &Error::UnwindingPanic);
436 return;
437 }
438 Ok(r) => r,
439 };
440 let t = t.as_ref().map(|t| ToSql::to_sql(t));
441
442 match t {
443 Ok(Ok(ref value)) => set_result(ctx, value),
444 Ok(Err(err)) => report_error(ctx, &err),
445 Err(err) => report_error(ctx, err),
446 }
447 }
448
449 let boxed_f: *mut F = Box::into_raw(Box::new(x_func));
450 let c_name = str_to_cstring(fn_name)?;
451 let r = unsafe {
452 ffi::sqlite3_create_function_v2(
453 self.db(),
454 c_name.as_ptr(),
455 n_arg,
456 flags.bits(),
457 boxed_f as *mut c_void,
458 Some(call_boxed_closure::<F, T>),
459 None,
460 None,
461 Some(free_boxed_value::<F>),
462 )
463 };
464 self.decode_result(r)
465 }
466
create_aggregate_function<A, D, T>( &mut self, fn_name: &str, n_arg: c_int, flags: FunctionFlags, aggr: D, ) -> Result<()> where A: RefUnwindSafe + UnwindSafe, D: Aggregate<A, T>, T: ToSql,467 fn create_aggregate_function<A, D, T>(
468 &mut self,
469 fn_name: &str,
470 n_arg: c_int,
471 flags: FunctionFlags,
472 aggr: D,
473 ) -> Result<()>
474 where
475 A: RefUnwindSafe + UnwindSafe,
476 D: Aggregate<A, T>,
477 T: ToSql,
478 {
479 let boxed_aggr: *mut D = Box::into_raw(Box::new(aggr));
480 let c_name = str_to_cstring(fn_name)?;
481 let r = unsafe {
482 ffi::sqlite3_create_function_v2(
483 self.db(),
484 c_name.as_ptr(),
485 n_arg,
486 flags.bits(),
487 boxed_aggr as *mut c_void,
488 None,
489 Some(call_boxed_step::<A, D, T>),
490 Some(call_boxed_final::<A, D, T>),
491 Some(free_boxed_value::<D>),
492 )
493 };
494 self.decode_result(r)
495 }
496
497 #[cfg(feature = "window")]
create_window_function<A, W, T>( &mut self, fn_name: &str, n_arg: c_int, flags: FunctionFlags, aggr: W, ) -> Result<()> where A: RefUnwindSafe + UnwindSafe, W: WindowAggregate<A, T>, T: ToSql,498 fn create_window_function<A, W, T>(
499 &mut self,
500 fn_name: &str,
501 n_arg: c_int,
502 flags: FunctionFlags,
503 aggr: W,
504 ) -> Result<()>
505 where
506 A: RefUnwindSafe + UnwindSafe,
507 W: WindowAggregate<A, T>,
508 T: ToSql,
509 {
510 let boxed_aggr: *mut W = Box::into_raw(Box::new(aggr));
511 let c_name = str_to_cstring(fn_name)?;
512 let r = unsafe {
513 ffi::sqlite3_create_window_function(
514 self.db(),
515 c_name.as_ptr(),
516 n_arg,
517 flags.bits(),
518 boxed_aggr as *mut c_void,
519 Some(call_boxed_step::<A, W, T>),
520 Some(call_boxed_final::<A, W, T>),
521 Some(call_boxed_value::<A, W, T>),
522 Some(call_boxed_inverse::<A, W, T>),
523 Some(free_boxed_value::<W>),
524 )
525 };
526 self.decode_result(r)
527 }
528
remove_function(&mut self, fn_name: &str, n_arg: c_int) -> Result<()>529 fn remove_function(&mut self, fn_name: &str, n_arg: c_int) -> Result<()> {
530 let c_name = str_to_cstring(fn_name)?;
531 let r = unsafe {
532 ffi::sqlite3_create_function_v2(
533 self.db(),
534 c_name.as_ptr(),
535 n_arg,
536 ffi::SQLITE_UTF8,
537 ptr::null_mut(),
538 None,
539 None,
540 None,
541 None,
542 )
543 };
544 self.decode_result(r)
545 }
546 }
547
aggregate_context<A>(ctx: *mut sqlite3_context, bytes: usize) -> Option<*mut *mut A>548 unsafe fn aggregate_context<A>(ctx: *mut sqlite3_context, bytes: usize) -> Option<*mut *mut A> {
549 let pac = ffi::sqlite3_aggregate_context(ctx, bytes as c_int) as *mut *mut A;
550 if pac.is_null() {
551 return None;
552 }
553 Some(pac)
554 }
555
call_boxed_step<A, D, T>( ctx: *mut sqlite3_context, argc: c_int, argv: *mut *mut sqlite3_value, ) where A: RefUnwindSafe + UnwindSafe, D: Aggregate<A, T>, T: ToSql,556 unsafe extern "C" fn call_boxed_step<A, D, T>(
557 ctx: *mut sqlite3_context,
558 argc: c_int,
559 argv: *mut *mut sqlite3_value,
560 ) where
561 A: RefUnwindSafe + UnwindSafe,
562 D: Aggregate<A, T>,
563 T: ToSql,
564 {
565 let pac = match aggregate_context(ctx, ::std::mem::size_of::<*mut A>()) {
566 Some(pac) => pac,
567 None => {
568 ffi::sqlite3_result_error_nomem(ctx);
569 return;
570 }
571 };
572
573 let r = catch_unwind(|| {
574 let boxed_aggr: *mut D = ffi::sqlite3_user_data(ctx) as *mut D;
575 assert!(
576 !boxed_aggr.is_null(),
577 "Internal error - null aggregate pointer"
578 );
579 if (*pac as *mut A).is_null() {
580 *pac = Box::into_raw(Box::new((*boxed_aggr).init()));
581 }
582 let mut ctx = Context {
583 ctx,
584 args: slice::from_raw_parts(argv, argc as usize),
585 };
586 (*boxed_aggr).step(&mut ctx, &mut **pac)
587 });
588 let r = match r {
589 Err(_) => {
590 report_error(ctx, &Error::UnwindingPanic);
591 return;
592 }
593 Ok(r) => r,
594 };
595 match r {
596 Ok(_) => {}
597 Err(err) => report_error(ctx, &err),
598 };
599 }
600
601 #[cfg(feature = "window")]
call_boxed_inverse<A, W, T>( ctx: *mut sqlite3_context, argc: c_int, argv: *mut *mut sqlite3_value, ) where A: RefUnwindSafe + UnwindSafe, W: WindowAggregate<A, T>, T: ToSql,602 unsafe extern "C" fn call_boxed_inverse<A, W, T>(
603 ctx: *mut sqlite3_context,
604 argc: c_int,
605 argv: *mut *mut sqlite3_value,
606 ) where
607 A: RefUnwindSafe + UnwindSafe,
608 W: WindowAggregate<A, T>,
609 T: ToSql,
610 {
611 let pac = match aggregate_context(ctx, ::std::mem::size_of::<*mut A>()) {
612 Some(pac) => pac,
613 None => {
614 ffi::sqlite3_result_error_nomem(ctx);
615 return;
616 }
617 };
618
619 let r = catch_unwind(|| {
620 let boxed_aggr: *mut W = ffi::sqlite3_user_data(ctx) as *mut W;
621 assert!(
622 !boxed_aggr.is_null(),
623 "Internal error - null aggregate pointer"
624 );
625 let mut ctx = Context {
626 ctx,
627 args: slice::from_raw_parts(argv, argc as usize),
628 };
629 (*boxed_aggr).inverse(&mut ctx, &mut **pac)
630 });
631 let r = match r {
632 Err(_) => {
633 report_error(ctx, &Error::UnwindingPanic);
634 return;
635 }
636 Ok(r) => r,
637 };
638 match r {
639 Ok(_) => {}
640 Err(err) => report_error(ctx, &err),
641 };
642 }
643
call_boxed_final<A, D, T>(ctx: *mut sqlite3_context) where A: RefUnwindSafe + UnwindSafe, D: Aggregate<A, T>, T: ToSql,644 unsafe extern "C" fn call_boxed_final<A, D, T>(ctx: *mut sqlite3_context)
645 where
646 A: RefUnwindSafe + UnwindSafe,
647 D: Aggregate<A, T>,
648 T: ToSql,
649 {
650 // Within the xFinal callback, it is customary to set N=0 in calls to
651 // sqlite3_aggregate_context(C,N) so that no pointless memory allocations occur.
652 let a: Option<A> = match aggregate_context(ctx, 0) {
653 Some(pac) => {
654 if (*pac as *mut A).is_null() {
655 None
656 } else {
657 let a = Box::from_raw(*pac);
658 Some(*a)
659 }
660 }
661 None => None,
662 };
663
664 let r = catch_unwind(|| {
665 let boxed_aggr: *mut D = ffi::sqlite3_user_data(ctx) as *mut D;
666 assert!(
667 !boxed_aggr.is_null(),
668 "Internal error - null aggregate pointer"
669 );
670 (*boxed_aggr).finalize(a)
671 });
672 let t = match r {
673 Err(_) => {
674 report_error(ctx, &Error::UnwindingPanic);
675 return;
676 }
677 Ok(r) => r,
678 };
679 let t = t.as_ref().map(|t| ToSql::to_sql(t));
680 match t {
681 Ok(Ok(ref value)) => set_result(ctx, value),
682 Ok(Err(err)) => report_error(ctx, &err),
683 Err(err) => report_error(ctx, err),
684 }
685 }
686
687 #[cfg(feature = "window")]
call_boxed_value<A, W, T>(ctx: *mut sqlite3_context) where A: RefUnwindSafe + UnwindSafe, W: WindowAggregate<A, T>, T: ToSql,688 unsafe extern "C" fn call_boxed_value<A, W, T>(ctx: *mut sqlite3_context)
689 where
690 A: RefUnwindSafe + UnwindSafe,
691 W: WindowAggregate<A, T>,
692 T: ToSql,
693 {
694 // Within the xValue callback, it is customary to set N=0 in calls to
695 // sqlite3_aggregate_context(C,N) so that no pointless memory allocations occur.
696 let a: Option<&A> = match aggregate_context(ctx, 0) {
697 Some(pac) => {
698 if (*pac as *mut A).is_null() {
699 None
700 } else {
701 let a = &**pac;
702 Some(a)
703 }
704 }
705 None => None,
706 };
707
708 let r = catch_unwind(|| {
709 let boxed_aggr: *mut W = ffi::sqlite3_user_data(ctx) as *mut W;
710 assert!(
711 !boxed_aggr.is_null(),
712 "Internal error - null aggregate pointer"
713 );
714 (*boxed_aggr).value(a)
715 });
716 let t = match r {
717 Err(_) => {
718 report_error(ctx, &Error::UnwindingPanic);
719 return;
720 }
721 Ok(r) => r,
722 };
723 let t = t.as_ref().map(|t| ToSql::to_sql(t));
724 match t {
725 Ok(Ok(ref value)) => set_result(ctx, value),
726 Ok(Err(err)) => report_error(ctx, &err),
727 Err(err) => report_error(ctx, err),
728 }
729 }
730
731 #[cfg(test)]
732 mod test {
733 use regex::Regex;
734 use std::f64::EPSILON;
735 use std::os::raw::c_double;
736
737 #[cfg(feature = "window")]
738 use crate::functions::WindowAggregate;
739 use crate::functions::{Aggregate, Context, FunctionFlags};
740 use crate::{Connection, Error, Result, NO_PARAMS};
741
half(ctx: &Context<'_>) -> Result<c_double>742 fn half(ctx: &Context<'_>) -> Result<c_double> {
743 assert_eq!(ctx.len(), 1, "called with unexpected number of arguments");
744 let value = ctx.get::<c_double>(0)?;
745 Ok(value / 2f64)
746 }
747
748 #[test]
test_function_half()749 fn test_function_half() {
750 let db = Connection::open_in_memory().unwrap();
751 db.create_scalar_function(
752 "half",
753 1,
754 FunctionFlags::SQLITE_UTF8 | FunctionFlags::SQLITE_DETERMINISTIC,
755 half,
756 )
757 .unwrap();
758 let result: Result<f64> = db.query_row("SELECT half(6)", NO_PARAMS, |r| r.get(0));
759
760 assert!((3f64 - result.unwrap()).abs() < EPSILON);
761 }
762
763 #[test]
test_remove_function()764 fn test_remove_function() {
765 let db = Connection::open_in_memory().unwrap();
766 db.create_scalar_function(
767 "half",
768 1,
769 FunctionFlags::SQLITE_UTF8 | FunctionFlags::SQLITE_DETERMINISTIC,
770 half,
771 )
772 .unwrap();
773 let result: Result<f64> = db.query_row("SELECT half(6)", NO_PARAMS, |r| r.get(0));
774 assert!((3f64 - result.unwrap()).abs() < EPSILON);
775
776 db.remove_function("half", 1).unwrap();
777 let result: Result<f64> = db.query_row("SELECT half(6)", NO_PARAMS, |r| r.get(0));
778 assert!(result.is_err());
779 }
780
781 // This implementation of a regexp scalar function uses SQLite's auxilliary data
782 // (https://www.sqlite.org/c3ref/get_auxdata.html) to avoid recompiling the regular
783 // expression multiple times within one query.
regexp_with_auxilliary(ctx: &Context<'_>) -> Result<bool>784 fn regexp_with_auxilliary(ctx: &Context<'_>) -> Result<bool> {
785 assert_eq!(ctx.len(), 2, "called with unexpected number of arguments");
786 type BoxError = Box<dyn std::error::Error + Send + Sync + 'static>;
787 let regexp: std::sync::Arc<Regex> = ctx
788 .get_or_create_aux(0, |vr| -> Result<_, BoxError> {
789 Ok(Regex::new(vr.as_str()?)?)
790 })?;
791
792 let is_match = {
793 let text = ctx
794 .get_raw(1)
795 .as_str()
796 .map_err(|e| Error::UserFunctionError(e.into()))?;
797
798 regexp.is_match(text)
799 };
800
801 Ok(is_match)
802 }
803
804 #[test]
test_function_regexp_with_auxilliary()805 fn test_function_regexp_with_auxilliary() {
806 let db = Connection::open_in_memory().unwrap();
807 db.execute_batch(
808 "BEGIN;
809 CREATE TABLE foo (x string);
810 INSERT INTO foo VALUES ('lisa');
811 INSERT INTO foo VALUES ('lXsi');
812 INSERT INTO foo VALUES ('lisX');
813 END;",
814 )
815 .unwrap();
816 db.create_scalar_function(
817 "regexp",
818 2,
819 FunctionFlags::SQLITE_UTF8 | FunctionFlags::SQLITE_DETERMINISTIC,
820 regexp_with_auxilliary,
821 )
822 .unwrap();
823
824 let result: Result<bool> =
825 db.query_row("SELECT regexp('l.s[aeiouy]', 'lisa')", NO_PARAMS, |r| {
826 r.get(0)
827 });
828
829 assert_eq!(true, result.unwrap());
830
831 let result: Result<i64> = db.query_row(
832 "SELECT COUNT(*) FROM foo WHERE regexp('l.s[aeiouy]', x) == 1",
833 NO_PARAMS,
834 |r| r.get(0),
835 );
836
837 assert_eq!(2, result.unwrap());
838 }
839
840 #[test]
test_varargs_function()841 fn test_varargs_function() {
842 let db = Connection::open_in_memory().unwrap();
843 db.create_scalar_function(
844 "my_concat",
845 -1,
846 FunctionFlags::SQLITE_UTF8 | FunctionFlags::SQLITE_DETERMINISTIC,
847 |ctx| {
848 let mut ret = String::new();
849
850 for idx in 0..ctx.len() {
851 let s = ctx.get::<String>(idx)?;
852 ret.push_str(&s);
853 }
854
855 Ok(ret)
856 },
857 )
858 .unwrap();
859
860 for &(expected, query) in &[
861 ("", "SELECT my_concat()"),
862 ("onetwo", "SELECT my_concat('one', 'two')"),
863 ("abc", "SELECT my_concat('a', 'b', 'c')"),
864 ] {
865 let result: String = db.query_row(query, NO_PARAMS, |r| r.get(0)).unwrap();
866 assert_eq!(expected, result);
867 }
868 }
869
870 #[test]
test_get_aux_type_checking()871 fn test_get_aux_type_checking() {
872 let db = Connection::open_in_memory().unwrap();
873 db.create_scalar_function("example", 2, FunctionFlags::default(), |ctx| {
874 if !ctx.get::<bool>(1)? {
875 ctx.set_aux::<i64>(0, 100)?;
876 } else {
877 assert_eq!(ctx.get_aux::<String>(0), Err(Error::GetAuxWrongType));
878 assert_eq!(*ctx.get_aux::<i64>(0).unwrap().unwrap(), 100);
879 }
880 Ok(true)
881 })
882 .unwrap();
883
884 let res: bool = db
885 .query_row(
886 "SELECT example(0, i) FROM (SELECT 0 as i UNION SELECT 1)",
887 NO_PARAMS,
888 |r| r.get(0),
889 )
890 .unwrap();
891 // Doesn't actually matter, we'll assert in the function if there's a problem.
892 assert!(res);
893 }
894
895 struct Sum;
896 struct Count;
897
898 impl Aggregate<i64, Option<i64>> for Sum {
init(&self) -> i64899 fn init(&self) -> i64 {
900 0
901 }
902
step(&self, ctx: &mut Context<'_>, sum: &mut i64) -> Result<()>903 fn step(&self, ctx: &mut Context<'_>, sum: &mut i64) -> Result<()> {
904 *sum += ctx.get::<i64>(0)?;
905 Ok(())
906 }
907
finalize(&self, sum: Option<i64>) -> Result<Option<i64>>908 fn finalize(&self, sum: Option<i64>) -> Result<Option<i64>> {
909 Ok(sum)
910 }
911 }
912
913 impl Aggregate<i64, i64> for Count {
init(&self) -> i64914 fn init(&self) -> i64 {
915 0
916 }
917
step(&self, _ctx: &mut Context<'_>, sum: &mut i64) -> Result<()>918 fn step(&self, _ctx: &mut Context<'_>, sum: &mut i64) -> Result<()> {
919 *sum += 1;
920 Ok(())
921 }
922
finalize(&self, sum: Option<i64>) -> Result<i64>923 fn finalize(&self, sum: Option<i64>) -> Result<i64> {
924 Ok(sum.unwrap_or(0))
925 }
926 }
927
928 #[test]
test_sum()929 fn test_sum() {
930 let db = Connection::open_in_memory().unwrap();
931 db.create_aggregate_function(
932 "my_sum",
933 1,
934 FunctionFlags::SQLITE_UTF8 | FunctionFlags::SQLITE_DETERMINISTIC,
935 Sum,
936 )
937 .unwrap();
938
939 // sum should return NULL when given no columns (contrast with count below)
940 let no_result = "SELECT my_sum(i) FROM (SELECT 2 AS i WHERE 1 <> 1)";
941 let result: Option<i64> = db.query_row(no_result, NO_PARAMS, |r| r.get(0)).unwrap();
942 assert!(result.is_none());
943
944 let single_sum = "SELECT my_sum(i) FROM (SELECT 2 AS i UNION ALL SELECT 2)";
945 let result: i64 = db.query_row(single_sum, NO_PARAMS, |r| r.get(0)).unwrap();
946 assert_eq!(4, result);
947
948 let dual_sum = "SELECT my_sum(i), my_sum(j) FROM (SELECT 2 AS i, 1 AS j UNION ALL SELECT \
949 2, 1)";
950 let result: (i64, i64) = db
951 .query_row(dual_sum, NO_PARAMS, |r| Ok((r.get(0)?, r.get(1)?)))
952 .unwrap();
953 assert_eq!((4, 2), result);
954 }
955
956 #[test]
test_count()957 fn test_count() {
958 let db = Connection::open_in_memory().unwrap();
959 db.create_aggregate_function(
960 "my_count",
961 -1,
962 FunctionFlags::SQLITE_UTF8 | FunctionFlags::SQLITE_DETERMINISTIC,
963 Count,
964 )
965 .unwrap();
966
967 // count should return 0 when given no columns (contrast with sum above)
968 let no_result = "SELECT my_count(i) FROM (SELECT 2 AS i WHERE 1 <> 1)";
969 let result: i64 = db.query_row(no_result, NO_PARAMS, |r| r.get(0)).unwrap();
970 assert_eq!(result, 0);
971
972 let single_sum = "SELECT my_count(i) FROM (SELECT 2 AS i UNION ALL SELECT 2)";
973 let result: i64 = db.query_row(single_sum, NO_PARAMS, |r| r.get(0)).unwrap();
974 assert_eq!(2, result);
975 }
976
977 #[cfg(feature = "window")]
978 impl WindowAggregate<i64, Option<i64>> for Sum {
inverse(&self, ctx: &mut Context<'_>, sum: &mut i64) -> Result<()>979 fn inverse(&self, ctx: &mut Context<'_>, sum: &mut i64) -> Result<()> {
980 *sum -= ctx.get::<i64>(0)?;
981 Ok(())
982 }
983
value(&self, sum: Option<&i64>) -> Result<Option<i64>>984 fn value(&self, sum: Option<&i64>) -> Result<Option<i64>> {
985 Ok(sum.copied())
986 }
987 }
988
989 #[test]
990 #[cfg(feature = "window")]
test_window()991 fn test_window() {
992 use fallible_iterator::FallibleIterator;
993
994 let db = Connection::open_in_memory().unwrap();
995 db.create_window_function(
996 "sumint",
997 1,
998 FunctionFlags::SQLITE_UTF8 | FunctionFlags::SQLITE_DETERMINISTIC,
999 Sum,
1000 )
1001 .unwrap();
1002 db.execute_batch(
1003 "CREATE TABLE t3(x, y);
1004 INSERT INTO t3 VALUES('a', 4),
1005 ('b', 5),
1006 ('c', 3),
1007 ('d', 8),
1008 ('e', 1);",
1009 )
1010 .unwrap();
1011
1012 let mut stmt = db
1013 .prepare(
1014 "SELECT x, sumint(y) OVER (
1015 ORDER BY x ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING
1016 ) AS sum_y
1017 FROM t3 ORDER BY x;",
1018 )
1019 .unwrap();
1020
1021 let results: Vec<(String, i64)> = stmt
1022 .query(NO_PARAMS)
1023 .unwrap()
1024 .map(|row| Ok((row.get("x")?, row.get("sum_y")?)))
1025 .collect()
1026 .unwrap();
1027 let expected = vec![
1028 ("a".to_owned(), 9),
1029 ("b".to_owned(), 12),
1030 ("c".to_owned(), 16),
1031 ("d".to_owned(), 12),
1032 ("e".to_owned(), 9),
1033 ];
1034 assert_eq!(expected, results);
1035 }
1036 }
1037