1 // Licensed to the Apache Software Foundation (ASF) under one
2 // or more contributor license agreements.  See the NOTICE file
3 // distributed with this work for additional information
4 // regarding copyright ownership.  The ASF licenses this file
5 // to you under the Apache License, Version 2.0 (the
6 // "License"); you may not use this file except in compliance
7 // with the License.  You may obtain a copy of the License at
8 //
9 //   http://www.apache.org/licenses/LICENSE-2.0
10 //
11 // Unless required by applicable law or agreed to in writing,
12 // software distributed under the License is distributed on an
13 // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14 // KIND, either express or implied.  See the License for the
15 // specific language governing permissions and limitations
16 // under the License.
17 
18 /// In this example we will declare a single-type, single return type UDAF that computes the geometric mean.
19 /// The geometric mean is described here: https://en.wikipedia.org/wiki/Geometric_mean
20 use arrow::{
21     array::Float32Array, array::Float64Array, datatypes::DataType,
22     record_batch::RecordBatch,
23 };
24 
25 use datafusion::{error::Result, logical_plan::create_udaf, physical_plan::Accumulator};
26 use datafusion::{prelude::*, scalar::ScalarValue};
27 use std::sync::Arc;
28 
29 // create local execution context with an in-memory table
create_context() -> Result<ExecutionContext>30 fn create_context() -> Result<ExecutionContext> {
31     use arrow::datatypes::{Field, Schema};
32     use datafusion::datasource::MemTable;
33     // define a schema.
34     let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Float32, false)]));
35 
36     // define data in two partitions
37     let batch1 = RecordBatch::try_new(
38         schema.clone(),
39         vec![Arc::new(Float32Array::from(vec![2.0, 4.0, 8.0]))],
40     )?;
41     let batch2 = RecordBatch::try_new(
42         schema.clone(),
43         vec![Arc::new(Float32Array::from(vec![64.0]))],
44     )?;
45 
46     // declare a new context. In spark API, this corresponds to a new spark SQLsession
47     let mut ctx = ExecutionContext::new();
48 
49     // declare a table in memory. In spark API, this corresponds to createDataFrame(...).
50     let provider = MemTable::try_new(schema, vec![vec![batch1], vec![batch2]])?;
51     ctx.register_table("t", Arc::new(provider));
52     Ok(ctx)
53 }
54 
55 /// A UDAF has state across multiple rows, and thus we require a `struct` with that state.
56 #[derive(Debug)]
57 struct GeometricMean {
58     n: u32,
59     prod: f64,
60 }
61 
62 impl GeometricMean {
63     // how the struct is initialized
new() -> Self64     pub fn new() -> Self {
65         GeometricMean { n: 0, prod: 1.0 }
66     }
67 }
68 
69 // UDAFs are built using the trait `Accumulator`, that offers DataFusion the necessary functions
70 // to use them.
71 impl Accumulator for GeometricMean {
72     // this function serializes our state to `ScalarValue`, which DataFusion uses
73     // to pass this state between execution stages.
74     // Note that this can be arbitrary data.
state(&self) -> Result<Vec<ScalarValue>>75     fn state(&self) -> Result<Vec<ScalarValue>> {
76         Ok(vec![
77             ScalarValue::from(self.prod),
78             ScalarValue::from(self.n),
79         ])
80     }
81 
82     // this function receives one entry per argument of this accumulator.
83     // DataFusion calls this function on every row, and expects this function to update the accumulator's state.
update(&mut self, values: &[ScalarValue]) -> Result<()>84     fn update(&mut self, values: &[ScalarValue]) -> Result<()> {
85         // this is a one-argument UDAF, and thus we use `0`.
86         let value = &values[0];
87         match value {
88             // here we map `ScalarValue` to our internal state. `Float64` indicates that this function
89             // only accepts Float64 as its argument (DataFusion does try to coerce arguments to this type)
90             //
91             // Note that `.map` here ensures that we ignore Nulls.
92             ScalarValue::Float64(e) => e.map(|value| {
93                 self.prod *= value;
94                 self.n += 1;
95             }),
96             _ => unreachable!(""),
97         };
98         Ok(())
99     }
100 
101     // this function receives states from other accumulators (Vec<ScalarValue>)
102     // and updates the accumulator.
merge(&mut self, states: &[ScalarValue]) -> Result<()>103     fn merge(&mut self, states: &[ScalarValue]) -> Result<()> {
104         let prod = &states[0];
105         let n = &states[1];
106         match (prod, n) {
107             (ScalarValue::Float64(Some(prod)), ScalarValue::UInt32(Some(n))) => {
108                 self.prod *= prod;
109                 self.n += n;
110             }
111             _ => unreachable!(""),
112         };
113         Ok(())
114     }
115 
116     // DataFusion expects this function to return the final value of this aggregator.
117     // in this case, this is the formula of the geometric mean
evaluate(&self) -> Result<ScalarValue>118     fn evaluate(&self) -> Result<ScalarValue> {
119         let value = self.prod.powf(1.0 / self.n as f64);
120         Ok(ScalarValue::from(value))
121     }
122 
123     // Optimization hint: this trait also supports `update_batch` and `merge_batch`,
124     // that can be used to perform these operations on arrays instead of single values.
125     // By default, these methods call `update` and `merge` row by row
126 }
127 
128 #[tokio::main]
main() -> Result<()>129 async fn main() -> Result<()> {
130     let ctx = create_context()?;
131 
132     // here is where we define the UDAF. We also declare its signature:
133     let geometric_mean = create_udaf(
134         // the name; used to represent it in plan descriptions and in the registry, to use in SQL.
135         "geo_mean",
136         // the input type; DataFusion guarantees that the first entry of `values` in `update` has this type.
137         DataType::Float64,
138         // the return type; DataFusion expects this to match the type returned by `evaluate`.
139         Arc::new(DataType::Float64),
140         // This is the accumulator factory; DataFusion uses it to create new accumulators.
141         Arc::new(|| Ok(Box::new(GeometricMean::new()))),
142         // This is the description of the state. `state()` must match the types here.
143         Arc::new(vec![DataType::Float64, DataType::UInt32]),
144     );
145 
146     // get a DataFrame from the context
147     // this table has 1 column `a` f32 with values {2,4,8,64}, whose geometric mean is 8.0.
148     let df = ctx.table("t")?;
149 
150     // perform the aggregation
151     let df = df.aggregate(&[], &[geometric_mean.call(vec![col("a")])])?;
152 
153     // note that "a" is f32, not f64. DataFusion coerces it to match the UDAF's signature.
154 
155     // execute the query
156     let results = df.collect().await?;
157 
158     // downcast the array to the expected type
159     let result = results[0]
160         .column(0)
161         .as_any()
162         .downcast_ref::<Float64Array>()
163         .unwrap();
164 
165     // verify that the calculation is correct
166     assert!((result.value(0) - 8.0).abs() < f64::EPSILON);
167     println!("The geometric mean of [2,4,8,64] is {}", result.value(0));
168 
169     Ok(())
170 }
171