1 // Licensed to the Apache Software Foundation (ASF) under one 2 // or more contributor license agreements. See the NOTICE file 3 // distributed with this work for additional information 4 // regarding copyright ownership. The ASF licenses this file 5 // to you under the Apache License, Version 2.0 (the 6 // "License"); you may not use this file except in compliance 7 // with the License. You may obtain a copy of the License at 8 // 9 // http://www.apache.org/licenses/LICENSE-2.0 10 // 11 // Unless required by applicable law or agreed to in writing, 12 // software distributed under the License is distributed on an 13 // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14 // KIND, either express or implied. See the License for the 15 // specific language governing permissions and limitations 16 // under the License. 17 18 //! Defines scalars used to construct groups, ex. in GROUP BY clauses. 19 20 use ordered_float::OrderedFloat; 21 use std::convert::{From, TryFrom}; 22 23 use crate::error::{DataFusionError, Result}; 24 use crate::scalar::ScalarValue; 25 26 /// Enumeration of types that can be used in a GROUP BY expression 27 #[derive(Debug, PartialEq, Eq, Hash, Clone)] 28 pub(crate) enum GroupByScalar { 29 Float32(OrderedFloat<f32>), 30 Float64(OrderedFloat<f64>), 31 UInt8(u8), 32 UInt16(u16), 33 UInt32(u32), 34 UInt64(u64), 35 Int8(i8), 36 Int16(i16), 37 Int32(i32), 38 Int64(i64), 39 Utf8(Box<String>), 40 Boolean(bool), 41 TimeMicrosecond(i64), 42 TimeNanosecond(i64), 43 Date32(i32), 44 } 45 46 impl TryFrom<&ScalarValue> for GroupByScalar { 47 type Error = DataFusionError; 48 try_from(scalar_value: &ScalarValue) -> Result<Self>49 fn try_from(scalar_value: &ScalarValue) -> Result<Self> { 50 Ok(match scalar_value { 51 ScalarValue::Float32(Some(v)) => { 52 GroupByScalar::Float32(OrderedFloat::from(*v)) 53 } 54 ScalarValue::Float64(Some(v)) => { 55 GroupByScalar::Float64(OrderedFloat::from(*v)) 56 } 57 ScalarValue::Boolean(Some(v)) => GroupByScalar::Boolean(*v), 58 ScalarValue::Int8(Some(v)) => GroupByScalar::Int8(*v), 59 ScalarValue::Int16(Some(v)) => GroupByScalar::Int16(*v), 60 ScalarValue::Int32(Some(v)) => GroupByScalar::Int32(*v), 61 ScalarValue::Int64(Some(v)) => GroupByScalar::Int64(*v), 62 ScalarValue::UInt8(Some(v)) => GroupByScalar::UInt8(*v), 63 ScalarValue::UInt16(Some(v)) => GroupByScalar::UInt16(*v), 64 ScalarValue::UInt32(Some(v)) => GroupByScalar::UInt32(*v), 65 ScalarValue::UInt64(Some(v)) => GroupByScalar::UInt64(*v), 66 ScalarValue::Utf8(Some(v)) => GroupByScalar::Utf8(Box::new(v.clone())), 67 ScalarValue::Float32(None) 68 | ScalarValue::Float64(None) 69 | ScalarValue::Boolean(None) 70 | ScalarValue::Int8(None) 71 | ScalarValue::Int16(None) 72 | ScalarValue::Int32(None) 73 | ScalarValue::Int64(None) 74 | ScalarValue::UInt8(None) 75 | ScalarValue::UInt16(None) 76 | ScalarValue::UInt32(None) 77 | ScalarValue::UInt64(None) 78 | ScalarValue::Utf8(None) => { 79 return Err(DataFusionError::Internal(format!( 80 "Cannot convert a ScalarValue holding NULL ({:?})", 81 scalar_value 82 ))); 83 } 84 v => { 85 return Err(DataFusionError::Internal(format!( 86 "Cannot convert a ScalarValue with associated DataType {:?}", 87 v.get_datatype() 88 ))) 89 } 90 }) 91 } 92 } 93 94 impl From<&GroupByScalar> for ScalarValue { from(group_by_scalar: &GroupByScalar) -> Self95 fn from(group_by_scalar: &GroupByScalar) -> Self { 96 match group_by_scalar { 97 GroupByScalar::Float32(v) => ScalarValue::Float32(Some((*v).into())), 98 GroupByScalar::Float64(v) => ScalarValue::Float64(Some((*v).into())), 99 GroupByScalar::Boolean(v) => ScalarValue::Boolean(Some(*v)), 100 GroupByScalar::Int8(v) => ScalarValue::Int8(Some(*v)), 101 GroupByScalar::Int16(v) => ScalarValue::Int16(Some(*v)), 102 GroupByScalar::Int32(v) => ScalarValue::Int32(Some(*v)), 103 GroupByScalar::Int64(v) => ScalarValue::Int64(Some(*v)), 104 GroupByScalar::UInt8(v) => ScalarValue::UInt8(Some(*v)), 105 GroupByScalar::UInt16(v) => ScalarValue::UInt16(Some(*v)), 106 GroupByScalar::UInt32(v) => ScalarValue::UInt32(Some(*v)), 107 GroupByScalar::UInt64(v) => ScalarValue::UInt64(Some(*v)), 108 GroupByScalar::Utf8(v) => ScalarValue::Utf8(Some(v.to_string())), 109 GroupByScalar::TimeMicrosecond(v) => ScalarValue::TimeMicrosecond(Some(*v)), 110 GroupByScalar::TimeNanosecond(v) => ScalarValue::TimeNanosecond(Some(*v)), 111 GroupByScalar::Date32(v) => ScalarValue::Date32(Some(*v)), 112 } 113 } 114 } 115 116 #[cfg(test)] 117 mod tests { 118 use super::*; 119 120 use crate::error::DataFusionError; 121 122 macro_rules! scalar_eq_test { 123 ($TYPE:expr, $VALUE:expr) => {{ 124 let scalar_value = $TYPE($VALUE); 125 let a = GroupByScalar::try_from(&scalar_value).unwrap(); 126 127 let scalar_value = $TYPE($VALUE); 128 let b = GroupByScalar::try_from(&scalar_value).unwrap(); 129 130 assert_eq!(a, b); 131 }}; 132 } 133 134 #[test] test_scalar_ne_non_std()135 fn test_scalar_ne_non_std() { 136 // Test only Scalars with non native Eq, Hash 137 scalar_eq_test!(ScalarValue::Float32, Some(1.0)); 138 scalar_eq_test!(ScalarValue::Float64, Some(1.0)); 139 } 140 141 macro_rules! scalar_ne_test { 142 ($TYPE:expr, $LVALUE:expr, $RVALUE:expr) => {{ 143 let scalar_value = $TYPE($LVALUE); 144 let a = GroupByScalar::try_from(&scalar_value).unwrap(); 145 146 let scalar_value = $TYPE($RVALUE); 147 let b = GroupByScalar::try_from(&scalar_value).unwrap(); 148 149 assert_ne!(a, b); 150 }}; 151 } 152 153 #[test] test_scalar_eq_non_std()154 fn test_scalar_eq_non_std() { 155 // Test only Scalars with non native Eq, Hash 156 scalar_ne_test!(ScalarValue::Float32, Some(1.0), Some(2.0)); 157 scalar_ne_test!(ScalarValue::Float64, Some(1.0), Some(2.0)); 158 } 159 160 #[test] from_scalar_holding_none()161 fn from_scalar_holding_none() { 162 let scalar_value = ScalarValue::Int8(None); 163 let result = GroupByScalar::try_from(&scalar_value); 164 165 match result { 166 Err(DataFusionError::Internal(error_message)) => assert_eq!( 167 error_message, 168 String::from("Cannot convert a ScalarValue holding NULL (Int8(NULL))") 169 ), 170 _ => panic!("Unexpected result"), 171 } 172 } 173 174 #[test] from_scalar_unsupported()175 fn from_scalar_unsupported() { 176 // Use any ScalarValue type not supported by GroupByScalar. 177 let scalar_value = ScalarValue::LargeUtf8(Some("1.1".to_string())); 178 let result = GroupByScalar::try_from(&scalar_value); 179 180 match result { 181 Err(DataFusionError::Internal(error_message)) => assert_eq!( 182 error_message, 183 String::from( 184 "Cannot convert a ScalarValue with associated DataType LargeUtf8" 185 ) 186 ), 187 _ => panic!("Unexpected result"), 188 } 189 } 190 191 #[test] size_of_group_by_scalar()192 fn size_of_group_by_scalar() { 193 assert_eq!(std::mem::size_of::<GroupByScalar>(), 16); 194 } 195 } 196