1 // Licensed to the Apache Software Foundation (ASF) under one 2 // or more contributor license agreements. See the NOTICE file 3 // distributed with this work for additional information 4 // regarding copyright ownership. The ASF licenses this file 5 // to you under the Apache License, Version 2.0 (the 6 // "License"); you may not use this file except in compliance 7 // with the License. You may obtain a copy of the License at 8 // 9 // http://www.apache.org/licenses/LICENSE-2.0 10 // 11 // Unless required by applicable law or agreed to in writing, 12 // software distributed under the License is distributed on an 13 // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14 // KIND, either express or implied. See the License for the 15 // specific language governing permissions and limitations 16 // under the License. 17 18 //! ExecutionContext contains methods for registering data sources and executing queries 19 20 use std::collections::HashMap; 21 use std::fs; 22 use std::path::Path; 23 use std::string::String; 24 use std::sync::Arc; 25 use std::thread::{self, JoinHandle}; 26 27 use arrow::csv; 28 use arrow::datatypes::*; 29 use arrow::record_batch::RecordBatch; 30 31 use crate::datasource::csv::CsvFile; 32 use crate::datasource::parquet::ParquetTable; 33 use crate::datasource::TableProvider; 34 use crate::error::{ExecutionError, Result}; 35 use crate::execution::physical_plan::common; 36 use crate::execution::physical_plan::csv::{CsvExec, CsvReadOptions}; 37 use crate::execution::physical_plan::datasource::DatasourceExec; 38 use crate::execution::physical_plan::expressions::{ 39 Alias, Avg, BinaryExpr, CastExpr, Column, Count, Literal, Max, Min, Sum, 40 }; 41 use crate::execution::physical_plan::hash_aggregate::HashAggregateExec; 42 use crate::execution::physical_plan::limit::LimitExec; 43 use crate::execution::physical_plan::math_expressions::register_math_functions; 44 use crate::execution::physical_plan::memory::MemoryExec; 45 use crate::execution::physical_plan::merge::MergeExec; 46 use crate::execution::physical_plan::parquet::ParquetExec; 47 use crate::execution::physical_plan::projection::ProjectionExec; 48 use crate::execution::physical_plan::selection::SelectionExec; 49 use crate::execution::physical_plan::udf::{ScalarFunction, ScalarFunctionExpr}; 50 use crate::execution::physical_plan::{AggregateExpr, ExecutionPlan, PhysicalExpr}; 51 use crate::execution::table_impl::TableImpl; 52 use crate::logicalplan::*; 53 use crate::optimizer::optimizer::OptimizerRule; 54 use crate::optimizer::projection_push_down::ProjectionPushDown; 55 use crate::optimizer::resolve_columns::ResolveColumnsRule; 56 use crate::optimizer::type_coercion::TypeCoercionRule; 57 use crate::sql::parser::{DFASTNode, DFParser, FileType}; 58 use crate::sql::planner::{SchemaProvider, SqlToRel}; 59 use crate::table::Table; 60 use sqlparser::sqlast::{SQLColumnDef, SQLType}; 61 62 /// Execution context for registering data sources and executing queries 63 pub struct ExecutionContext { 64 datasources: HashMap<String, Box<dyn TableProvider>>, 65 scalar_functions: HashMap<String, Box<ScalarFunction>>, 66 } 67 68 impl ExecutionContext { 69 /// Create a new execution context for in-memory queries new() -> Self70 pub fn new() -> Self { 71 let mut ctx = Self { 72 datasources: HashMap::new(), 73 scalar_functions: HashMap::new(), 74 }; 75 register_math_functions(&mut ctx); 76 ctx 77 } 78 79 /// Execute a SQL query and produce a Relation (a schema-aware iterator over a series 80 /// of RecordBatch instances) sql(&mut self, sql: &str, batch_size: usize) -> Result<Vec<RecordBatch>>81 pub fn sql(&mut self, sql: &str, batch_size: usize) -> Result<Vec<RecordBatch>> { 82 let plan = self.create_logical_plan(sql)?; 83 84 return self.collect_plan(&plan, batch_size); 85 } 86 87 /// Executes a logical plan and produce a Relation (a schema-aware iterator over a series 88 /// of RecordBatch instances) collect_plan( &mut self, plan: &LogicalPlan, batch_size: usize, ) -> Result<Vec<RecordBatch>>89 pub fn collect_plan( 90 &mut self, 91 plan: &LogicalPlan, 92 batch_size: usize, 93 ) -> Result<Vec<RecordBatch>> { 94 match plan { 95 LogicalPlan::CreateExternalTable { 96 ref schema, 97 ref name, 98 ref location, 99 ref file_type, 100 ref has_header, 101 } => match file_type { 102 FileType::CSV => { 103 self.register_csv( 104 name, 105 location, 106 CsvReadOptions::new() 107 .schema(&schema) 108 .has_header(*has_header), 109 )?; 110 Ok(vec![]) 111 } 112 FileType::Parquet => { 113 self.register_parquet(name, location)?; 114 Ok(vec![]) 115 } 116 _ => Err(ExecutionError::ExecutionError(format!( 117 "Unsupported file type {:?}.", 118 file_type 119 ))), 120 }, 121 122 plan => { 123 let plan = self.optimize(&plan)?; 124 let plan = self.create_physical_plan(&plan, batch_size)?; 125 Ok(self.collect(plan.as_ref())?) 126 } 127 } 128 } 129 130 /// Creates a logical plan create_logical_plan(&mut self, sql: &str) -> Result<LogicalPlan>131 pub fn create_logical_plan(&mut self, sql: &str) -> Result<LogicalPlan> { 132 let ast = DFParser::parse_sql(String::from(sql))?; 133 134 match ast { 135 DFASTNode::ANSI(ansi) => { 136 let schema_provider = ExecutionContextSchemaProvider { 137 datasources: &self.datasources, 138 scalar_functions: &self.scalar_functions, 139 }; 140 141 // create a query planner 142 let query_planner = SqlToRel::new(schema_provider); 143 144 // plan the query (create a logical relational plan) 145 let plan = query_planner.sql_to_rel(&ansi)?; 146 147 Ok(plan) 148 } 149 DFASTNode::CreateExternalTable { 150 name, 151 columns, 152 file_type, 153 has_header, 154 location, 155 } => { 156 let schema = Box::new(self.build_schema(columns)?); 157 158 Ok(LogicalPlan::CreateExternalTable { 159 schema, 160 name, 161 location, 162 file_type, 163 has_header, 164 }) 165 } 166 } 167 } 168 169 /// Register a scalar UDF register_udf(&mut self, f: ScalarFunction)170 pub fn register_udf(&mut self, f: ScalarFunction) { 171 self.scalar_functions.insert(f.name.clone(), Box::new(f)); 172 } 173 174 /// Get a reference to the registered scalar functions scalar_functions(&self) -> &HashMap<String, Box<ScalarFunction>>175 pub fn scalar_functions(&self) -> &HashMap<String, Box<ScalarFunction>> { 176 &self.scalar_functions 177 } 178 build_schema(&self, columns: Vec<SQLColumnDef>) -> Result<Schema>179 fn build_schema(&self, columns: Vec<SQLColumnDef>) -> Result<Schema> { 180 let mut fields = Vec::new(); 181 182 for column in columns { 183 let data_type = self.make_data_type(column.data_type)?; 184 fields.push(Field::new(&column.name, data_type, column.allow_null)); 185 } 186 187 Ok(Schema::new(fields)) 188 } 189 make_data_type(&self, sql_type: SQLType) -> Result<DataType>190 fn make_data_type(&self, sql_type: SQLType) -> Result<DataType> { 191 match sql_type { 192 SQLType::BigInt => Ok(DataType::Int64), 193 SQLType::Int => Ok(DataType::Int32), 194 SQLType::SmallInt => Ok(DataType::Int16), 195 SQLType::Char(_) | SQLType::Varchar(_) | SQLType::Text => Ok(DataType::Utf8), 196 SQLType::Decimal(_, _) => Ok(DataType::Float64), 197 SQLType::Float(_) => Ok(DataType::Float32), 198 SQLType::Real | SQLType::Double => Ok(DataType::Float64), 199 SQLType::Boolean => Ok(DataType::Boolean), 200 SQLType::Date => Ok(DataType::Date64(DateUnit::Day)), 201 SQLType::Time => Ok(DataType::Time64(TimeUnit::Millisecond)), 202 SQLType::Timestamp => Ok(DataType::Date64(DateUnit::Millisecond)), 203 SQLType::Uuid 204 | SQLType::Clob(_) 205 | SQLType::Binary(_) 206 | SQLType::Varbinary(_) 207 | SQLType::Blob(_) 208 | SQLType::Regclass 209 | SQLType::Bytea 210 | SQLType::Custom(_) 211 | SQLType::Array(_) => Err(ExecutionError::General(format!( 212 "Unsupported data type: {:?}.", 213 sql_type 214 ))), 215 } 216 } 217 218 /// Register a CSV file as a table so that it can be queried from SQL register_csv( &mut self, name: &str, filename: &str, options: CsvReadOptions, ) -> Result<()>219 pub fn register_csv( 220 &mut self, 221 name: &str, 222 filename: &str, 223 options: CsvReadOptions, 224 ) -> Result<()> { 225 self.register_table(name, Box::new(CsvFile::try_new(filename, options)?)); 226 Ok(()) 227 } 228 229 /// Register a Parquet file as a table so that it can be queried from SQL register_parquet(&mut self, name: &str, filename: &str) -> Result<()>230 pub fn register_parquet(&mut self, name: &str, filename: &str) -> Result<()> { 231 let table = ParquetTable::try_new(&filename)?; 232 self.register_table(name, Box::new(table)); 233 Ok(()) 234 } 235 236 /// Register a table so that it can be queried from SQL register_table(&mut self, name: &str, provider: Box<dyn TableProvider>)237 pub fn register_table(&mut self, name: &str, provider: Box<dyn TableProvider>) { 238 self.datasources.insert(name.to_string(), provider); 239 } 240 241 /// Get a table by name table(&mut self, table_name: &str) -> Result<Arc<dyn Table>>242 pub fn table(&mut self, table_name: &str) -> Result<Arc<dyn Table>> { 243 match self.datasources.get(table_name) { 244 Some(provider) => { 245 let schema = provider.schema().as_ref().clone(); 246 let table_scan = LogicalPlan::TableScan { 247 schema_name: "".to_string(), 248 table_name: table_name.to_string(), 249 table_schema: Box::new(schema.to_owned()), 250 projected_schema: Box::new(schema), 251 projection: None, 252 }; 253 Ok(Arc::new(TableImpl::new( 254 &LogicalPlanBuilder::from(&table_scan).build()?, 255 ))) 256 } 257 _ => Err(ExecutionError::General(format!( 258 "No table named '{}'", 259 table_name 260 ))), 261 } 262 } 263 264 /// Optimize the logical plan by applying optimizer rules optimize(&self, plan: &LogicalPlan) -> Result<LogicalPlan>265 pub fn optimize(&self, plan: &LogicalPlan) -> Result<LogicalPlan> { 266 let rules: Vec<Box<dyn OptimizerRule>> = vec![ 267 Box::new(ResolveColumnsRule::new()), 268 Box::new(ProjectionPushDown::new()), 269 Box::new(TypeCoercionRule::new(&self.scalar_functions)), 270 ]; 271 let mut plan = plan.clone(); 272 for mut rule in rules { 273 plan = rule.optimize(&plan)?; 274 } 275 Ok(plan) 276 } 277 278 /// Create a physical plan from a logical plan create_physical_plan( &mut self, logical_plan: &LogicalPlan, batch_size: usize, ) -> Result<Arc<dyn ExecutionPlan>>279 pub fn create_physical_plan( 280 &mut self, 281 logical_plan: &LogicalPlan, 282 batch_size: usize, 283 ) -> Result<Arc<dyn ExecutionPlan>> { 284 match logical_plan { 285 LogicalPlan::TableScan { 286 table_name, 287 projection, 288 .. 289 } => match self.datasources.get(table_name) { 290 Some(provider) => { 291 let partitions = provider.scan(projection, batch_size)?; 292 if partitions.is_empty() { 293 Err(ExecutionError::General( 294 "Table provider returned no partitions".to_string(), 295 )) 296 } else { 297 let partition = partitions[0].lock().unwrap(); 298 let schema = partition.schema(); 299 let exec = 300 DatasourceExec::new(schema.clone(), partitions.clone()); 301 Ok(Arc::new(exec)) 302 } 303 } 304 _ => Err(ExecutionError::General(format!( 305 "No table named {}", 306 table_name 307 ))), 308 }, 309 LogicalPlan::InMemoryScan { 310 data, 311 schema, 312 projection, 313 .. 314 } => Ok(Arc::new(MemoryExec::try_new( 315 data, 316 Arc::new(schema.as_ref().to_owned()), 317 projection.to_owned(), 318 )?)), 319 LogicalPlan::CsvScan { 320 path, 321 schema, 322 has_header, 323 delimiter, 324 projection, 325 .. 326 } => Ok(Arc::new(CsvExec::try_new( 327 path, 328 CsvReadOptions::new() 329 .schema(schema.as_ref()) 330 .delimiter_option(*delimiter) 331 .has_header(*has_header), 332 projection.to_owned(), 333 batch_size, 334 )?)), 335 LogicalPlan::ParquetScan { 336 path, projection, .. 337 } => Ok(Arc::new(ParquetExec::try_new( 338 path, 339 projection.to_owned(), 340 batch_size, 341 )?)), 342 LogicalPlan::Projection { input, expr, .. } => { 343 let input = self.create_physical_plan(input, batch_size)?; 344 let input_schema = input.as_ref().schema().clone(); 345 let runtime_expr = expr 346 .iter() 347 .map(|e| self.create_physical_expr(e, &input_schema)) 348 .collect::<Result<Vec<_>>>()?; 349 Ok(Arc::new(ProjectionExec::try_new(runtime_expr, input)?)) 350 } 351 LogicalPlan::Aggregate { 352 input, 353 group_expr, 354 aggr_expr, 355 .. 356 } => { 357 // Initially need to perform the aggregate and then merge the partitions 358 let input = self.create_physical_plan(input, batch_size)?; 359 let input_schema = input.as_ref().schema().clone(); 360 361 let group_expr = group_expr 362 .iter() 363 .map(|e| self.create_physical_expr(e, &input_schema)) 364 .collect::<Result<Vec<_>>>()?; 365 let aggr_expr = aggr_expr 366 .iter() 367 .map(|e| self.create_aggregate_expr(e, &input_schema)) 368 .collect::<Result<Vec<_>>>()?; 369 370 let initial_aggr = 371 HashAggregateExec::try_new(group_expr, aggr_expr, input)?; 372 373 let schema = initial_aggr.schema(); 374 let partitions = initial_aggr.partitions()?; 375 376 if partitions.len() == 1 { 377 return Ok(Arc::new(initial_aggr)); 378 } 379 380 let (final_group, final_aggr) = initial_aggr.make_final_expr(); 381 382 let merge = Arc::new(MergeExec::new(schema.clone(), partitions)); 383 384 Ok(Arc::new(HashAggregateExec::try_new( 385 final_group, 386 final_aggr, 387 merge, 388 )?)) 389 } 390 LogicalPlan::Selection { input, expr, .. } => { 391 let input = self.create_physical_plan(input, batch_size)?; 392 let input_schema = input.as_ref().schema().clone(); 393 let runtime_expr = self.create_physical_expr(expr, &input_schema)?; 394 Ok(Arc::new(SelectionExec::try_new(runtime_expr, input)?)) 395 } 396 LogicalPlan::Limit { input, expr, .. } => { 397 let input = self.create_physical_plan(input, batch_size)?; 398 let input_schema = input.as_ref().schema().clone(); 399 400 match expr { 401 &Expr::Literal(ref scalar_value) => { 402 let limit: usize = match scalar_value { 403 ScalarValue::Int8(limit) if *limit >= 0 => { 404 Ok(*limit as usize) 405 } 406 ScalarValue::Int16(limit) if *limit >= 0 => { 407 Ok(*limit as usize) 408 } 409 ScalarValue::Int32(limit) if *limit >= 0 => { 410 Ok(*limit as usize) 411 } 412 ScalarValue::Int64(limit) if *limit >= 0 => { 413 Ok(*limit as usize) 414 } 415 ScalarValue::UInt8(limit) => Ok(*limit as usize), 416 ScalarValue::UInt16(limit) => Ok(*limit as usize), 417 ScalarValue::UInt32(limit) => Ok(*limit as usize), 418 ScalarValue::UInt64(limit) => Ok(*limit as usize), 419 _ => Err(ExecutionError::ExecutionError( 420 "Limit only supports non-negative integer literals" 421 .to_string(), 422 )), 423 }?; 424 Ok(Arc::new(LimitExec::new( 425 input_schema.clone(), 426 input.partitions()?, 427 limit, 428 ))) 429 } 430 _ => Err(ExecutionError::ExecutionError( 431 "Limit only supports non-negative integer literals".to_string(), 432 )), 433 } 434 } 435 _ => Err(ExecutionError::General( 436 "Unsupported logical plan variant".to_string(), 437 )), 438 } 439 } 440 441 /// Create a physical expression from a logical expression create_physical_expr( &self, e: &Expr, input_schema: &Schema, ) -> Result<Arc<dyn PhysicalExpr>>442 pub fn create_physical_expr( 443 &self, 444 e: &Expr, 445 input_schema: &Schema, 446 ) -> Result<Arc<dyn PhysicalExpr>> { 447 match e { 448 Expr::Alias(expr, name) => { 449 let expr = self.create_physical_expr(expr, input_schema)?; 450 Ok(Arc::new(Alias::new(expr, &name))) 451 } 452 Expr::Column(i) => { 453 Ok(Arc::new(Column::new(*i, &input_schema.field(*i).name()))) 454 } 455 Expr::Literal(value) => Ok(Arc::new(Literal::new(value.clone()))), 456 Expr::BinaryExpr { left, op, right } => Ok(Arc::new(BinaryExpr::new( 457 self.create_physical_expr(left, input_schema)?, 458 op.clone(), 459 self.create_physical_expr(right, input_schema)?, 460 ))), 461 Expr::Cast { expr, data_type } => Ok(Arc::new(CastExpr::try_new( 462 self.create_physical_expr(expr, input_schema)?, 463 input_schema, 464 data_type.clone(), 465 )?)), 466 Expr::ScalarFunction { 467 name, 468 args, 469 return_type, 470 } => match &self.scalar_functions.get(name) { 471 Some(f) => { 472 let mut physical_args = vec![]; 473 for e in args { 474 physical_args.push(self.create_physical_expr(e, input_schema)?); 475 } 476 Ok(Arc::new(ScalarFunctionExpr::new( 477 name, 478 Box::new(f.fun.clone()), 479 physical_args, 480 return_type, 481 ))) 482 } 483 _ => Err(ExecutionError::General(format!( 484 "Invalid scalar function '{:?}'", 485 name 486 ))), 487 }, 488 other => Err(ExecutionError::NotImplemented(format!( 489 "Physical plan does not support logical expression {:?}", 490 other 491 ))), 492 } 493 } 494 495 /// Create an aggregate expression from a logical expression create_aggregate_expr( &self, e: &Expr, input_schema: &Schema, ) -> Result<Arc<dyn AggregateExpr>>496 pub fn create_aggregate_expr( 497 &self, 498 e: &Expr, 499 input_schema: &Schema, 500 ) -> Result<Arc<dyn AggregateExpr>> { 501 match e { 502 Expr::AggregateFunction { name, args, .. } => { 503 match name.to_lowercase().as_ref() { 504 "sum" => Ok(Arc::new(Sum::new( 505 self.create_physical_expr(&args[0], input_schema)?, 506 ))), 507 "avg" => Ok(Arc::new(Avg::new( 508 self.create_physical_expr(&args[0], input_schema)?, 509 ))), 510 "max" => Ok(Arc::new(Max::new( 511 self.create_physical_expr(&args[0], input_schema)?, 512 ))), 513 "min" => Ok(Arc::new(Min::new( 514 self.create_physical_expr(&args[0], input_schema)?, 515 ))), 516 "count" => Ok(Arc::new(Count::new( 517 self.create_physical_expr(&args[0], input_schema)?, 518 ))), 519 other => Err(ExecutionError::NotImplemented(format!( 520 "Unsupported aggregate function '{}'", 521 other 522 ))), 523 } 524 } 525 other => Err(ExecutionError::General(format!( 526 "Invalid aggregate expression '{:?}'", 527 other 528 ))), 529 } 530 } 531 532 /// Execute a physical plan and collect the results in memory collect(&self, plan: &dyn ExecutionPlan) -> Result<Vec<RecordBatch>>533 pub fn collect(&self, plan: &dyn ExecutionPlan) -> Result<Vec<RecordBatch>> { 534 let partitions = plan.partitions()?; 535 536 match partitions.len() { 537 0 => Ok(vec![]), 538 1 => { 539 let it = partitions[0].execute()?; 540 common::collect(it) 541 } 542 _ => { 543 // merge into a single partition 544 let plan = MergeExec::new(plan.schema().clone(), partitions); 545 let partitions = plan.partitions()?; 546 if partitions.len() == 1 { 547 common::collect(partitions[0].execute()?) 548 } else { 549 Err(ExecutionError::InternalError(format!( 550 "MergeExec returned {} partitions", 551 partitions.len() 552 ))) 553 } 554 } 555 } 556 } 557 558 /// Execute a query and write the results to a partitioned CSV file write_csv(&self, plan: &dyn ExecutionPlan, path: &str) -> Result<()>559 pub fn write_csv(&self, plan: &dyn ExecutionPlan, path: &str) -> Result<()> { 560 // create directory to contain the CSV files (one per partition) 561 let path = path.to_string(); 562 fs::create_dir(&path)?; 563 564 let threads: Vec<JoinHandle<Result<()>>> = plan 565 .partitions()? 566 .iter() 567 .enumerate() 568 .map(|(i, p)| { 569 let p = p.clone(); 570 let path = path.clone(); 571 thread::spawn(move || { 572 let filename = format!("part-{}.csv", i); 573 let path = Path::new(&path).join(&filename); 574 let file = fs::File::create(path)?; 575 let mut writer = csv::Writer::new(file); 576 let it = p.execute()?; 577 let mut it = it.lock().unwrap(); 578 loop { 579 match it.next() { 580 Ok(Some(batch)) => { 581 writer.write(&batch)?; 582 } 583 Ok(None) => break, 584 Err(e) => return Err(e), 585 } 586 } 587 Ok(()) 588 }) 589 }) 590 .collect(); 591 592 // combine the results from each thread 593 for thread in threads { 594 let join = thread.join().expect("Failed to join thread"); 595 join?; 596 } 597 598 Ok(()) 599 } 600 } 601 602 struct ExecutionContextSchemaProvider<'a> { 603 datasources: &'a HashMap<String, Box<dyn TableProvider>>, 604 scalar_functions: &'a HashMap<String, Box<ScalarFunction>>, 605 } 606 607 impl SchemaProvider for ExecutionContextSchemaProvider<'_> { get_table_meta(&self, name: &str) -> Option<Arc<Schema>>608 fn get_table_meta(&self, name: &str) -> Option<Arc<Schema>> { 609 self.datasources.get(name).map(|ds| ds.schema().clone()) 610 } 611 get_function_meta(&self, name: &str) -> Option<Arc<FunctionMeta>>612 fn get_function_meta(&self, name: &str) -> Option<Arc<FunctionMeta>> { 613 self.scalar_functions.get(name).map(|f| { 614 Arc::new(FunctionMeta::new( 615 name.to_owned(), 616 f.args.clone(), 617 f.return_type.clone(), 618 FunctionType::Scalar, 619 )) 620 }) 621 } 622 } 623 624 #[cfg(test)] 625 mod tests { 626 627 use super::*; 628 use crate::datasource::MemTable; 629 use crate::execution::physical_plan::udf::ScalarUdf; 630 use crate::test; 631 use arrow::array::{ArrayRef, Int32Array}; 632 use arrow::compute::add; 633 use std::fs::File; 634 use std::io::prelude::*; 635 use tempdir::TempDir; 636 637 #[test] parallel_projection() -> Result<()>638 fn parallel_projection() -> Result<()> { 639 let partition_count = 4; 640 let results = execute("SELECT c1, c2 FROM test", partition_count)?; 641 642 // there should be one batch per partition 643 assert_eq!(results.len(), partition_count); 644 645 // each batch should contain 2 columns and 10 rows 646 for batch in &results { 647 assert_eq!(batch.num_columns(), 2); 648 assert_eq!(batch.num_rows(), 10); 649 } 650 651 Ok(()) 652 } 653 654 #[test] parallel_selection() -> Result<()>655 fn parallel_selection() -> Result<()> { 656 let tmp_dir = TempDir::new("parallel_selection")?; 657 let partition_count = 4; 658 let mut ctx = create_ctx(&tmp_dir, partition_count)?; 659 660 let logical_plan = 661 ctx.create_logical_plan("SELECT c1, c2 FROM test WHERE c1 > 0 AND c1 < 3")?; 662 let logical_plan = ctx.optimize(&logical_plan)?; 663 664 let physical_plan = ctx.create_physical_plan(&logical_plan, 1024)?; 665 666 let results = ctx.collect(physical_plan.as_ref())?; 667 668 // there should be one batch per partition 669 assert_eq!(results.len(), partition_count); 670 671 let row_count: usize = results.iter().map(|batch| batch.num_rows()).sum(); 672 assert_eq!(row_count, 20); 673 674 Ok(()) 675 } 676 677 #[test] aggregate() -> Result<()>678 fn aggregate() -> Result<()> { 679 let results = execute("SELECT SUM(c1), SUM(c2) FROM test", 4)?; 680 assert_eq!(results.len(), 1); 681 682 let batch = &results[0]; 683 let expected: Vec<&str> = vec!["60,220"]; 684 let mut rows = test::format_batch(&batch); 685 rows.sort(); 686 assert_eq!(rows, expected); 687 688 Ok(()) 689 } 690 691 #[test] aggregate_avg() -> Result<()>692 fn aggregate_avg() -> Result<()> { 693 let results = execute("SELECT AVG(c1), AVG(c2) FROM test", 4)?; 694 assert_eq!(results.len(), 1); 695 696 let batch = &results[0]; 697 let expected: Vec<&str> = vec!["1.5,5.5"]; 698 let mut rows = test::format_batch(&batch); 699 rows.sort(); 700 assert_eq!(rows, expected); 701 702 Ok(()) 703 } 704 705 #[test] aggregate_max() -> Result<()>706 fn aggregate_max() -> Result<()> { 707 let results = execute("SELECT MAX(c1), MAX(c2) FROM test", 4)?; 708 assert_eq!(results.len(), 1); 709 710 let batch = &results[0]; 711 let expected: Vec<&str> = vec!["3,10"]; 712 let mut rows = test::format_batch(&batch); 713 rows.sort(); 714 assert_eq!(rows, expected); 715 716 Ok(()) 717 } 718 719 #[test] aggregate_min() -> Result<()>720 fn aggregate_min() -> Result<()> { 721 let results = execute("SELECT MIN(c1), MIN(c2) FROM test", 4)?; 722 assert_eq!(results.len(), 1); 723 724 let batch = &results[0]; 725 let expected: Vec<&str> = vec!["0,1"]; 726 let mut rows = test::format_batch(&batch); 727 rows.sort(); 728 assert_eq!(rows, expected); 729 730 Ok(()) 731 } 732 733 #[test] aggregate_grouped() -> Result<()>734 fn aggregate_grouped() -> Result<()> { 735 let results = execute("SELECT c1, SUM(c2) FROM test GROUP BY c1", 4)?; 736 assert_eq!(results.len(), 1); 737 738 let batch = &results[0]; 739 let expected: Vec<&str> = vec!["0,55", "1,55", "2,55", "3,55"]; 740 let mut rows = test::format_batch(&batch); 741 rows.sort(); 742 assert_eq!(rows, expected); 743 744 Ok(()) 745 } 746 747 #[test] aggregate_grouped_avg() -> Result<()>748 fn aggregate_grouped_avg() -> Result<()> { 749 let results = execute("SELECT c1, AVG(c2) FROM test GROUP BY c1", 4)?; 750 assert_eq!(results.len(), 1); 751 752 let batch = &results[0]; 753 let expected: Vec<&str> = vec!["0,5.5", "1,5.5", "2,5.5", "3,5.5"]; 754 let mut rows = test::format_batch(&batch); 755 rows.sort(); 756 assert_eq!(rows, expected); 757 758 Ok(()) 759 } 760 761 #[test] aggregate_grouped_max() -> Result<()>762 fn aggregate_grouped_max() -> Result<()> { 763 let results = execute("SELECT c1, MAX(c2) FROM test GROUP BY c1", 4)?; 764 assert_eq!(results.len(), 1); 765 766 let batch = &results[0]; 767 let expected: Vec<&str> = vec!["0,10", "1,10", "2,10", "3,10"]; 768 let mut rows = test::format_batch(&batch); 769 rows.sort(); 770 assert_eq!(rows, expected); 771 772 Ok(()) 773 } 774 775 #[test] aggregate_grouped_min() -> Result<()>776 fn aggregate_grouped_min() -> Result<()> { 777 let results = execute("SELECT c1, MIN(c2) FROM test GROUP BY c1", 4)?; 778 assert_eq!(results.len(), 1); 779 780 let batch = &results[0]; 781 let expected: Vec<&str> = vec!["0,1", "1,1", "2,1", "3,1"]; 782 let mut rows = test::format_batch(&batch); 783 rows.sort(); 784 assert_eq!(rows, expected); 785 786 Ok(()) 787 } 788 789 #[test] count_basic() -> Result<()>790 fn count_basic() -> Result<()> { 791 let results = execute("SELECT COUNT(c1), COUNT(c2) FROM test", 1)?; 792 assert_eq!(results.len(), 1); 793 794 let batch = &results[0]; 795 let expected: Vec<&str> = vec!["10,10"]; 796 let mut rows = test::format_batch(&batch); 797 rows.sort(); 798 assert_eq!(rows, expected); 799 Ok(()) 800 } 801 802 #[test] count_partitioned() -> Result<()>803 fn count_partitioned() -> Result<()> { 804 let results = execute("SELECT COUNT(c1), COUNT(c2) FROM test", 4)?; 805 assert_eq!(results.len(), 1); 806 807 let batch = &results[0]; 808 let expected: Vec<&str> = vec!["40,40"]; 809 let mut rows = test::format_batch(&batch); 810 rows.sort(); 811 assert_eq!(rows, expected); 812 Ok(()) 813 } 814 815 #[test] count_aggregated() -> Result<()>816 fn count_aggregated() -> Result<()> { 817 let results = execute("SELECT c1, COUNT(c2) FROM test GROUP BY c1", 4)?; 818 assert_eq!(results.len(), 1); 819 820 let batch = &results[0]; 821 let expected = vec!["0,10", "1,10", "2,10", "3,10"]; 822 let mut rows = test::format_batch(&batch); 823 rows.sort(); 824 assert_eq!(rows, expected); 825 Ok(()) 826 } 827 828 #[test] aggregate_with_alias() -> Result<()>829 fn aggregate_with_alias() -> Result<()> { 830 let tmp_dir = TempDir::new("execute")?; 831 let mut ctx = create_ctx(&tmp_dir, 1)?; 832 833 let schema = Arc::new(Schema::new(vec![ 834 Field::new("state", DataType::Utf8, false), 835 Field::new("salary", DataType::UInt32, false), 836 ])); 837 838 let plan = LogicalPlanBuilder::scan("default", "test", schema.as_ref(), None)? 839 .aggregate( 840 vec![col("state")], 841 vec![aggregate_expr("SUM", col("salary"), DataType::UInt32)], 842 )? 843 .project(vec![col("state"), col_index(1).alias("total_salary")])? 844 .build()?; 845 846 let plan = ctx.optimize(&plan)?; 847 848 let physical_plan = ctx.create_physical_plan(&Arc::new(plan), 1024)?; 849 assert_eq!("c1", physical_plan.schema().field(0).name().as_str()); 850 assert_eq!( 851 "total_salary", 852 physical_plan.schema().field(1).name().as_str() 853 ); 854 Ok(()) 855 } 856 857 #[test] write_csv_results() -> Result<()>858 fn write_csv_results() -> Result<()> { 859 // create partitioned input file and context 860 let tmp_dir = TempDir::new("write_csv_results_temp")?; 861 let mut ctx = create_ctx(&tmp_dir, 4)?; 862 863 // execute a simple query and write the results to CSV 864 let out_dir = tmp_dir.as_ref().to_str().unwrap().to_string() + "/out"; 865 write_csv(&mut ctx, "SELECT c1, c2 FROM test", &out_dir)?; 866 867 // create a new context and verify that the results were saved to a partitioned csv file 868 let mut ctx = ExecutionContext::new(); 869 870 let schema = Arc::new(Schema::new(vec![ 871 Field::new("c1", DataType::UInt32, false), 872 Field::new("c2", DataType::UInt64, false), 873 ])); 874 875 // register each partition as well as the top level dir 876 let csv_read_option = CsvReadOptions::new().schema(&schema); 877 ctx.register_csv("part0", &format!("{}/part-0.csv", out_dir), csv_read_option)?; 878 ctx.register_csv("part1", &format!("{}/part-1.csv", out_dir), csv_read_option)?; 879 ctx.register_csv("part2", &format!("{}/part-2.csv", out_dir), csv_read_option)?; 880 ctx.register_csv("part3", &format!("{}/part-3.csv", out_dir), csv_read_option)?; 881 ctx.register_csv("allparts", &out_dir, csv_read_option)?; 882 883 let part0 = collect(&mut ctx, "SELECT c1, c2 FROM part0")?; 884 let part1 = collect(&mut ctx, "SELECT c1, c2 FROM part1")?; 885 let part2 = collect(&mut ctx, "SELECT c1, c2 FROM part2")?; 886 let part3 = collect(&mut ctx, "SELECT c1, c2 FROM part3")?; 887 let allparts = collect(&mut ctx, "SELECT c1, c2 FROM allparts")?; 888 889 let part0_count: usize = part0.iter().map(|batch| batch.num_rows()).sum(); 890 let part1_count: usize = part1.iter().map(|batch| batch.num_rows()).sum(); 891 let part2_count: usize = part2.iter().map(|batch| batch.num_rows()).sum(); 892 let part3_count: usize = part3.iter().map(|batch| batch.num_rows()).sum(); 893 let allparts_count: usize = allparts.iter().map(|batch| batch.num_rows()).sum(); 894 895 assert_eq!(part0_count, 10); 896 assert_eq!(part1_count, 10); 897 assert_eq!(part2_count, 10); 898 assert_eq!(part3_count, 10); 899 assert_eq!(allparts_count, 40); 900 901 Ok(()) 902 } 903 904 #[test] scalar_udf() -> Result<()>905 fn scalar_udf() -> Result<()> { 906 let schema = Schema::new(vec![ 907 Field::new("a", DataType::Int32, false), 908 Field::new("b", DataType::Int32, false), 909 ]); 910 911 let batch = RecordBatch::try_new( 912 Arc::new(schema.clone()), 913 vec![ 914 Arc::new(Int32Array::from(vec![1, 10, 10, 100])), 915 Arc::new(Int32Array::from(vec![2, 12, 12, 120])), 916 ], 917 )?; 918 919 let mut ctx = ExecutionContext::new(); 920 921 let provider = MemTable::new(Arc::new(schema), vec![vec![batch]])?; 922 ctx.register_table("t", Box::new(provider)); 923 924 let myfunc: ScalarUdf = |args: &[ArrayRef]| { 925 let l = &args[0] 926 .as_any() 927 .downcast_ref::<Int32Array>() 928 .expect("cast failed"); 929 let r = &args[1] 930 .as_any() 931 .downcast_ref::<Int32Array>() 932 .expect("cast failed"); 933 Ok(Arc::new(add(l, r)?)) 934 }; 935 936 let my_add = ScalarFunction::new( 937 "my_add", 938 vec![ 939 Field::new("a", DataType::Int32, true), 940 Field::new("b", DataType::Int32, true), 941 ], 942 DataType::Int32, 943 myfunc, 944 ); 945 946 ctx.register_udf(my_add); 947 948 let t = ctx.table("t")?; 949 950 let plan = LogicalPlanBuilder::from(&t.to_logical_plan()) 951 .project(vec![ 952 col("a"), 953 col("b"), 954 scalar_function("my_add", vec![col("a"), col("b")], DataType::Int32), 955 ])? 956 .build()?; 957 958 assert_eq!( 959 format!("{:?}", plan), 960 "Projection: #a, #b, my_add(#a, #b)\n TableScan: t projection=None" 961 ); 962 963 let plan = ctx.optimize(&plan)?; 964 let plan = ctx.create_physical_plan(&plan, 1024)?; 965 let result = ctx.collect(plan.as_ref())?; 966 967 let batch = &result[0]; 968 assert_eq!(3, batch.num_columns()); 969 assert_eq!(4, batch.num_rows()); 970 971 let a = batch 972 .column(0) 973 .as_any() 974 .downcast_ref::<Int32Array>() 975 .expect("failed to cast a"); 976 let b = batch 977 .column(1) 978 .as_any() 979 .downcast_ref::<Int32Array>() 980 .expect("failed to cast b"); 981 let sum = batch 982 .column(2) 983 .as_any() 984 .downcast_ref::<Int32Array>() 985 .expect("failed to cast sum"); 986 987 assert_eq!(4, a.len()); 988 assert_eq!(4, b.len()); 989 assert_eq!(4, sum.len()); 990 for i in 0..sum.len() { 991 assert_eq!(a.value(i) + b.value(i), sum.value(i)); 992 } 993 994 Ok(()) 995 } 996 997 /// Execute SQL and return results collect(ctx: &mut ExecutionContext, sql: &str) -> Result<Vec<RecordBatch>>998 fn collect(ctx: &mut ExecutionContext, sql: &str) -> Result<Vec<RecordBatch>> { 999 let logical_plan = ctx.create_logical_plan(sql)?; 1000 let logical_plan = ctx.optimize(&logical_plan)?; 1001 let physical_plan = ctx.create_physical_plan(&logical_plan, 1024)?; 1002 ctx.collect(physical_plan.as_ref()) 1003 } 1004 1005 /// Execute SQL and return results execute(sql: &str, partition_count: usize) -> Result<Vec<RecordBatch>>1006 fn execute(sql: &str, partition_count: usize) -> Result<Vec<RecordBatch>> { 1007 let tmp_dir = TempDir::new("execute")?; 1008 let mut ctx = create_ctx(&tmp_dir, partition_count)?; 1009 collect(&mut ctx, sql) 1010 } 1011 1012 /// Execute SQL and write results to partitioned csv files write_csv(ctx: &mut ExecutionContext, sql: &str, out_dir: &str) -> Result<()>1013 fn write_csv(ctx: &mut ExecutionContext, sql: &str, out_dir: &str) -> Result<()> { 1014 let logical_plan = ctx.create_logical_plan(sql)?; 1015 let logical_plan = ctx.optimize(&logical_plan)?; 1016 let physical_plan = ctx.create_physical_plan(&logical_plan, 1024)?; 1017 ctx.write_csv(physical_plan.as_ref(), out_dir) 1018 } 1019 1020 /// Generate a partitioned CSV file and register it with an execution context create_ctx(tmp_dir: &TempDir, partition_count: usize) -> Result<ExecutionContext>1021 fn create_ctx(tmp_dir: &TempDir, partition_count: usize) -> Result<ExecutionContext> { 1022 let mut ctx = ExecutionContext::new(); 1023 1024 // define schema for data source (csv file) 1025 let schema = Arc::new(Schema::new(vec![ 1026 Field::new("c1", DataType::UInt32, false), 1027 Field::new("c2", DataType::UInt64, false), 1028 ])); 1029 1030 // generate a partitioned file 1031 for partition in 0..partition_count { 1032 let filename = format!("partition-{}.csv", partition); 1033 let file_path = tmp_dir.path().join(&filename); 1034 let mut file = File::create(file_path)?; 1035 1036 // generate some data 1037 for i in 0..=10 { 1038 let data = format!("{},{}\n", partition, i); 1039 file.write_all(data.as_bytes())?; 1040 } 1041 } 1042 1043 // register csv file with the execution context 1044 ctx.register_csv( 1045 "test", 1046 tmp_dir.path().to_str().unwrap(), 1047 CsvReadOptions::new().schema(&schema), 1048 )?; 1049 1050 Ok(ctx) 1051 } 1052 } 1053