1 // Licensed to the Apache Software Foundation (ASF) under one
2 // or more contributor license agreements.  See the NOTICE file
3 // distributed with this work for additional information
4 // regarding copyright ownership.  The ASF licenses this file
5 // to you under the Apache License, Version 2.0 (the
6 // "License"); you may not use this file except in compliance
7 // with the License.  You may obtain a copy of the License at
8 //
9 //   http://www.apache.org/licenses/LICENSE-2.0
10 //
11 // Unless required by applicable law or agreed to in writing,
12 // software distributed under the License is distributed on an
13 // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14 // KIND, either express or implied.  See the License for the
15 // specific language governing permissions and limitations
16 // under the License.
17 
18 //! Implementation of DataFrame API
19 
20 use std::sync::{Arc, Mutex};
21 
22 use crate::arrow::record_batch::RecordBatch;
23 use crate::error::Result;
24 use crate::execution::context::{ExecutionContext, ExecutionContextState};
25 use crate::logical_plan::{
26     col, DFSchema, Expr, FunctionRegistry, JoinType, LogicalPlan, LogicalPlanBuilder,
27     Partitioning,
28 };
29 use crate::{
30     dataframe::*,
31     physical_plan::{collect, collect_partitioned},
32 };
33 
34 use async_trait::async_trait;
35 
36 /// Implementation of DataFrame API
37 pub struct DataFrameImpl {
38     ctx_state: Arc<Mutex<ExecutionContextState>>,
39     plan: LogicalPlan,
40 }
41 
42 impl DataFrameImpl {
43     /// Create a new Table based on an existing logical plan
new(ctx_state: Arc<Mutex<ExecutionContextState>>, plan: &LogicalPlan) -> Self44     pub fn new(ctx_state: Arc<Mutex<ExecutionContextState>>, plan: &LogicalPlan) -> Self {
45         Self {
46             ctx_state,
47             plan: plan.clone(),
48         }
49     }
50 }
51 
52 #[async_trait]
53 impl DataFrame for DataFrameImpl {
54     /// Apply a projection based on a list of column names
select_columns(&self, columns: &[&str]) -> Result<Arc<dyn DataFrame>>55     fn select_columns(&self, columns: &[&str]) -> Result<Arc<dyn DataFrame>> {
56         let fields = columns
57             .iter()
58             .map(|name| self.plan.schema().field_with_unqualified_name(name))
59             .collect::<Result<Vec<_>>>()?;
60         let expr: Vec<Expr> = fields.iter().map(|f| col(f.name())).collect();
61         self.select(&expr)
62     }
63 
64     /// Create a projection based on arbitrary expressions
select(&self, expr_list: &[Expr]) -> Result<Arc<dyn DataFrame>>65     fn select(&self, expr_list: &[Expr]) -> Result<Arc<dyn DataFrame>> {
66         let plan = LogicalPlanBuilder::from(&self.plan)
67             .project(expr_list)?
68             .build()?;
69         Ok(Arc::new(DataFrameImpl::new(self.ctx_state.clone(), &plan)))
70     }
71 
72     /// Create a filter based on a predicate expression
filter(&self, predicate: Expr) -> Result<Arc<dyn DataFrame>>73     fn filter(&self, predicate: Expr) -> Result<Arc<dyn DataFrame>> {
74         let plan = LogicalPlanBuilder::from(&self.plan)
75             .filter(predicate)?
76             .build()?;
77         Ok(Arc::new(DataFrameImpl::new(self.ctx_state.clone(), &plan)))
78     }
79 
80     /// Perform an aggregate query
aggregate( &self, group_expr: &[Expr], aggr_expr: &[Expr], ) -> Result<Arc<dyn DataFrame>>81     fn aggregate(
82         &self,
83         group_expr: &[Expr],
84         aggr_expr: &[Expr],
85     ) -> Result<Arc<dyn DataFrame>> {
86         let plan = LogicalPlanBuilder::from(&self.plan)
87             .aggregate(group_expr, aggr_expr)?
88             .build()?;
89         Ok(Arc::new(DataFrameImpl::new(self.ctx_state.clone(), &plan)))
90     }
91 
92     /// Limit the number of rows
limit(&self, n: usize) -> Result<Arc<dyn DataFrame>>93     fn limit(&self, n: usize) -> Result<Arc<dyn DataFrame>> {
94         let plan = LogicalPlanBuilder::from(&self.plan).limit(n)?.build()?;
95         Ok(Arc::new(DataFrameImpl::new(self.ctx_state.clone(), &plan)))
96     }
97 
98     /// Sort by specified sorting expressions
sort(&self, expr: &[Expr]) -> Result<Arc<dyn DataFrame>>99     fn sort(&self, expr: &[Expr]) -> Result<Arc<dyn DataFrame>> {
100         let plan = LogicalPlanBuilder::from(&self.plan).sort(expr)?.build()?;
101         Ok(Arc::new(DataFrameImpl::new(self.ctx_state.clone(), &plan)))
102     }
103 
104     /// Join with another DataFrame
join( &self, right: Arc<dyn DataFrame>, join_type: JoinType, left_cols: &[&str], right_cols: &[&str], ) -> Result<Arc<dyn DataFrame>>105     fn join(
106         &self,
107         right: Arc<dyn DataFrame>,
108         join_type: JoinType,
109         left_cols: &[&str],
110         right_cols: &[&str],
111     ) -> Result<Arc<dyn DataFrame>> {
112         let plan = LogicalPlanBuilder::from(&self.plan)
113             .join(&right.to_logical_plan(), join_type, left_cols, right_cols)?
114             .build()?;
115         Ok(Arc::new(DataFrameImpl::new(self.ctx_state.clone(), &plan)))
116     }
117 
repartition( &self, partitioning_scheme: Partitioning, ) -> Result<Arc<dyn DataFrame>>118     fn repartition(
119         &self,
120         partitioning_scheme: Partitioning,
121     ) -> Result<Arc<dyn DataFrame>> {
122         let plan = LogicalPlanBuilder::from(&self.plan)
123             .repartition(partitioning_scheme)?
124             .build()?;
125         Ok(Arc::new(DataFrameImpl::new(self.ctx_state.clone(), &plan)))
126     }
127 
128     /// Convert to logical plan
to_logical_plan(&self) -> LogicalPlan129     fn to_logical_plan(&self) -> LogicalPlan {
130         self.plan.clone()
131     }
132 
133     // Convert the logical plan represented by this DataFrame into a physical plan and
134     // execute it
collect(&self) -> Result<Vec<RecordBatch>>135     async fn collect(&self) -> Result<Vec<RecordBatch>> {
136         let state = self.ctx_state.lock().unwrap().clone();
137         let ctx = ExecutionContext::from(Arc::new(Mutex::new(state)));
138         let plan = ctx.optimize(&self.plan)?;
139         let plan = ctx.create_physical_plan(&plan)?;
140         Ok(collect(plan).await?)
141     }
142 
143     // Convert the logical plan represented by this DataFrame into a physical plan and
144     // execute it
collect_partitioned(&self) -> Result<Vec<Vec<RecordBatch>>>145     async fn collect_partitioned(&self) -> Result<Vec<Vec<RecordBatch>>> {
146         let state = self.ctx_state.lock().unwrap().clone();
147         let ctx = ExecutionContext::from(Arc::new(Mutex::new(state)));
148         let plan = ctx.optimize(&self.plan)?;
149         let plan = ctx.create_physical_plan(&plan)?;
150         Ok(collect_partitioned(plan).await?)
151     }
152 
153     /// Returns the schema from the logical plan
schema(&self) -> &DFSchema154     fn schema(&self) -> &DFSchema {
155         self.plan.schema()
156     }
157 
explain(&self, verbose: bool) -> Result<Arc<dyn DataFrame>>158     fn explain(&self, verbose: bool) -> Result<Arc<dyn DataFrame>> {
159         let plan = LogicalPlanBuilder::from(&self.plan)
160             .explain(verbose)?
161             .build()?;
162         Ok(Arc::new(DataFrameImpl::new(self.ctx_state.clone(), &plan)))
163     }
164 
registry(&self) -> Arc<dyn FunctionRegistry>165     fn registry(&self) -> Arc<dyn FunctionRegistry> {
166         let registry = self.ctx_state.lock().unwrap().clone();
167         Arc::new(registry)
168     }
169 }
170 
171 #[cfg(test)]
172 mod tests {
173     use super::*;
174     use crate::execution::context::ExecutionContext;
175     use crate::logical_plan::*;
176     use crate::{datasource::csv::CsvReadOptions, physical_plan::ColumnarValue};
177     use crate::{physical_plan::functions::ScalarFunctionImplementation, test};
178     use arrow::datatypes::DataType;
179 
180     #[test]
select_columns() -> Result<()>181     fn select_columns() -> Result<()> {
182         // build plan using Table API
183         let t = test_table()?;
184         let t2 = t.select_columns(&["c1", "c2", "c11"])?;
185         let plan = t2.to_logical_plan();
186 
187         // build query using SQL
188         let sql_plan = create_plan("SELECT c1, c2, c11 FROM aggregate_test_100")?;
189 
190         // the two plans should be identical
191         assert_same_plan(&plan, &sql_plan);
192 
193         Ok(())
194     }
195 
196     #[test]
select_expr() -> Result<()>197     fn select_expr() -> Result<()> {
198         // build plan using Table API
199         let t = test_table()?;
200         let t2 = t.select(&[col("c1"), col("c2"), col("c11")])?;
201         let plan = t2.to_logical_plan();
202 
203         // build query using SQL
204         let sql_plan = create_plan("SELECT c1, c2, c11 FROM aggregate_test_100")?;
205 
206         // the two plans should be identical
207         assert_same_plan(&plan, &sql_plan);
208 
209         Ok(())
210     }
211 
212     #[test]
aggregate() -> Result<()>213     fn aggregate() -> Result<()> {
214         // build plan using DataFrame API
215         let df = test_table()?;
216         let group_expr = &[col("c1")];
217         let aggr_expr = &[
218             min(col("c12")),
219             max(col("c12")),
220             avg(col("c12")),
221             sum(col("c12")),
222             count(col("c12")),
223             count_distinct(col("c12")),
224         ];
225 
226         let df = df.aggregate(group_expr, aggr_expr)?;
227 
228         let plan = df.to_logical_plan();
229 
230         // build same plan using SQL API
231         let sql = "SELECT c1, MIN(c12), MAX(c12), AVG(c12), SUM(c12), COUNT(c12), COUNT(DISTINCT c12) \
232                    FROM aggregate_test_100 \
233                    GROUP BY c1";
234         let sql_plan = create_plan(sql)?;
235 
236         // the two plans should be identical
237         assert_same_plan(&plan, &sql_plan);
238 
239         Ok(())
240     }
241 
242     #[tokio::test]
join() -> Result<()>243     async fn join() -> Result<()> {
244         let left = test_table()?.select_columns(&["c1", "c2"])?;
245         let right = test_table()?.select_columns(&["c1", "c3"])?;
246         let left_rows = left.collect().await?;
247         let right_rows = right.collect().await?;
248         let join = left.join(right, JoinType::Inner, &["c1"], &["c1"])?;
249         let join_rows = join.collect().await?;
250         assert_eq!(1, left_rows.len());
251         assert_eq!(100, left_rows[0].num_rows());
252         assert_eq!(1, right_rows.len());
253         assert_eq!(100, right_rows[0].num_rows());
254         assert_eq!(1, join_rows.len());
255         assert_eq!(2008, join_rows[0].num_rows());
256         Ok(())
257     }
258 
259     #[test]
limit() -> Result<()>260     fn limit() -> Result<()> {
261         // build query using Table API
262         let t = test_table()?;
263         let t2 = t.select_columns(&["c1", "c2", "c11"])?.limit(10)?;
264         let plan = t2.to_logical_plan();
265 
266         // build query using SQL
267         let sql_plan =
268             create_plan("SELECT c1, c2, c11 FROM aggregate_test_100 LIMIT 10")?;
269 
270         // the two plans should be identical
271         assert_same_plan(&plan, &sql_plan);
272 
273         Ok(())
274     }
275 
276     #[test]
explain() -> Result<()>277     fn explain() -> Result<()> {
278         // build query using Table API
279         let df = test_table()?;
280         let df = df
281             .select_columns(&["c1", "c2", "c11"])?
282             .limit(10)?
283             .explain(false)?;
284         let plan = df.to_logical_plan();
285 
286         // build query using SQL
287         let sql_plan =
288             create_plan("EXPLAIN SELECT c1, c2, c11 FROM aggregate_test_100 LIMIT 10")?;
289 
290         // the two plans should be identical
291         assert_same_plan(&plan, &sql_plan);
292 
293         Ok(())
294     }
295 
296     #[test]
registry() -> Result<()>297     fn registry() -> Result<()> {
298         let mut ctx = ExecutionContext::new();
299         register_aggregate_csv(&mut ctx)?;
300 
301         // declare the udf
302         let my_fn: ScalarFunctionImplementation =
303             Arc::new(|_: &[ColumnarValue]| unimplemented!("my_fn is not implemented"));
304 
305         // create and register the udf
306         ctx.register_udf(create_udf(
307             "my_fn",
308             vec![DataType::Float64],
309             Arc::new(DataType::Float64),
310             my_fn,
311         ));
312 
313         // build query with a UDF using DataFrame API
314         let df = ctx.table("aggregate_test_100")?;
315 
316         let f = df.registry();
317 
318         let df = df.select(&[f.udf("my_fn")?.call(vec![col("c12")])])?;
319         let plan = df.to_logical_plan();
320 
321         // build query using SQL
322         let sql_plan =
323             ctx.create_logical_plan("SELECT my_fn(c12) FROM aggregate_test_100")?;
324 
325         // the two plans should be identical
326         assert_same_plan(&plan, &sql_plan);
327 
328         Ok(())
329     }
330 
331     #[tokio::test]
sendable()332     async fn sendable() {
333         let df = test_table().unwrap();
334         // dataframes should be sendable between threads/tasks
335         let task = tokio::task::spawn(async move {
336             df.select_columns(&["c1"])
337                 .expect("should be usable in a task")
338         });
339         task.await.expect("task completed successfully");
340     }
341 
342     /// Compare the formatted string representation of two plans for equality
assert_same_plan(plan1: &LogicalPlan, plan2: &LogicalPlan)343     fn assert_same_plan(plan1: &LogicalPlan, plan2: &LogicalPlan) {
344         assert_eq!(format!("{:?}", plan1), format!("{:?}", plan2));
345     }
346 
347     /// Create a logical plan from a SQL query
create_plan(sql: &str) -> Result<LogicalPlan>348     fn create_plan(sql: &str) -> Result<LogicalPlan> {
349         let mut ctx = ExecutionContext::new();
350         register_aggregate_csv(&mut ctx)?;
351         ctx.create_logical_plan(sql)
352     }
353 
test_table() -> Result<Arc<dyn DataFrame + 'static>>354     fn test_table() -> Result<Arc<dyn DataFrame + 'static>> {
355         let mut ctx = ExecutionContext::new();
356         register_aggregate_csv(&mut ctx)?;
357         ctx.table("aggregate_test_100")
358     }
359 
register_aggregate_csv(ctx: &mut ExecutionContext) -> Result<()>360     fn register_aggregate_csv(ctx: &mut ExecutionContext) -> Result<()> {
361         let schema = test::aggr_test_schema();
362         let testdata = arrow::util::test_util::arrow_test_data();
363         ctx.register_csv(
364             "aggregate_test_100",
365             &format!("{}/csv/aggregate_test_100.csv", testdata),
366             CsvReadOptions::new().schema(&schema.as_ref()),
367         )?;
368         Ok(())
369     }
370 }
371