1 // Licensed to the Apache Software Foundation (ASF) under one
2 // or more contributor license agreements.  See the NOTICE file
3 // distributed with this work for additional information
4 // regarding copyright ownership.  The ASF licenses this file
5 // to you under the Apache License, Version 2.0 (the
6 // "License"); you may not use this file except in compliance
7 // with the License.  You may obtain a copy of the License at
8 //
9 //   http://www.apache.org/licenses/LICENSE-2.0
10 //
11 // Unless required by applicable law or agreed to in writing,
12 // software distributed under the License is distributed on an
13 // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14 // KIND, either express or implied.  See the License for the
15 // specific language governing permissions and limitations
16 // under the License.
17 
18 //! Declaration of built-in (aggregate) functions.
19 //! This module contains built-in aggregates' enumeration and metadata.
20 //!
21 //! Generally, an aggregate has:
22 //! * a signature
23 //! * a return type, that is a function of the incoming argument's types
24 //! * the computation, that must accept each valid signature
25 //!
26 //! * Signature: see `Signature`
27 //! * Return type: a function `(arg_types) -> return_type`. E.g. for min, ([f32]) -> f32, ([f64]) -> f64.
28 
29 use super::{
30     functions::Signature,
31     type_coercion::{coerce, data_types},
32     Accumulator, AggregateExpr, PhysicalExpr,
33 };
34 use crate::error::{DataFusionError, Result};
35 use crate::physical_plan::distinct_expressions;
36 use crate::physical_plan::expressions;
37 use arrow::datatypes::{DataType, Schema};
38 use expressions::{avg_return_type, sum_return_type};
39 use std::{fmt, str::FromStr, sync::Arc};
40 
41 /// the implementation of an aggregate function
42 pub type AccumulatorFunctionImplementation =
43     Arc<dyn Fn() -> Result<Box<dyn Accumulator>> + Send + Sync>;
44 
45 /// This signature corresponds to which types an aggregator serializes
46 /// its state, given its return datatype.
47 pub type StateTypeFunction =
48     Arc<dyn Fn(&DataType) -> Result<Arc<Vec<DataType>>> + Send + Sync>;
49 
50 /// Enum of all built-in scalar functions
51 #[derive(Debug, Clone, PartialEq, Eq)]
52 pub enum AggregateFunction {
53     /// count
54     Count,
55     /// sum
56     Sum,
57     /// min
58     Min,
59     /// max
60     Max,
61     /// avg
62     Avg,
63 }
64 
65 impl fmt::Display for AggregateFunction {
fmt(&self, f: &mut fmt::Formatter) -> fmt::Result66     fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
67         // uppercase of the debug.
68         write!(f, "{}", format!("{:?}", self).to_uppercase())
69     }
70 }
71 
72 impl FromStr for AggregateFunction {
73     type Err = DataFusionError;
from_str(name: &str) -> Result<AggregateFunction>74     fn from_str(name: &str) -> Result<AggregateFunction> {
75         Ok(match &*name.to_uppercase() {
76             "MIN" => AggregateFunction::Min,
77             "MAX" => AggregateFunction::Max,
78             "COUNT" => AggregateFunction::Count,
79             "AVG" => AggregateFunction::Avg,
80             "SUM" => AggregateFunction::Sum,
81             _ => {
82                 return Err(DataFusionError::Plan(format!(
83                     "There is no built-in function named {}",
84                     name
85                 )))
86             }
87         })
88     }
89 }
90 
91 /// Returns the datatype of the scalar function
return_type(fun: &AggregateFunction, arg_types: &[DataType]) -> Result<DataType>92 pub fn return_type(fun: &AggregateFunction, arg_types: &[DataType]) -> Result<DataType> {
93     // Note that this function *must* return the same type that the respective physical expression returns
94     // or the execution panics.
95 
96     // verify that this is a valid set of data types for this function
97     data_types(arg_types, &signature(fun))?;
98 
99     match fun {
100         AggregateFunction::Count => Ok(DataType::UInt64),
101         AggregateFunction::Max | AggregateFunction::Min => Ok(arg_types[0].clone()),
102         AggregateFunction::Sum => sum_return_type(&arg_types[0]),
103         AggregateFunction::Avg => avg_return_type(&arg_types[0]),
104     }
105 }
106 
107 /// Create a physical (function) expression.
108 /// This function errors when `args`' can't be coerced to a valid argument type of the function.
create_aggregate_expr( fun: &AggregateFunction, distinct: bool, args: &[Arc<dyn PhysicalExpr>], input_schema: &Schema, name: String, ) -> Result<Arc<dyn AggregateExpr>>109 pub fn create_aggregate_expr(
110     fun: &AggregateFunction,
111     distinct: bool,
112     args: &[Arc<dyn PhysicalExpr>],
113     input_schema: &Schema,
114     name: String,
115 ) -> Result<Arc<dyn AggregateExpr>> {
116     // coerce
117     let arg = coerce(args, input_schema, &signature(fun))?[0].clone();
118 
119     let arg_types = args
120         .iter()
121         .map(|e| e.data_type(input_schema))
122         .collect::<Result<Vec<_>>>()?;
123 
124     let return_type = return_type(&fun, &arg_types)?;
125 
126     Ok(match (fun, distinct) {
127         (AggregateFunction::Count, false) => {
128             Arc::new(expressions::Count::new(arg, name, return_type))
129         }
130         (AggregateFunction::Count, true) => {
131             Arc::new(distinct_expressions::DistinctCount::new(
132                 arg_types,
133                 args.to_vec(),
134                 name,
135                 return_type,
136             ))
137         }
138         (AggregateFunction::Sum, false) => {
139             Arc::new(expressions::Sum::new(arg, name, return_type))
140         }
141         (AggregateFunction::Sum, true) => {
142             return Err(DataFusionError::NotImplemented(
143                 "SUM(DISTINCT) aggregations are not available".to_string(),
144             ));
145         }
146         (AggregateFunction::Min, _) => {
147             Arc::new(expressions::Min::new(arg, name, return_type))
148         }
149         (AggregateFunction::Max, _) => {
150             Arc::new(expressions::Max::new(arg, name, return_type))
151         }
152         (AggregateFunction::Avg, false) => {
153             Arc::new(expressions::Avg::new(arg, name, return_type))
154         }
155         (AggregateFunction::Avg, true) => {
156             return Err(DataFusionError::NotImplemented(
157                 "AVG(DISTINCT) aggregations are not available".to_string(),
158             ));
159         }
160     })
161 }
162 
163 static NUMERICS: &[DataType] = &[
164     DataType::Int8,
165     DataType::Int16,
166     DataType::Int32,
167     DataType::Int64,
168     DataType::UInt8,
169     DataType::UInt16,
170     DataType::UInt32,
171     DataType::UInt64,
172     DataType::Float32,
173     DataType::Float64,
174 ];
175 
176 /// the signatures supported by the function `fun`.
signature(fun: &AggregateFunction) -> Signature177 fn signature(fun: &AggregateFunction) -> Signature {
178     // note: the physical expression must accept the type returned by this function or the execution panics.
179     match fun {
180         AggregateFunction::Count => Signature::Any(1),
181         AggregateFunction::Min | AggregateFunction::Max => {
182             let mut valid = vec![DataType::Utf8, DataType::LargeUtf8];
183             valid.extend_from_slice(NUMERICS);
184             Signature::Uniform(1, valid)
185         }
186         AggregateFunction::Avg | AggregateFunction::Sum => {
187             Signature::Uniform(1, NUMERICS.to_vec())
188         }
189     }
190 }
191 
192 #[cfg(test)]
193 mod tests {
194     use super::*;
195     use crate::error::Result;
196 
197     #[test]
test_min_max() -> Result<()>198     fn test_min_max() -> Result<()> {
199         let observed = return_type(&AggregateFunction::Min, &[DataType::Utf8])?;
200         assert_eq!(DataType::Utf8, observed);
201 
202         let observed = return_type(&AggregateFunction::Max, &[DataType::Int32])?;
203         assert_eq!(DataType::Int32, observed);
204         Ok(())
205     }
206 
207     #[test]
test_sum_no_utf8()208     fn test_sum_no_utf8() {
209         let observed = return_type(&AggregateFunction::Sum, &[DataType::Utf8]);
210         assert!(observed.is_err());
211     }
212 
213     #[test]
test_sum_upcasts() -> Result<()>214     fn test_sum_upcasts() -> Result<()> {
215         let observed = return_type(&AggregateFunction::Sum, &[DataType::UInt32])?;
216         assert_eq!(DataType::UInt64, observed);
217         Ok(())
218     }
219 
220     #[test]
test_count_return_type() -> Result<()>221     fn test_count_return_type() -> Result<()> {
222         let observed = return_type(&AggregateFunction::Count, &[DataType::Utf8])?;
223         assert_eq!(DataType::UInt64, observed);
224 
225         let observed = return_type(&AggregateFunction::Count, &[DataType::Int8])?;
226         assert_eq!(DataType::UInt64, observed);
227         Ok(())
228     }
229 
230     #[test]
test_avg_return_type() -> Result<()>231     fn test_avg_return_type() -> Result<()> {
232         let observed = return_type(&AggregateFunction::Avg, &[DataType::Float32])?;
233         assert_eq!(DataType::Float64, observed);
234 
235         let observed = return_type(&AggregateFunction::Avg, &[DataType::Float64])?;
236         assert_eq!(DataType::Float64, observed);
237         Ok(())
238     }
239 
240     #[test]
test_avg_no_utf8()241     fn test_avg_no_utf8() {
242         let observed = return_type(&AggregateFunction::Avg, &[DataType::Utf8]);
243         assert!(observed.is_err());
244     }
245 }
246