1 use byteorder::*;
2 use std::io::Write;
3 
4 use deserialize::{self, FromSql, FromSqlRow, Queryable};
5 use expression::{AppearsOnTable, AsExpression, Expression, NonAggregate, SelectableExpression};
6 use pg::Pg;
7 use query_builder::{AstPass, QueryFragment};
8 use result::QueryResult;
9 use row::Row;
10 use serialize::{self, IsNull, Output, ToSql, WriteTuple};
11 use sql_types::{HasSqlType, Record};
12 
13 macro_rules! tuple_impls {
14     ($(
15         $Tuple:tt {
16             $(($idx:tt) -> $T:ident, $ST:ident, $TT:ident,)+
17         }
18     )+) => {$(
19         impl<$($T,)+ $($ST,)+> FromSql<Record<($($ST,)+)>, Pg> for ($($T,)+)
20         where
21             $($T: FromSql<$ST, Pg>,)+
22         {
23             // Yes, we're relying on the order of evaluation of subexpressions
24             // but the only other option would be to use `mem::uninitialized`
25             // and `ptr::write`.
26             #[allow(clippy::eval_order_dependence)]
27             fn from_sql(bytes: Option<&[u8]>) -> deserialize::Result<Self> {
28                 let mut bytes = not_none!(bytes);
29                 let num_elements = bytes.read_i32::<NetworkEndian>()?;
30 
31                 if num_elements != $Tuple {
32                     return Err(format!(
33                         "Expected a tuple of {} elements, got {}",
34                         $Tuple,
35                         num_elements,
36                     ).into());
37                 }
38 
39                 let result = ($({
40                     // We could in theory validate the OID here, but that
41                     // ignores cases like text vs varchar where the
42                     // representation is the same and we don't care which we
43                     // got.
44                     let _oid = bytes.read_u32::<NetworkEndian>()?;
45                     let num_bytes = bytes.read_i32::<NetworkEndian>()?;
46 
47                     if num_bytes == -1 {
48                         $T::from_sql(None)?
49                     } else {
50                         let (elem_bytes, new_bytes) = bytes.split_at(num_bytes as usize);
51                         bytes = new_bytes;
52                         $T::from_sql(Some(elem_bytes))?
53                     }
54                 },)+);
55 
56                 if bytes.is_empty() {
57                     Ok(result)
58                 } else {
59                     Err("Received too many bytes. This tuple likely contains \
60                         an element of the wrong SQL type.".into())
61                 }
62             }
63         }
64 
65         impl<$($T,)+ $($ST,)+> FromSqlRow<Record<($($ST,)+)>, Pg> for ($($T,)+)
66         where
67             Self: FromSql<Record<($($ST,)+)>, Pg>,
68         {
69             const FIELDS_NEEDED: usize = 1;
70 
71             fn build_from_row<RowT: Row<Pg>>(row: &mut RowT) -> deserialize::Result<Self> {
72                 Self::from_sql(row.take())
73             }
74         }
75 
76         impl<$($T,)+ $($ST,)+> Queryable<Record<($($ST,)+)>, Pg> for ($($T,)+)
77         where
78             Self: FromSqlRow<Record<($($ST,)+)>, Pg>,
79         {
80             type Row = Self;
81 
82             fn build(row: Self::Row) -> Self {
83                 row
84             }
85         }
86 
87         impl<$($T,)+ $($ST,)+> AsExpression<Record<($($ST,)+)>> for ($($T,)+)
88         where
89             $($T: AsExpression<$ST>,)+
90             PgTuple<($($T::Expression,)+)>: Expression<SqlType = Record<($($ST,)+)>>,
91         {
92             type Expression = PgTuple<($($T::Expression,)+)>;
93 
94             fn as_expression(self) -> Self::Expression {
95                 PgTuple(($(
96                     self.$idx.as_expression(),
97                 )+))
98             }
99         }
100 
101         impl<$($T,)+ $($ST,)+> WriteTuple<($($ST,)+)> for ($($T,)+)
102         where
103             $($T: ToSql<$ST, Pg>,)+
104             $(Pg: HasSqlType<$ST>),+
105         {
106             fn write_tuple<_W: Write>(&self, out: &mut Output<_W, Pg>) -> serialize::Result {
107                 let mut buffer = out.with_buffer(Vec::new());
108                 out.write_i32::<NetworkEndian>($Tuple)?;
109 
110                 $(
111                     let oid = <Pg as HasSqlType<$ST>>::metadata(out.metadata_lookup()).oid;
112                     out.write_u32::<NetworkEndian>(oid)?;
113                     let is_null = self.$idx.to_sql(&mut buffer)?;
114 
115                     if let IsNull::No = is_null {
116                         out.write_i32::<NetworkEndian>(buffer.len() as i32)?;
117                         out.write_all(&buffer)?;
118                         buffer.clear();
119                     } else {
120                         out.write_i32::<NetworkEndian>(-1)?;
121                     }
122                 )+
123 
124                 Ok(IsNull::No)
125             }
126         }
127     )+}
128 }
129 
130 __diesel_for_each_tuple!(tuple_impls);
131 
132 #[derive(Debug, Clone, Copy, QueryId)]
133 pub struct PgTuple<T>(T);
134 
135 impl<T> QueryFragment<Pg> for PgTuple<T>
136 where
137     T: QueryFragment<Pg>,
138 {
walk_ast(&self, mut out: AstPass<Pg>) -> QueryResult<()>139     fn walk_ast(&self, mut out: AstPass<Pg>) -> QueryResult<()> {
140         out.push_sql("(");
141         self.0.walk_ast(out.reborrow())?;
142         out.push_sql(")");
143         Ok(())
144     }
145 }
146 
147 impl<T> Expression for PgTuple<T>
148 where
149     T: Expression,
150 {
151     type SqlType = Record<T::SqlType>;
152 }
153 
154 impl<T, QS> SelectableExpression<QS> for PgTuple<T>
155 where
156     T: SelectableExpression<QS>,
157     Self: AppearsOnTable<QS>,
158 {
159 }
160 
161 impl<T, QS> AppearsOnTable<QS> for PgTuple<T>
162 where
163     T: AppearsOnTable<QS>,
164     Self: Expression,
165 {
166 }
167 
168 impl<T> NonAggregate for PgTuple<T>
169 where
170     T: NonAggregate,
171     Self: Expression,
172 {
173 }
174 
175 #[cfg(test)]
176 mod tests {
177     use super::*;
178     use dsl::sql;
179     use prelude::*;
180     use sql_types::*;
181     use test_helpers::*;
182 
183     #[test]
record_deserializes_correctly()184     fn record_deserializes_correctly() {
185         let conn = pg_connection();
186 
187         let tup =
188             sql::<Record<(Integer, Text)>>("SELECT (1, 'hi')").get_result::<(i32, String)>(&conn);
189         assert_eq!(Ok((1, String::from("hi"))), tup);
190 
191         let tup = sql::<Record<(Record<(Integer, Text)>, Integer)>>("SELECT ((2, 'bye'), 3)")
192             .get_result::<((i32, String), i32)>(&conn);
193         assert_eq!(Ok(((2, String::from("bye")), 3)), tup);
194 
195         let tup = sql::<
196             Record<(
197                 Record<(Nullable<Integer>, Nullable<Text>)>,
198                 Nullable<Integer>,
199             )>,
200         >("SELECT ((4, NULL), NULL)")
201         .get_result::<((Option<i32>, Option<String>), Option<i32>)>(&conn);
202         assert_eq!(Ok(((Some(4), None), None)), tup);
203     }
204 
205     #[test]
record_kinda_sorta_not_really_serializes_correctly()206     fn record_kinda_sorta_not_really_serializes_correctly() {
207         let conn = pg_connection();
208 
209         let tup = sql::<Record<(Integer, Text)>>("(1, 'hi')");
210         let res = ::select(tup.eq((1, "hi"))).get_result(&conn);
211         assert_eq!(Ok(true), res);
212 
213         let tup = sql::<Record<(Record<(Integer, Text)>, Integer)>>("((2, 'bye'::text), 3)");
214         let res = ::select(tup.eq(((2, "bye"), 3))).get_result(&conn);
215         assert_eq!(Ok(true), res);
216 
217         let tup = sql::<
218             Record<(
219                 Record<(Nullable<Integer>, Nullable<Text>)>,
220                 Nullable<Integer>,
221             )>,
222         >("((4, NULL::text), NULL::int4)");
223         let res = ::select(tup.is_not_distinct_from(((Some(4), None::<&str>), None::<i32>)))
224             .get_result(&conn);
225         assert_eq!(Ok(true), res);
226     }
227 
228     #[test]
serializing_named_composite_types()229     fn serializing_named_composite_types() {
230         #[derive(SqlType, QueryId, Debug, Clone, Copy)]
231         #[postgres(type_name = "my_type")]
232         struct MyType;
233 
234         #[derive(Debug, AsExpression)]
235         #[sql_type = "MyType"]
236         struct MyStruct<'a>(i32, &'a str);
237 
238         impl<'a> ToSql<MyType, Pg> for MyStruct<'a> {
239             fn to_sql<W: Write>(&self, out: &mut Output<W, Pg>) -> serialize::Result {
240                 WriteTuple::<(Integer, Text)>::write_tuple(&(self.0, self.1), out)
241             }
242         }
243 
244         let conn = pg_connection();
245 
246         ::sql_query("CREATE TYPE my_type AS (i int4, t text)")
247             .execute(&conn)
248             .unwrap();
249         let sql = sql::<Bool>("(1, 'hi')::my_type = ").bind::<MyType, _>(MyStruct(1, "hi"));
250         let res = ::select(sql).get_result(&conn);
251         assert_eq!(Ok(true), res);
252     }
253 }
254