1 extern crate libsqlite3_sys as ffi;
2 
3 mod functions;
4 #[doc(hidden)]
5 pub mod raw;
6 mod serialized_value;
7 mod sqlite_value;
8 mod statement_iterator;
9 mod stmt;
10 
11 pub use self::sqlite_value::SqliteValue;
12 
13 use std::os::raw as libc;
14 
15 use self::raw::RawConnection;
16 use self::statement_iterator::*;
17 use self::stmt::{Statement, StatementUse};
18 use connection::*;
19 use deserialize::{Queryable, QueryableByName};
20 use query_builder::bind_collector::RawBytesBindCollector;
21 use query_builder::*;
22 use result::*;
23 use serialize::ToSql;
24 use sql_types::HasSqlType;
25 use sqlite::Sqlite;
26 
27 /// Connections for the SQLite backend. Unlike other backends, "connection URLs"
28 /// for SQLite are file paths, [URIs](https://sqlite.org/uri.html), or special
29 /// identifiers like `:memory:`.
30 #[allow(missing_debug_implementations)]
31 pub struct SqliteConnection {
32     statement_cache: StatementCache<Sqlite, Statement>,
33     raw_connection: RawConnection,
34     transaction_manager: AnsiTransactionManager,
35 }
36 
37 // This relies on the invariant that RawConnection or Statement are never
38 // leaked. If a reference to one of those was held on a different thread, this
39 // would not be thread safe.
40 unsafe impl Send for SqliteConnection {}
41 
42 impl SimpleConnection for SqliteConnection {
batch_execute(&self, query: &str) -> QueryResult<()>43     fn batch_execute(&self, query: &str) -> QueryResult<()> {
44         self.raw_connection.exec(query)
45     }
46 }
47 
48 impl Connection for SqliteConnection {
49     type Backend = Sqlite;
50     type TransactionManager = AnsiTransactionManager;
51 
establish(database_url: &str) -> ConnectionResult<Self>52     fn establish(database_url: &str) -> ConnectionResult<Self> {
53         use result::ConnectionError::CouldntSetupConfiguration;
54 
55         let raw_connection = RawConnection::establish(database_url)?;
56         let conn = Self {
57             statement_cache: StatementCache::new(),
58             raw_connection,
59             transaction_manager: AnsiTransactionManager::new(),
60         };
61         conn.register_diesel_sql_functions()
62             .map_err(CouldntSetupConfiguration)?;
63         Ok(conn)
64     }
65 
66     #[doc(hidden)]
execute(&self, query: &str) -> QueryResult<usize>67     fn execute(&self, query: &str) -> QueryResult<usize> {
68         self.batch_execute(query)?;
69         Ok(self.raw_connection.rows_affected_by_last_query())
70     }
71 
72     #[doc(hidden)]
query_by_index<T, U>(&self, source: T) -> QueryResult<Vec<U>> where T: AsQuery, T::Query: QueryFragment<Self::Backend> + QueryId, Self::Backend: HasSqlType<T::SqlType>, U: Queryable<T::SqlType, Self::Backend>,73     fn query_by_index<T, U>(&self, source: T) -> QueryResult<Vec<U>>
74     where
75         T: AsQuery,
76         T::Query: QueryFragment<Self::Backend> + QueryId,
77         Self::Backend: HasSqlType<T::SqlType>,
78         U: Queryable<T::SqlType, Self::Backend>,
79     {
80         let mut statement = self.prepare_query(&source.as_query())?;
81         let statement_use = StatementUse::new(&mut statement);
82         let iter = StatementIterator::new(statement_use);
83         iter.collect()
84     }
85 
86     #[doc(hidden)]
query_by_name<T, U>(&self, source: &T) -> QueryResult<Vec<U>> where T: QueryFragment<Self::Backend> + QueryId, U: QueryableByName<Self::Backend>,87     fn query_by_name<T, U>(&self, source: &T) -> QueryResult<Vec<U>>
88     where
89         T: QueryFragment<Self::Backend> + QueryId,
90         U: QueryableByName<Self::Backend>,
91     {
92         let mut statement = self.prepare_query(source)?;
93         let statement_use = StatementUse::new(&mut statement);
94         let iter = NamedStatementIterator::new(statement_use)?;
95         iter.collect()
96     }
97 
98     #[doc(hidden)]
execute_returning_count<T>(&self, source: &T) -> QueryResult<usize> where T: QueryFragment<Self::Backend> + QueryId,99     fn execute_returning_count<T>(&self, source: &T) -> QueryResult<usize>
100     where
101         T: QueryFragment<Self::Backend> + QueryId,
102     {
103         let mut statement = self.prepare_query(source)?;
104         let mut statement_use = StatementUse::new(&mut statement);
105         statement_use.run()?;
106         Ok(self.raw_connection.rows_affected_by_last_query())
107     }
108 
109     #[doc(hidden)]
transaction_manager(&self) -> &Self::TransactionManager110     fn transaction_manager(&self) -> &Self::TransactionManager {
111         &self.transaction_manager
112     }
113 }
114 
115 impl SqliteConnection {
116     /// Run a transaction with `BEGIN IMMEDIATE`
117     ///
118     /// This method will return an error if a transaction is already open.
119     ///
120     /// # Example
121     ///
122     /// ```rust
123     /// # #[macro_use] extern crate diesel;
124     /// # include!("../../doctest_setup.rs");
125     /// #
126     /// # fn main() {
127     /// #     run_test().unwrap();
128     /// # }
129     /// #
130     /// # fn run_test() -> QueryResult<()> {
131     /// #     let conn = SqliteConnection::establish(":memory:").unwrap();
132     /// conn.immediate_transaction(|| {
133     ///     // Do stuff in a transaction
134     ///     Ok(())
135     /// })
136     /// # }
137     /// ```
immediate_transaction<T, E, F>(&self, f: F) -> Result<T, E> where F: FnOnce() -> Result<T, E>, E: From<Error>,138     pub fn immediate_transaction<T, E, F>(&self, f: F) -> Result<T, E>
139     where
140         F: FnOnce() -> Result<T, E>,
141         E: From<Error>,
142     {
143         self.transaction_sql(f, "BEGIN IMMEDIATE")
144     }
145 
146     /// Run a transaction with `BEGIN EXCLUSIVE`
147     ///
148     /// This method will return an error if a transaction is already open.
149     ///
150     /// # Example
151     ///
152     /// ```rust
153     /// # #[macro_use] extern crate diesel;
154     /// # include!("../../doctest_setup.rs");
155     /// #
156     /// # fn main() {
157     /// #     run_test().unwrap();
158     /// # }
159     /// #
160     /// # fn run_test() -> QueryResult<()> {
161     /// #     let conn = SqliteConnection::establish(":memory:").unwrap();
162     /// conn.exclusive_transaction(|| {
163     ///     // Do stuff in a transaction
164     ///     Ok(())
165     /// })
166     /// # }
167     /// ```
exclusive_transaction<T, E, F>(&self, f: F) -> Result<T, E> where F: FnOnce() -> Result<T, E>, E: From<Error>,168     pub fn exclusive_transaction<T, E, F>(&self, f: F) -> Result<T, E>
169     where
170         F: FnOnce() -> Result<T, E>,
171         E: From<Error>,
172     {
173         self.transaction_sql(f, "BEGIN EXCLUSIVE")
174     }
175 
transaction_sql<T, E, F>(&self, f: F, sql: &str) -> Result<T, E> where F: FnOnce() -> Result<T, E>, E: From<Error>,176     fn transaction_sql<T, E, F>(&self, f: F, sql: &str) -> Result<T, E>
177     where
178         F: FnOnce() -> Result<T, E>,
179         E: From<Error>,
180     {
181         let transaction_manager = self.transaction_manager();
182 
183         transaction_manager.begin_transaction_sql(self, sql)?;
184         match f() {
185             Ok(value) => {
186                 transaction_manager.commit_transaction(self)?;
187                 Ok(value)
188             }
189             Err(e) => {
190                 transaction_manager.rollback_transaction(self)?;
191                 Err(e)
192             }
193         }
194     }
195 
prepare_query<T: QueryFragment<Sqlite> + QueryId>( &self, source: &T, ) -> QueryResult<MaybeCached<Statement>>196     fn prepare_query<T: QueryFragment<Sqlite> + QueryId>(
197         &self,
198         source: &T,
199     ) -> QueryResult<MaybeCached<Statement>> {
200         let mut statement = self.cached_prepared_statement(source)?;
201 
202         let mut bind_collector = RawBytesBindCollector::<Sqlite>::new();
203         source.collect_binds(&mut bind_collector, &())?;
204         let metadata = bind_collector.metadata;
205         let binds = bind_collector.binds;
206         for (tpe, value) in metadata.into_iter().zip(binds) {
207             statement.bind(tpe, value)?;
208         }
209 
210         Ok(statement)
211     }
212 
cached_prepared_statement<T: QueryFragment<Sqlite> + QueryId>( &self, source: &T, ) -> QueryResult<MaybeCached<Statement>>213     fn cached_prepared_statement<T: QueryFragment<Sqlite> + QueryId>(
214         &self,
215         source: &T,
216     ) -> QueryResult<MaybeCached<Statement>> {
217         self.statement_cache.cached_statement(source, &[], |sql| {
218             Statement::prepare(&self.raw_connection, sql)
219         })
220     }
221 
222     #[doc(hidden)]
register_sql_function<ArgsSqlType, RetSqlType, Args, Ret, F>( &self, fn_name: &str, deterministic: bool, mut f: F, ) -> QueryResult<()> where F: FnMut(Args) -> Ret + Send + 'static, Args: Queryable<ArgsSqlType, Sqlite>, Ret: ToSql<RetSqlType, Sqlite>, Sqlite: HasSqlType<RetSqlType>,223     pub fn register_sql_function<ArgsSqlType, RetSqlType, Args, Ret, F>(
224         &self,
225         fn_name: &str,
226         deterministic: bool,
227         mut f: F,
228     ) -> QueryResult<()>
229     where
230         F: FnMut(Args) -> Ret + Send + 'static,
231         Args: Queryable<ArgsSqlType, Sqlite>,
232         Ret: ToSql<RetSqlType, Sqlite>,
233         Sqlite: HasSqlType<RetSqlType>,
234     {
235         functions::register(
236             &self.raw_connection,
237             fn_name,
238             deterministic,
239             move |_, args| f(args),
240         )
241     }
242 
register_diesel_sql_functions(&self) -> QueryResult<()>243     fn register_diesel_sql_functions(&self) -> QueryResult<()> {
244         use sql_types::{Integer, Text};
245 
246         functions::register::<Text, Integer, _, _, _>(
247             &self.raw_connection,
248             "diesel_manage_updated_at",
249             false,
250             |conn, table_name: String| {
251                 conn.exec(&format!(
252                     include_str!("diesel_manage_updated_at.sql"),
253                     table_name = table_name
254                 ))
255                 .expect("Failed to create trigger");
256                 0 // have to return *something*
257             },
258         )
259     }
260 }
261 
error_message(err_code: libc::c_int) -> &'static str262 fn error_message(err_code: libc::c_int) -> &'static str {
263     ffi::code_to_str(err_code)
264 }
265 
266 #[cfg(test)]
267 mod tests {
268     use super::*;
269     use dsl::sql;
270     use prelude::*;
271     use sql_types::Integer;
272 
273     #[test]
prepared_statements_are_cached_when_run()274     fn prepared_statements_are_cached_when_run() {
275         let connection = SqliteConnection::establish(":memory:").unwrap();
276         let query = ::select(1.into_sql::<Integer>());
277 
278         assert_eq!(Ok(1), query.get_result(&connection));
279         assert_eq!(Ok(1), query.get_result(&connection));
280         assert_eq!(1, connection.statement_cache.len());
281     }
282 
283     #[test]
sql_literal_nodes_are_not_cached()284     fn sql_literal_nodes_are_not_cached() {
285         let connection = SqliteConnection::establish(":memory:").unwrap();
286         let query = ::select(sql::<Integer>("1"));
287 
288         assert_eq!(Ok(1), query.get_result(&connection));
289         assert_eq!(0, connection.statement_cache.len());
290     }
291 
292     #[test]
queries_containing_sql_literal_nodes_are_not_cached()293     fn queries_containing_sql_literal_nodes_are_not_cached() {
294         let connection = SqliteConnection::establish(":memory:").unwrap();
295         let one_as_expr = 1.into_sql::<Integer>();
296         let query = ::select(one_as_expr.eq(sql::<Integer>("1")));
297 
298         assert_eq!(Ok(true), query.get_result(&connection));
299         assert_eq!(0, connection.statement_cache.len());
300     }
301 
302     #[test]
queries_containing_in_with_vec_are_not_cached()303     fn queries_containing_in_with_vec_are_not_cached() {
304         let connection = SqliteConnection::establish(":memory:").unwrap();
305         let one_as_expr = 1.into_sql::<Integer>();
306         let query = ::select(one_as_expr.eq_any(vec![1, 2, 3]));
307 
308         assert_eq!(Ok(true), query.get_result(&connection));
309         assert_eq!(0, connection.statement_cache.len());
310     }
311 
312     #[test]
queries_containing_in_with_subselect_are_cached()313     fn queries_containing_in_with_subselect_are_cached() {
314         let connection = SqliteConnection::establish(":memory:").unwrap();
315         let one_as_expr = 1.into_sql::<Integer>();
316         let query = ::select(one_as_expr.eq_any(::select(one_as_expr)));
317 
318         assert_eq!(Ok(true), query.get_result(&connection));
319         assert_eq!(1, connection.statement_cache.len());
320     }
321 
322     use sql_types::Text;
323     sql_function!(fn fun_case(x: Text) -> Text);
324 
325     #[test]
register_custom_function()326     fn register_custom_function() {
327         let connection = SqliteConnection::establish(":memory:").unwrap();
328         fun_case::register_impl(&connection, |x: String| {
329             x.chars()
330                 .enumerate()
331                 .map(|(i, c)| {
332                     if i % 2 == 0 {
333                         c.to_lowercase().to_string()
334                     } else {
335                         c.to_uppercase().to_string()
336                     }
337                 })
338                 .collect::<String>()
339         })
340         .unwrap();
341 
342         let mapped_string = ::select(fun_case("foobar"))
343             .get_result::<String>(&connection)
344             .unwrap();
345         assert_eq!("fOoBaR", mapped_string);
346     }
347 
348     sql_function!(fn my_add(x: Integer, y: Integer) -> Integer);
349 
350     #[test]
register_multiarg_function()351     fn register_multiarg_function() {
352         let connection = SqliteConnection::establish(":memory:").unwrap();
353         my_add::register_impl(&connection, |x: i32, y: i32| x + y).unwrap();
354 
355         let added = ::select(my_add(1, 2)).get_result::<i32>(&connection);
356         assert_eq!(Ok(3), added);
357     }
358 
359     sql_function!(fn add_counter(x: Integer) -> Integer);
360 
361     #[test]
register_nondeterministic_function()362     fn register_nondeterministic_function() {
363         let connection = SqliteConnection::establish(":memory:").unwrap();
364         let mut y = 0;
365         add_counter::register_nondeterministic_impl(&connection, move |x: i32| {
366             y += 1;
367             x + y
368         })
369         .unwrap();
370 
371         let added = ::select((add_counter(1), add_counter(1), add_counter(1)))
372             .get_result::<(i32, i32, i32)>(&connection);
373         assert_eq!(Ok((2, 3, 4)), added);
374     }
375 }
376