1 // Licensed to the Apache Software Foundation (ASF) under one 2 // or more contributor license agreements. See the NOTICE file 3 // distributed with this work for additional information 4 // regarding copyright ownership. The ASF licenses this file 5 // to you under the Apache License, Version 2.0 (the 6 // "License"); you may not use this file except in compliance 7 // with the License. You may obtain a copy of the License at 8 // 9 // http://www.apache.org/licenses/LICENSE-2.0 10 // 11 // Unless required by applicable law or agreed to in writing, 12 // software distributed under the License is distributed on an 13 // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14 // KIND, either express or implied. See the License for the 15 // specific language governing permissions and limitations 16 // under the License. 17 18 //! Implementation of DataFrame API 19 20 use std::sync::{Arc, Mutex}; 21 22 use crate::arrow::record_batch::RecordBatch; 23 use crate::error::Result; 24 use crate::execution::context::{ExecutionContext, ExecutionContextState}; 25 use crate::logical_plan::{ 26 col, DFSchema, Expr, FunctionRegistry, JoinType, LogicalPlan, LogicalPlanBuilder, 27 Partitioning, 28 }; 29 use crate::{ 30 dataframe::*, 31 physical_plan::{collect, collect_partitioned}, 32 }; 33 34 use async_trait::async_trait; 35 36 /// Implementation of DataFrame API 37 pub struct DataFrameImpl { 38 ctx_state: Arc<Mutex<ExecutionContextState>>, 39 plan: LogicalPlan, 40 } 41 42 impl DataFrameImpl { 43 /// Create a new Table based on an existing logical plan new(ctx_state: Arc<Mutex<ExecutionContextState>>, plan: &LogicalPlan) -> Self44 pub fn new(ctx_state: Arc<Mutex<ExecutionContextState>>, plan: &LogicalPlan) -> Self { 45 Self { 46 ctx_state, 47 plan: plan.clone(), 48 } 49 } 50 } 51 52 #[async_trait] 53 impl DataFrame for DataFrameImpl { 54 /// Apply a projection based on a list of column names select_columns(&self, columns: &[&str]) -> Result<Arc<dyn DataFrame>>55 fn select_columns(&self, columns: &[&str]) -> Result<Arc<dyn DataFrame>> { 56 let fields = columns 57 .iter() 58 .map(|name| self.plan.schema().field_with_unqualified_name(name)) 59 .collect::<Result<Vec<_>>>()?; 60 let expr: Vec<Expr> = fields.iter().map(|f| col(f.name())).collect(); 61 self.select(&expr) 62 } 63 64 /// Create a projection based on arbitrary expressions select(&self, expr_list: &[Expr]) -> Result<Arc<dyn DataFrame>>65 fn select(&self, expr_list: &[Expr]) -> Result<Arc<dyn DataFrame>> { 66 let plan = LogicalPlanBuilder::from(&self.plan) 67 .project(expr_list)? 68 .build()?; 69 Ok(Arc::new(DataFrameImpl::new(self.ctx_state.clone(), &plan))) 70 } 71 72 /// Create a filter based on a predicate expression filter(&self, predicate: Expr) -> Result<Arc<dyn DataFrame>>73 fn filter(&self, predicate: Expr) -> Result<Arc<dyn DataFrame>> { 74 let plan = LogicalPlanBuilder::from(&self.plan) 75 .filter(predicate)? 76 .build()?; 77 Ok(Arc::new(DataFrameImpl::new(self.ctx_state.clone(), &plan))) 78 } 79 80 /// Perform an aggregate query aggregate( &self, group_expr: &[Expr], aggr_expr: &[Expr], ) -> Result<Arc<dyn DataFrame>>81 fn aggregate( 82 &self, 83 group_expr: &[Expr], 84 aggr_expr: &[Expr], 85 ) -> Result<Arc<dyn DataFrame>> { 86 let plan = LogicalPlanBuilder::from(&self.plan) 87 .aggregate(group_expr, aggr_expr)? 88 .build()?; 89 Ok(Arc::new(DataFrameImpl::new(self.ctx_state.clone(), &plan))) 90 } 91 92 /// Limit the number of rows limit(&self, n: usize) -> Result<Arc<dyn DataFrame>>93 fn limit(&self, n: usize) -> Result<Arc<dyn DataFrame>> { 94 let plan = LogicalPlanBuilder::from(&self.plan).limit(n)?.build()?; 95 Ok(Arc::new(DataFrameImpl::new(self.ctx_state.clone(), &plan))) 96 } 97 98 /// Sort by specified sorting expressions sort(&self, expr: &[Expr]) -> Result<Arc<dyn DataFrame>>99 fn sort(&self, expr: &[Expr]) -> Result<Arc<dyn DataFrame>> { 100 let plan = LogicalPlanBuilder::from(&self.plan).sort(expr)?.build()?; 101 Ok(Arc::new(DataFrameImpl::new(self.ctx_state.clone(), &plan))) 102 } 103 104 /// Join with another DataFrame join( &self, right: Arc<dyn DataFrame>, join_type: JoinType, left_cols: &[&str], right_cols: &[&str], ) -> Result<Arc<dyn DataFrame>>105 fn join( 106 &self, 107 right: Arc<dyn DataFrame>, 108 join_type: JoinType, 109 left_cols: &[&str], 110 right_cols: &[&str], 111 ) -> Result<Arc<dyn DataFrame>> { 112 let plan = LogicalPlanBuilder::from(&self.plan) 113 .join(&right.to_logical_plan(), join_type, left_cols, right_cols)? 114 .build()?; 115 Ok(Arc::new(DataFrameImpl::new(self.ctx_state.clone(), &plan))) 116 } 117 repartition( &self, partitioning_scheme: Partitioning, ) -> Result<Arc<dyn DataFrame>>118 fn repartition( 119 &self, 120 partitioning_scheme: Partitioning, 121 ) -> Result<Arc<dyn DataFrame>> { 122 let plan = LogicalPlanBuilder::from(&self.plan) 123 .repartition(partitioning_scheme)? 124 .build()?; 125 Ok(Arc::new(DataFrameImpl::new(self.ctx_state.clone(), &plan))) 126 } 127 128 /// Convert to logical plan to_logical_plan(&self) -> LogicalPlan129 fn to_logical_plan(&self) -> LogicalPlan { 130 self.plan.clone() 131 } 132 133 // Convert the logical plan represented by this DataFrame into a physical plan and 134 // execute it collect(&self) -> Result<Vec<RecordBatch>>135 async fn collect(&self) -> Result<Vec<RecordBatch>> { 136 let state = self.ctx_state.lock().unwrap().clone(); 137 let ctx = ExecutionContext::from(Arc::new(Mutex::new(state))); 138 let plan = ctx.optimize(&self.plan)?; 139 let plan = ctx.create_physical_plan(&plan)?; 140 Ok(collect(plan).await?) 141 } 142 143 // Convert the logical plan represented by this DataFrame into a physical plan and 144 // execute it collect_partitioned(&self) -> Result<Vec<Vec<RecordBatch>>>145 async fn collect_partitioned(&self) -> Result<Vec<Vec<RecordBatch>>> { 146 let state = self.ctx_state.lock().unwrap().clone(); 147 let ctx = ExecutionContext::from(Arc::new(Mutex::new(state))); 148 let plan = ctx.optimize(&self.plan)?; 149 let plan = ctx.create_physical_plan(&plan)?; 150 Ok(collect_partitioned(plan).await?) 151 } 152 153 /// Returns the schema from the logical plan schema(&self) -> &DFSchema154 fn schema(&self) -> &DFSchema { 155 self.plan.schema() 156 } 157 explain(&self, verbose: bool) -> Result<Arc<dyn DataFrame>>158 fn explain(&self, verbose: bool) -> Result<Arc<dyn DataFrame>> { 159 let plan = LogicalPlanBuilder::from(&self.plan) 160 .explain(verbose)? 161 .build()?; 162 Ok(Arc::new(DataFrameImpl::new(self.ctx_state.clone(), &plan))) 163 } 164 registry(&self) -> Arc<dyn FunctionRegistry>165 fn registry(&self) -> Arc<dyn FunctionRegistry> { 166 let registry = self.ctx_state.lock().unwrap().clone(); 167 Arc::new(registry) 168 } 169 } 170 171 #[cfg(test)] 172 mod tests { 173 use super::*; 174 use crate::execution::context::ExecutionContext; 175 use crate::logical_plan::*; 176 use crate::{datasource::csv::CsvReadOptions, physical_plan::ColumnarValue}; 177 use crate::{physical_plan::functions::ScalarFunctionImplementation, test}; 178 use arrow::datatypes::DataType; 179 180 #[test] select_columns() -> Result<()>181 fn select_columns() -> Result<()> { 182 // build plan using Table API 183 let t = test_table()?; 184 let t2 = t.select_columns(&["c1", "c2", "c11"])?; 185 let plan = t2.to_logical_plan(); 186 187 // build query using SQL 188 let sql_plan = create_plan("SELECT c1, c2, c11 FROM aggregate_test_100")?; 189 190 // the two plans should be identical 191 assert_same_plan(&plan, &sql_plan); 192 193 Ok(()) 194 } 195 196 #[test] select_expr() -> Result<()>197 fn select_expr() -> Result<()> { 198 // build plan using Table API 199 let t = test_table()?; 200 let t2 = t.select(&[col("c1"), col("c2"), col("c11")])?; 201 let plan = t2.to_logical_plan(); 202 203 // build query using SQL 204 let sql_plan = create_plan("SELECT c1, c2, c11 FROM aggregate_test_100")?; 205 206 // the two plans should be identical 207 assert_same_plan(&plan, &sql_plan); 208 209 Ok(()) 210 } 211 212 #[test] aggregate() -> Result<()>213 fn aggregate() -> Result<()> { 214 // build plan using DataFrame API 215 let df = test_table()?; 216 let group_expr = &[col("c1")]; 217 let aggr_expr = &[ 218 min(col("c12")), 219 max(col("c12")), 220 avg(col("c12")), 221 sum(col("c12")), 222 count(col("c12")), 223 count_distinct(col("c12")), 224 ]; 225 226 let df = df.aggregate(group_expr, aggr_expr)?; 227 228 let plan = df.to_logical_plan(); 229 230 // build same plan using SQL API 231 let sql = "SELECT c1, MIN(c12), MAX(c12), AVG(c12), SUM(c12), COUNT(c12), COUNT(DISTINCT c12) \ 232 FROM aggregate_test_100 \ 233 GROUP BY c1"; 234 let sql_plan = create_plan(sql)?; 235 236 // the two plans should be identical 237 assert_same_plan(&plan, &sql_plan); 238 239 Ok(()) 240 } 241 242 #[tokio::test] join() -> Result<()>243 async fn join() -> Result<()> { 244 let left = test_table()?.select_columns(&["c1", "c2"])?; 245 let right = test_table()?.select_columns(&["c1", "c3"])?; 246 let left_rows = left.collect().await?; 247 let right_rows = right.collect().await?; 248 let join = left.join(right, JoinType::Inner, &["c1"], &["c1"])?; 249 let join_rows = join.collect().await?; 250 assert_eq!(1, left_rows.len()); 251 assert_eq!(100, left_rows[0].num_rows()); 252 assert_eq!(1, right_rows.len()); 253 assert_eq!(100, right_rows[0].num_rows()); 254 assert_eq!(1, join_rows.len()); 255 assert_eq!(2008, join_rows[0].num_rows()); 256 Ok(()) 257 } 258 259 #[test] limit() -> Result<()>260 fn limit() -> Result<()> { 261 // build query using Table API 262 let t = test_table()?; 263 let t2 = t.select_columns(&["c1", "c2", "c11"])?.limit(10)?; 264 let plan = t2.to_logical_plan(); 265 266 // build query using SQL 267 let sql_plan = 268 create_plan("SELECT c1, c2, c11 FROM aggregate_test_100 LIMIT 10")?; 269 270 // the two plans should be identical 271 assert_same_plan(&plan, &sql_plan); 272 273 Ok(()) 274 } 275 276 #[test] explain() -> Result<()>277 fn explain() -> Result<()> { 278 // build query using Table API 279 let df = test_table()?; 280 let df = df 281 .select_columns(&["c1", "c2", "c11"])? 282 .limit(10)? 283 .explain(false)?; 284 let plan = df.to_logical_plan(); 285 286 // build query using SQL 287 let sql_plan = 288 create_plan("EXPLAIN SELECT c1, c2, c11 FROM aggregate_test_100 LIMIT 10")?; 289 290 // the two plans should be identical 291 assert_same_plan(&plan, &sql_plan); 292 293 Ok(()) 294 } 295 296 #[test] registry() -> Result<()>297 fn registry() -> Result<()> { 298 let mut ctx = ExecutionContext::new(); 299 register_aggregate_csv(&mut ctx)?; 300 301 // declare the udf 302 let my_fn: ScalarFunctionImplementation = 303 Arc::new(|_: &[ColumnarValue]| unimplemented!("my_fn is not implemented")); 304 305 // create and register the udf 306 ctx.register_udf(create_udf( 307 "my_fn", 308 vec![DataType::Float64], 309 Arc::new(DataType::Float64), 310 my_fn, 311 )); 312 313 // build query with a UDF using DataFrame API 314 let df = ctx.table("aggregate_test_100")?; 315 316 let f = df.registry(); 317 318 let df = df.select(&[f.udf("my_fn")?.call(vec![col("c12")])])?; 319 let plan = df.to_logical_plan(); 320 321 // build query using SQL 322 let sql_plan = 323 ctx.create_logical_plan("SELECT my_fn(c12) FROM aggregate_test_100")?; 324 325 // the two plans should be identical 326 assert_same_plan(&plan, &sql_plan); 327 328 Ok(()) 329 } 330 331 #[tokio::test] sendable()332 async fn sendable() { 333 let df = test_table().unwrap(); 334 // dataframes should be sendable between threads/tasks 335 let task = tokio::task::spawn(async move { 336 df.select_columns(&["c1"]) 337 .expect("should be usable in a task") 338 }); 339 task.await.expect("task completed successfully"); 340 } 341 342 /// Compare the formatted string representation of two plans for equality assert_same_plan(plan1: &LogicalPlan, plan2: &LogicalPlan)343 fn assert_same_plan(plan1: &LogicalPlan, plan2: &LogicalPlan) { 344 assert_eq!(format!("{:?}", plan1), format!("{:?}", plan2)); 345 } 346 347 /// Create a logical plan from a SQL query create_plan(sql: &str) -> Result<LogicalPlan>348 fn create_plan(sql: &str) -> Result<LogicalPlan> { 349 let mut ctx = ExecutionContext::new(); 350 register_aggregate_csv(&mut ctx)?; 351 ctx.create_logical_plan(sql) 352 } 353 test_table() -> Result<Arc<dyn DataFrame + 'static>>354 fn test_table() -> Result<Arc<dyn DataFrame + 'static>> { 355 let mut ctx = ExecutionContext::new(); 356 register_aggregate_csv(&mut ctx)?; 357 ctx.table("aggregate_test_100") 358 } 359 register_aggregate_csv(ctx: &mut ExecutionContext) -> Result<()>360 fn register_aggregate_csv(ctx: &mut ExecutionContext) -> Result<()> { 361 let schema = test::aggr_test_schema(); 362 let testdata = arrow::util::test_util::arrow_test_data(); 363 ctx.register_csv( 364 "aggregate_test_100", 365 &format!("{}/csv/aggregate_test_100.csv", testdata), 366 CsvReadOptions::new().schema(&schema.as_ref()), 367 )?; 368 Ok(()) 369 } 370 } 371