1 // Licensed to the Apache Software Foundation (ASF) under one
2 // or more contributor license agreements. See the NOTICE file
3 // distributed with this work for additional information
4 // regarding copyright ownership. The ASF licenses this file
5 // to you under the Apache License, Version 2.0 (the
6 // "License"); you may not use this file except in compliance
7 // with the License. You may obtain a copy of the License at
8 //
9 // http://www.apache.org/licenses/LICENSE-2.0
10 //
11 // Unless required by applicable law or agreed to in writing,
12 // software distributed under the License is distributed on an
13 // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14 // KIND, either express or implied. See the License for the
15 // specific language governing permissions and limitations
16 // under the License.
17
18 //! Declaration of built-in (aggregate) functions.
19 //! This module contains built-in aggregates' enumeration and metadata.
20 //!
21 //! Generally, an aggregate has:
22 //! * a signature
23 //! * a return type, that is a function of the incoming argument's types
24 //! * the computation, that must accept each valid signature
25 //!
26 //! * Signature: see `Signature`
27 //! * Return type: a function `(arg_types) -> return_type`. E.g. for min, ([f32]) -> f32, ([f64]) -> f64.
28
29 use super::{
30 functions::Signature,
31 type_coercion::{coerce, data_types},
32 Accumulator, AggregateExpr, PhysicalExpr,
33 };
34 use crate::error::{DataFusionError, Result};
35 use crate::physical_plan::distinct_expressions;
36 use crate::physical_plan::expressions;
37 use arrow::datatypes::{DataType, Schema};
38 use expressions::{avg_return_type, sum_return_type};
39 use std::{fmt, str::FromStr, sync::Arc};
40
41 /// the implementation of an aggregate function
42 pub type AccumulatorFunctionImplementation =
43 Arc<dyn Fn() -> Result<Box<dyn Accumulator>> + Send + Sync>;
44
45 /// This signature corresponds to which types an aggregator serializes
46 /// its state, given its return datatype.
47 pub type StateTypeFunction =
48 Arc<dyn Fn(&DataType) -> Result<Arc<Vec<DataType>>> + Send + Sync>;
49
50 /// Enum of all built-in scalar functions
51 #[derive(Debug, Clone, PartialEq, Eq)]
52 pub enum AggregateFunction {
53 /// count
54 Count,
55 /// sum
56 Sum,
57 /// min
58 Min,
59 /// max
60 Max,
61 /// avg
62 Avg,
63 }
64
65 impl fmt::Display for AggregateFunction {
fmt(&self, f: &mut fmt::Formatter) -> fmt::Result66 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
67 // uppercase of the debug.
68 write!(f, "{}", format!("{:?}", self).to_uppercase())
69 }
70 }
71
72 impl FromStr for AggregateFunction {
73 type Err = DataFusionError;
from_str(name: &str) -> Result<AggregateFunction>74 fn from_str(name: &str) -> Result<AggregateFunction> {
75 Ok(match &*name.to_uppercase() {
76 "MIN" => AggregateFunction::Min,
77 "MAX" => AggregateFunction::Max,
78 "COUNT" => AggregateFunction::Count,
79 "AVG" => AggregateFunction::Avg,
80 "SUM" => AggregateFunction::Sum,
81 _ => {
82 return Err(DataFusionError::Plan(format!(
83 "There is no built-in function named {}",
84 name
85 )))
86 }
87 })
88 }
89 }
90
91 /// Returns the datatype of the scalar function
return_type(fun: &AggregateFunction, arg_types: &[DataType]) -> Result<DataType>92 pub fn return_type(fun: &AggregateFunction, arg_types: &[DataType]) -> Result<DataType> {
93 // Note that this function *must* return the same type that the respective physical expression returns
94 // or the execution panics.
95
96 // verify that this is a valid set of data types for this function
97 data_types(arg_types, &signature(fun))?;
98
99 match fun {
100 AggregateFunction::Count => Ok(DataType::UInt64),
101 AggregateFunction::Max | AggregateFunction::Min => Ok(arg_types[0].clone()),
102 AggregateFunction::Sum => sum_return_type(&arg_types[0]),
103 AggregateFunction::Avg => avg_return_type(&arg_types[0]),
104 }
105 }
106
107 /// Create a physical (function) expression.
108 /// This function errors when `args`' can't be coerced to a valid argument type of the function.
create_aggregate_expr( fun: &AggregateFunction, distinct: bool, args: &[Arc<dyn PhysicalExpr>], input_schema: &Schema, name: String, ) -> Result<Arc<dyn AggregateExpr>>109 pub fn create_aggregate_expr(
110 fun: &AggregateFunction,
111 distinct: bool,
112 args: &[Arc<dyn PhysicalExpr>],
113 input_schema: &Schema,
114 name: String,
115 ) -> Result<Arc<dyn AggregateExpr>> {
116 // coerce
117 let arg = coerce(args, input_schema, &signature(fun))?[0].clone();
118
119 let arg_types = args
120 .iter()
121 .map(|e| e.data_type(input_schema))
122 .collect::<Result<Vec<_>>>()?;
123
124 let return_type = return_type(&fun, &arg_types)?;
125
126 Ok(match (fun, distinct) {
127 (AggregateFunction::Count, false) => {
128 Arc::new(expressions::Count::new(arg, name, return_type))
129 }
130 (AggregateFunction::Count, true) => {
131 Arc::new(distinct_expressions::DistinctCount::new(
132 arg_types,
133 args.to_vec(),
134 name,
135 return_type,
136 ))
137 }
138 (AggregateFunction::Sum, false) => {
139 Arc::new(expressions::Sum::new(arg, name, return_type))
140 }
141 (AggregateFunction::Sum, true) => {
142 return Err(DataFusionError::NotImplemented(
143 "SUM(DISTINCT) aggregations are not available".to_string(),
144 ));
145 }
146 (AggregateFunction::Min, _) => {
147 Arc::new(expressions::Min::new(arg, name, return_type))
148 }
149 (AggregateFunction::Max, _) => {
150 Arc::new(expressions::Max::new(arg, name, return_type))
151 }
152 (AggregateFunction::Avg, false) => {
153 Arc::new(expressions::Avg::new(arg, name, return_type))
154 }
155 (AggregateFunction::Avg, true) => {
156 return Err(DataFusionError::NotImplemented(
157 "AVG(DISTINCT) aggregations are not available".to_string(),
158 ));
159 }
160 })
161 }
162
163 static NUMERICS: &[DataType] = &[
164 DataType::Int8,
165 DataType::Int16,
166 DataType::Int32,
167 DataType::Int64,
168 DataType::UInt8,
169 DataType::UInt16,
170 DataType::UInt32,
171 DataType::UInt64,
172 DataType::Float32,
173 DataType::Float64,
174 ];
175
176 /// the signatures supported by the function `fun`.
signature(fun: &AggregateFunction) -> Signature177 fn signature(fun: &AggregateFunction) -> Signature {
178 // note: the physical expression must accept the type returned by this function or the execution panics.
179 match fun {
180 AggregateFunction::Count => Signature::Any(1),
181 AggregateFunction::Min | AggregateFunction::Max => {
182 let mut valid = vec![DataType::Utf8, DataType::LargeUtf8];
183 valid.extend_from_slice(NUMERICS);
184 Signature::Uniform(1, valid)
185 }
186 AggregateFunction::Avg | AggregateFunction::Sum => {
187 Signature::Uniform(1, NUMERICS.to_vec())
188 }
189 }
190 }
191
192 #[cfg(test)]
193 mod tests {
194 use super::*;
195 use crate::error::Result;
196
197 #[test]
test_min_max() -> Result<()>198 fn test_min_max() -> Result<()> {
199 let observed = return_type(&AggregateFunction::Min, &[DataType::Utf8])?;
200 assert_eq!(DataType::Utf8, observed);
201
202 let observed = return_type(&AggregateFunction::Max, &[DataType::Int32])?;
203 assert_eq!(DataType::Int32, observed);
204 Ok(())
205 }
206
207 #[test]
test_sum_no_utf8()208 fn test_sum_no_utf8() {
209 let observed = return_type(&AggregateFunction::Sum, &[DataType::Utf8]);
210 assert!(observed.is_err());
211 }
212
213 #[test]
test_sum_upcasts() -> Result<()>214 fn test_sum_upcasts() -> Result<()> {
215 let observed = return_type(&AggregateFunction::Sum, &[DataType::UInt32])?;
216 assert_eq!(DataType::UInt64, observed);
217 Ok(())
218 }
219
220 #[test]
test_count_return_type() -> Result<()>221 fn test_count_return_type() -> Result<()> {
222 let observed = return_type(&AggregateFunction::Count, &[DataType::Utf8])?;
223 assert_eq!(DataType::UInt64, observed);
224
225 let observed = return_type(&AggregateFunction::Count, &[DataType::Int8])?;
226 assert_eq!(DataType::UInt64, observed);
227 Ok(())
228 }
229
230 #[test]
test_avg_return_type() -> Result<()>231 fn test_avg_return_type() -> Result<()> {
232 let observed = return_type(&AggregateFunction::Avg, &[DataType::Float32])?;
233 assert_eq!(DataType::Float64, observed);
234
235 let observed = return_type(&AggregateFunction::Avg, &[DataType::Float64])?;
236 assert_eq!(DataType::Float64, observed);
237 Ok(())
238 }
239
240 #[test]
test_avg_no_utf8()241 fn test_avg_no_utf8() {
242 let observed = return_type(&AggregateFunction::Avg, &[DataType::Utf8]);
243 assert!(observed.is_err());
244 }
245 }
246