1 // Licensed to the Apache Software Foundation (ASF) under one
2 // or more contributor license agreements.  See the NOTICE file
3 // distributed with this work for additional information
4 // regarding copyright ownership.  The ASF licenses this file
5 // to you under the Apache License, Version 2.0 (the
6 // "License"); you may not use this file except in compliance
7 // with the License.  You may obtain a copy of the License at
8 //
9 //   http://www.apache.org/licenses/LICENSE-2.0
10 //
11 // Unless required by applicable law or agreed to in writing,
12 // software distributed under the License is distributed on an
13 // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14 // KIND, either express or implied.  See the License for the
15 // specific language governing permissions and limitations
16 // under the License.
17 
18 //! Defines scalars used to construct groups, ex. in GROUP BY clauses.
19 
20 use ordered_float::OrderedFloat;
21 use std::convert::{From, TryFrom};
22 
23 use crate::error::{DataFusionError, Result};
24 use crate::scalar::ScalarValue;
25 
26 /// Enumeration of types that can be used in a GROUP BY expression
27 #[derive(Debug, PartialEq, Eq, Hash, Clone)]
28 pub(crate) enum GroupByScalar {
29     Float32(OrderedFloat<f32>),
30     Float64(OrderedFloat<f64>),
31     UInt8(u8),
32     UInt16(u16),
33     UInt32(u32),
34     UInt64(u64),
35     Int8(i8),
36     Int16(i16),
37     Int32(i32),
38     Int64(i64),
39     Utf8(Box<String>),
40     Boolean(bool),
41     TimeMicrosecond(i64),
42     TimeNanosecond(i64),
43     Date32(i32),
44 }
45 
46 impl TryFrom<&ScalarValue> for GroupByScalar {
47     type Error = DataFusionError;
48 
try_from(scalar_value: &ScalarValue) -> Result<Self>49     fn try_from(scalar_value: &ScalarValue) -> Result<Self> {
50         Ok(match scalar_value {
51             ScalarValue::Float32(Some(v)) => {
52                 GroupByScalar::Float32(OrderedFloat::from(*v))
53             }
54             ScalarValue::Float64(Some(v)) => {
55                 GroupByScalar::Float64(OrderedFloat::from(*v))
56             }
57             ScalarValue::Boolean(Some(v)) => GroupByScalar::Boolean(*v),
58             ScalarValue::Int8(Some(v)) => GroupByScalar::Int8(*v),
59             ScalarValue::Int16(Some(v)) => GroupByScalar::Int16(*v),
60             ScalarValue::Int32(Some(v)) => GroupByScalar::Int32(*v),
61             ScalarValue::Int64(Some(v)) => GroupByScalar::Int64(*v),
62             ScalarValue::UInt8(Some(v)) => GroupByScalar::UInt8(*v),
63             ScalarValue::UInt16(Some(v)) => GroupByScalar::UInt16(*v),
64             ScalarValue::UInt32(Some(v)) => GroupByScalar::UInt32(*v),
65             ScalarValue::UInt64(Some(v)) => GroupByScalar::UInt64(*v),
66             ScalarValue::Utf8(Some(v)) => GroupByScalar::Utf8(Box::new(v.clone())),
67             ScalarValue::Float32(None)
68             | ScalarValue::Float64(None)
69             | ScalarValue::Boolean(None)
70             | ScalarValue::Int8(None)
71             | ScalarValue::Int16(None)
72             | ScalarValue::Int32(None)
73             | ScalarValue::Int64(None)
74             | ScalarValue::UInt8(None)
75             | ScalarValue::UInt16(None)
76             | ScalarValue::UInt32(None)
77             | ScalarValue::UInt64(None)
78             | ScalarValue::Utf8(None) => {
79                 return Err(DataFusionError::Internal(format!(
80                     "Cannot convert a ScalarValue holding NULL ({:?})",
81                     scalar_value
82                 )));
83             }
84             v => {
85                 return Err(DataFusionError::Internal(format!(
86                     "Cannot convert a ScalarValue with associated DataType {:?}",
87                     v.get_datatype()
88                 )))
89             }
90         })
91     }
92 }
93 
94 impl From<&GroupByScalar> for ScalarValue {
from(group_by_scalar: &GroupByScalar) -> Self95     fn from(group_by_scalar: &GroupByScalar) -> Self {
96         match group_by_scalar {
97             GroupByScalar::Float32(v) => ScalarValue::Float32(Some((*v).into())),
98             GroupByScalar::Float64(v) => ScalarValue::Float64(Some((*v).into())),
99             GroupByScalar::Boolean(v) => ScalarValue::Boolean(Some(*v)),
100             GroupByScalar::Int8(v) => ScalarValue::Int8(Some(*v)),
101             GroupByScalar::Int16(v) => ScalarValue::Int16(Some(*v)),
102             GroupByScalar::Int32(v) => ScalarValue::Int32(Some(*v)),
103             GroupByScalar::Int64(v) => ScalarValue::Int64(Some(*v)),
104             GroupByScalar::UInt8(v) => ScalarValue::UInt8(Some(*v)),
105             GroupByScalar::UInt16(v) => ScalarValue::UInt16(Some(*v)),
106             GroupByScalar::UInt32(v) => ScalarValue::UInt32(Some(*v)),
107             GroupByScalar::UInt64(v) => ScalarValue::UInt64(Some(*v)),
108             GroupByScalar::Utf8(v) => ScalarValue::Utf8(Some(v.to_string())),
109             GroupByScalar::TimeMicrosecond(v) => ScalarValue::TimeMicrosecond(Some(*v)),
110             GroupByScalar::TimeNanosecond(v) => ScalarValue::TimeNanosecond(Some(*v)),
111             GroupByScalar::Date32(v) => ScalarValue::Date32(Some(*v)),
112         }
113     }
114 }
115 
116 #[cfg(test)]
117 mod tests {
118     use super::*;
119 
120     use crate::error::DataFusionError;
121 
122     macro_rules! scalar_eq_test {
123         ($TYPE:expr, $VALUE:expr) => {{
124             let scalar_value = $TYPE($VALUE);
125             let a = GroupByScalar::try_from(&scalar_value).unwrap();
126 
127             let scalar_value = $TYPE($VALUE);
128             let b = GroupByScalar::try_from(&scalar_value).unwrap();
129 
130             assert_eq!(a, b);
131         }};
132     }
133 
134     #[test]
test_scalar_ne_non_std()135     fn test_scalar_ne_non_std() {
136         // Test only Scalars with non native Eq, Hash
137         scalar_eq_test!(ScalarValue::Float32, Some(1.0));
138         scalar_eq_test!(ScalarValue::Float64, Some(1.0));
139     }
140 
141     macro_rules! scalar_ne_test {
142         ($TYPE:expr, $LVALUE:expr, $RVALUE:expr) => {{
143             let scalar_value = $TYPE($LVALUE);
144             let a = GroupByScalar::try_from(&scalar_value).unwrap();
145 
146             let scalar_value = $TYPE($RVALUE);
147             let b = GroupByScalar::try_from(&scalar_value).unwrap();
148 
149             assert_ne!(a, b);
150         }};
151     }
152 
153     #[test]
test_scalar_eq_non_std()154     fn test_scalar_eq_non_std() {
155         // Test only Scalars with non native Eq, Hash
156         scalar_ne_test!(ScalarValue::Float32, Some(1.0), Some(2.0));
157         scalar_ne_test!(ScalarValue::Float64, Some(1.0), Some(2.0));
158     }
159 
160     #[test]
from_scalar_holding_none()161     fn from_scalar_holding_none() {
162         let scalar_value = ScalarValue::Int8(None);
163         let result = GroupByScalar::try_from(&scalar_value);
164 
165         match result {
166             Err(DataFusionError::Internal(error_message)) => assert_eq!(
167                 error_message,
168                 String::from("Cannot convert a ScalarValue holding NULL (Int8(NULL))")
169             ),
170             _ => panic!("Unexpected result"),
171         }
172     }
173 
174     #[test]
from_scalar_unsupported()175     fn from_scalar_unsupported() {
176         // Use any ScalarValue type not supported by GroupByScalar.
177         let scalar_value = ScalarValue::LargeUtf8(Some("1.1".to_string()));
178         let result = GroupByScalar::try_from(&scalar_value);
179 
180         match result {
181             Err(DataFusionError::Internal(error_message)) => assert_eq!(
182                 error_message,
183                 String::from(
184                     "Cannot convert a ScalarValue with associated DataType LargeUtf8"
185                 )
186             ),
187             _ => panic!("Unexpected result"),
188         }
189     }
190 
191     #[test]
size_of_group_by_scalar()192     fn size_of_group_by_scalar() {
193         assert_eq!(std::mem::size_of::<GroupByScalar>(), 16);
194     }
195 }
196