1 use byteorder::*; 2 use std::io::Write; 3 4 use deserialize::{self, FromSql, FromSqlRow, Queryable}; 5 use expression::{AppearsOnTable, AsExpression, Expression, NonAggregate, SelectableExpression}; 6 use pg::Pg; 7 use query_builder::{AstPass, QueryFragment}; 8 use result::QueryResult; 9 use row::Row; 10 use serialize::{self, IsNull, Output, ToSql, WriteTuple}; 11 use sql_types::{HasSqlType, Record}; 12 13 macro_rules! tuple_impls { 14 ($( 15 $Tuple:tt { 16 $(($idx:tt) -> $T:ident, $ST:ident, $TT:ident,)+ 17 } 18 )+) => {$( 19 impl<$($T,)+ $($ST,)+> FromSql<Record<($($ST,)+)>, Pg> for ($($T,)+) 20 where 21 $($T: FromSql<$ST, Pg>,)+ 22 { 23 // Yes, we're relying on the order of evaluation of subexpressions 24 // but the only other option would be to use `mem::uninitialized` 25 // and `ptr::write`. 26 #[allow(clippy::eval_order_dependence)] 27 fn from_sql(bytes: Option<&[u8]>) -> deserialize::Result<Self> { 28 let mut bytes = not_none!(bytes); 29 let num_elements = bytes.read_i32::<NetworkEndian>()?; 30 31 if num_elements != $Tuple { 32 return Err(format!( 33 "Expected a tuple of {} elements, got {}", 34 $Tuple, 35 num_elements, 36 ).into()); 37 } 38 39 let result = ($({ 40 // We could in theory validate the OID here, but that 41 // ignores cases like text vs varchar where the 42 // representation is the same and we don't care which we 43 // got. 44 let _oid = bytes.read_u32::<NetworkEndian>()?; 45 let num_bytes = bytes.read_i32::<NetworkEndian>()?; 46 47 if num_bytes == -1 { 48 $T::from_sql(None)? 49 } else { 50 let (elem_bytes, new_bytes) = bytes.split_at(num_bytes as usize); 51 bytes = new_bytes; 52 $T::from_sql(Some(elem_bytes))? 53 } 54 },)+); 55 56 if bytes.is_empty() { 57 Ok(result) 58 } else { 59 Err("Received too many bytes. This tuple likely contains \ 60 an element of the wrong SQL type.".into()) 61 } 62 } 63 } 64 65 impl<$($T,)+ $($ST,)+> FromSqlRow<Record<($($ST,)+)>, Pg> for ($($T,)+) 66 where 67 Self: FromSql<Record<($($ST,)+)>, Pg>, 68 { 69 const FIELDS_NEEDED: usize = 1; 70 71 fn build_from_row<RowT: Row<Pg>>(row: &mut RowT) -> deserialize::Result<Self> { 72 Self::from_sql(row.take()) 73 } 74 } 75 76 impl<$($T,)+ $($ST,)+> Queryable<Record<($($ST,)+)>, Pg> for ($($T,)+) 77 where 78 Self: FromSqlRow<Record<($($ST,)+)>, Pg>, 79 { 80 type Row = Self; 81 82 fn build(row: Self::Row) -> Self { 83 row 84 } 85 } 86 87 impl<$($T,)+ $($ST,)+> AsExpression<Record<($($ST,)+)>> for ($($T,)+) 88 where 89 $($T: AsExpression<$ST>,)+ 90 PgTuple<($($T::Expression,)+)>: Expression<SqlType = Record<($($ST,)+)>>, 91 { 92 type Expression = PgTuple<($($T::Expression,)+)>; 93 94 fn as_expression(self) -> Self::Expression { 95 PgTuple(($( 96 self.$idx.as_expression(), 97 )+)) 98 } 99 } 100 101 impl<$($T,)+ $($ST,)+> WriteTuple<($($ST,)+)> for ($($T,)+) 102 where 103 $($T: ToSql<$ST, Pg>,)+ 104 $(Pg: HasSqlType<$ST>),+ 105 { 106 fn write_tuple<_W: Write>(&self, out: &mut Output<_W, Pg>) -> serialize::Result { 107 let mut buffer = out.with_buffer(Vec::new()); 108 out.write_i32::<NetworkEndian>($Tuple)?; 109 110 $( 111 let oid = <Pg as HasSqlType<$ST>>::metadata(out.metadata_lookup()).oid; 112 out.write_u32::<NetworkEndian>(oid)?; 113 let is_null = self.$idx.to_sql(&mut buffer)?; 114 115 if let IsNull::No = is_null { 116 out.write_i32::<NetworkEndian>(buffer.len() as i32)?; 117 out.write_all(&buffer)?; 118 buffer.clear(); 119 } else { 120 out.write_i32::<NetworkEndian>(-1)?; 121 } 122 )+ 123 124 Ok(IsNull::No) 125 } 126 } 127 )+} 128 } 129 130 __diesel_for_each_tuple!(tuple_impls); 131 132 #[derive(Debug, Clone, Copy, QueryId)] 133 pub struct PgTuple<T>(T); 134 135 impl<T> QueryFragment<Pg> for PgTuple<T> 136 where 137 T: QueryFragment<Pg>, 138 { walk_ast(&self, mut out: AstPass<Pg>) -> QueryResult<()>139 fn walk_ast(&self, mut out: AstPass<Pg>) -> QueryResult<()> { 140 out.push_sql("("); 141 self.0.walk_ast(out.reborrow())?; 142 out.push_sql(")"); 143 Ok(()) 144 } 145 } 146 147 impl<T> Expression for PgTuple<T> 148 where 149 T: Expression, 150 { 151 type SqlType = Record<T::SqlType>; 152 } 153 154 impl<T, QS> SelectableExpression<QS> for PgTuple<T> 155 where 156 T: SelectableExpression<QS>, 157 Self: AppearsOnTable<QS>, 158 { 159 } 160 161 impl<T, QS> AppearsOnTable<QS> for PgTuple<T> 162 where 163 T: AppearsOnTable<QS>, 164 Self: Expression, 165 { 166 } 167 168 impl<T> NonAggregate for PgTuple<T> 169 where 170 T: NonAggregate, 171 Self: Expression, 172 { 173 } 174 175 #[cfg(test)] 176 mod tests { 177 use super::*; 178 use dsl::sql; 179 use prelude::*; 180 use sql_types::*; 181 use test_helpers::*; 182 183 #[test] record_deserializes_correctly()184 fn record_deserializes_correctly() { 185 let conn = pg_connection(); 186 187 let tup = 188 sql::<Record<(Integer, Text)>>("SELECT (1, 'hi')").get_result::<(i32, String)>(&conn); 189 assert_eq!(Ok((1, String::from("hi"))), tup); 190 191 let tup = sql::<Record<(Record<(Integer, Text)>, Integer)>>("SELECT ((2, 'bye'), 3)") 192 .get_result::<((i32, String), i32)>(&conn); 193 assert_eq!(Ok(((2, String::from("bye")), 3)), tup); 194 195 let tup = sql::< 196 Record<( 197 Record<(Nullable<Integer>, Nullable<Text>)>, 198 Nullable<Integer>, 199 )>, 200 >("SELECT ((4, NULL), NULL)") 201 .get_result::<((Option<i32>, Option<String>), Option<i32>)>(&conn); 202 assert_eq!(Ok(((Some(4), None), None)), tup); 203 } 204 205 #[test] record_kinda_sorta_not_really_serializes_correctly()206 fn record_kinda_sorta_not_really_serializes_correctly() { 207 let conn = pg_connection(); 208 209 let tup = sql::<Record<(Integer, Text)>>("(1, 'hi')"); 210 let res = ::select(tup.eq((1, "hi"))).get_result(&conn); 211 assert_eq!(Ok(true), res); 212 213 let tup = sql::<Record<(Record<(Integer, Text)>, Integer)>>("((2, 'bye'::text), 3)"); 214 let res = ::select(tup.eq(((2, "bye"), 3))).get_result(&conn); 215 assert_eq!(Ok(true), res); 216 217 let tup = sql::< 218 Record<( 219 Record<(Nullable<Integer>, Nullable<Text>)>, 220 Nullable<Integer>, 221 )>, 222 >("((4, NULL::text), NULL::int4)"); 223 let res = ::select(tup.is_not_distinct_from(((Some(4), None::<&str>), None::<i32>))) 224 .get_result(&conn); 225 assert_eq!(Ok(true), res); 226 } 227 228 #[test] serializing_named_composite_types()229 fn serializing_named_composite_types() { 230 #[derive(SqlType, QueryId, Debug, Clone, Copy)] 231 #[postgres(type_name = "my_type")] 232 struct MyType; 233 234 #[derive(Debug, AsExpression)] 235 #[sql_type = "MyType"] 236 struct MyStruct<'a>(i32, &'a str); 237 238 impl<'a> ToSql<MyType, Pg> for MyStruct<'a> { 239 fn to_sql<W: Write>(&self, out: &mut Output<W, Pg>) -> serialize::Result { 240 WriteTuple::<(Integer, Text)>::write_tuple(&(self.0, self.1), out) 241 } 242 } 243 244 let conn = pg_connection(); 245 246 ::sql_query("CREATE TYPE my_type AS (i int4, t text)") 247 .execute(&conn) 248 .unwrap(); 249 let sql = sql::<Bool>("(1, 'hi')::my_type = ").bind::<MyType, _>(MyStruct(1, "hi")); 250 let res = ::select(sql).get_result(&conn); 251 assert_eq!(Ok(true), res); 252 } 253 } 254