1 // Licensed to the Apache Software Foundation (ASF) under one
2 // or more contributor license agreements.  See the NOTICE file
3 // distributed with this work for additional information
4 // regarding copyright ownership.  The ASF licenses this file
5 // to you under the Apache License, Version 2.0 (the
6 // "License"); you may not use this file except in compliance
7 // with the License.  You may obtain a copy of the License at
8 //
9 //   http://www.apache.org/licenses/LICENSE-2.0
10 //
11 // Unless required by applicable law or agreed to in writing,
12 // software distributed under the License is distributed on an
13 // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14 // KIND, either express or implied.  See the License for the
15 // specific language governing permissions and limitations
16 // under the License.
17 
18 //! ExecutionContext contains methods for registering data sources and executing queries
19 
20 use std::collections::HashMap;
21 use std::fs;
22 use std::path::Path;
23 use std::string::String;
24 use std::sync::Arc;
25 use std::thread::{self, JoinHandle};
26 
27 use arrow::csv;
28 use arrow::datatypes::*;
29 use arrow::record_batch::RecordBatch;
30 
31 use crate::datasource::csv::CsvFile;
32 use crate::datasource::parquet::ParquetTable;
33 use crate::datasource::TableProvider;
34 use crate::error::{ExecutionError, Result};
35 use crate::execution::physical_plan::common;
36 use crate::execution::physical_plan::csv::{CsvExec, CsvReadOptions};
37 use crate::execution::physical_plan::datasource::DatasourceExec;
38 use crate::execution::physical_plan::expressions::{
39     Alias, Avg, BinaryExpr, CastExpr, Column, Count, Literal, Max, Min, Sum,
40 };
41 use crate::execution::physical_plan::hash_aggregate::HashAggregateExec;
42 use crate::execution::physical_plan::limit::LimitExec;
43 use crate::execution::physical_plan::math_expressions::register_math_functions;
44 use crate::execution::physical_plan::memory::MemoryExec;
45 use crate::execution::physical_plan::merge::MergeExec;
46 use crate::execution::physical_plan::parquet::ParquetExec;
47 use crate::execution::physical_plan::projection::ProjectionExec;
48 use crate::execution::physical_plan::selection::SelectionExec;
49 use crate::execution::physical_plan::udf::{ScalarFunction, ScalarFunctionExpr};
50 use crate::execution::physical_plan::{AggregateExpr, ExecutionPlan, PhysicalExpr};
51 use crate::execution::table_impl::TableImpl;
52 use crate::logicalplan::*;
53 use crate::optimizer::optimizer::OptimizerRule;
54 use crate::optimizer::projection_push_down::ProjectionPushDown;
55 use crate::optimizer::resolve_columns::ResolveColumnsRule;
56 use crate::optimizer::type_coercion::TypeCoercionRule;
57 use crate::sql::parser::{DFASTNode, DFParser, FileType};
58 use crate::sql::planner::{SchemaProvider, SqlToRel};
59 use crate::table::Table;
60 use sqlparser::sqlast::{SQLColumnDef, SQLType};
61 
62 /// Execution context for registering data sources and executing queries
63 pub struct ExecutionContext {
64     datasources: HashMap<String, Box<dyn TableProvider>>,
65     scalar_functions: HashMap<String, Box<ScalarFunction>>,
66 }
67 
68 impl ExecutionContext {
69     /// Create a new execution context for in-memory queries
new() -> Self70     pub fn new() -> Self {
71         let mut ctx = Self {
72             datasources: HashMap::new(),
73             scalar_functions: HashMap::new(),
74         };
75         register_math_functions(&mut ctx);
76         ctx
77     }
78 
79     /// Execute a SQL query and produce a Relation (a schema-aware iterator over a series
80     /// of RecordBatch instances)
sql(&mut self, sql: &str, batch_size: usize) -> Result<Vec<RecordBatch>>81     pub fn sql(&mut self, sql: &str, batch_size: usize) -> Result<Vec<RecordBatch>> {
82         let plan = self.create_logical_plan(sql)?;
83 
84         return self.collect_plan(&plan, batch_size);
85     }
86 
87     /// Executes a logical plan and produce a Relation (a schema-aware iterator over a series
88     /// of RecordBatch instances)
collect_plan( &mut self, plan: &LogicalPlan, batch_size: usize, ) -> Result<Vec<RecordBatch>>89     pub fn collect_plan(
90         &mut self,
91         plan: &LogicalPlan,
92         batch_size: usize,
93     ) -> Result<Vec<RecordBatch>> {
94         match plan {
95             LogicalPlan::CreateExternalTable {
96                 ref schema,
97                 ref name,
98                 ref location,
99                 ref file_type,
100                 ref has_header,
101             } => match file_type {
102                 FileType::CSV => {
103                     self.register_csv(
104                         name,
105                         location,
106                         CsvReadOptions::new()
107                             .schema(&schema)
108                             .has_header(*has_header),
109                     )?;
110                     Ok(vec![])
111                 }
112                 FileType::Parquet => {
113                     self.register_parquet(name, location)?;
114                     Ok(vec![])
115                 }
116                 _ => Err(ExecutionError::ExecutionError(format!(
117                     "Unsupported file type {:?}.",
118                     file_type
119                 ))),
120             },
121 
122             plan => {
123                 let plan = self.optimize(&plan)?;
124                 let plan = self.create_physical_plan(&plan, batch_size)?;
125                 Ok(self.collect(plan.as_ref())?)
126             }
127         }
128     }
129 
130     /// Creates a logical plan
create_logical_plan(&mut self, sql: &str) -> Result<LogicalPlan>131     pub fn create_logical_plan(&mut self, sql: &str) -> Result<LogicalPlan> {
132         let ast = DFParser::parse_sql(String::from(sql))?;
133 
134         match ast {
135             DFASTNode::ANSI(ansi) => {
136                 let schema_provider = ExecutionContextSchemaProvider {
137                     datasources: &self.datasources,
138                     scalar_functions: &self.scalar_functions,
139                 };
140 
141                 // create a query planner
142                 let query_planner = SqlToRel::new(schema_provider);
143 
144                 // plan the query (create a logical relational plan)
145                 let plan = query_planner.sql_to_rel(&ansi)?;
146 
147                 Ok(plan)
148             }
149             DFASTNode::CreateExternalTable {
150                 name,
151                 columns,
152                 file_type,
153                 has_header,
154                 location,
155             } => {
156                 let schema = Box::new(self.build_schema(columns)?);
157 
158                 Ok(LogicalPlan::CreateExternalTable {
159                     schema,
160                     name,
161                     location,
162                     file_type,
163                     has_header,
164                 })
165             }
166         }
167     }
168 
169     /// Register a scalar UDF
register_udf(&mut self, f: ScalarFunction)170     pub fn register_udf(&mut self, f: ScalarFunction) {
171         self.scalar_functions.insert(f.name.clone(), Box::new(f));
172     }
173 
174     /// Get a reference to the registered scalar functions
scalar_functions(&self) -> &HashMap<String, Box<ScalarFunction>>175     pub fn scalar_functions(&self) -> &HashMap<String, Box<ScalarFunction>> {
176         &self.scalar_functions
177     }
178 
build_schema(&self, columns: Vec<SQLColumnDef>) -> Result<Schema>179     fn build_schema(&self, columns: Vec<SQLColumnDef>) -> Result<Schema> {
180         let mut fields = Vec::new();
181 
182         for column in columns {
183             let data_type = self.make_data_type(column.data_type)?;
184             fields.push(Field::new(&column.name, data_type, column.allow_null));
185         }
186 
187         Ok(Schema::new(fields))
188     }
189 
make_data_type(&self, sql_type: SQLType) -> Result<DataType>190     fn make_data_type(&self, sql_type: SQLType) -> Result<DataType> {
191         match sql_type {
192             SQLType::BigInt => Ok(DataType::Int64),
193             SQLType::Int => Ok(DataType::Int32),
194             SQLType::SmallInt => Ok(DataType::Int16),
195             SQLType::Char(_) | SQLType::Varchar(_) | SQLType::Text => Ok(DataType::Utf8),
196             SQLType::Decimal(_, _) => Ok(DataType::Float64),
197             SQLType::Float(_) => Ok(DataType::Float32),
198             SQLType::Real | SQLType::Double => Ok(DataType::Float64),
199             SQLType::Boolean => Ok(DataType::Boolean),
200             SQLType::Date => Ok(DataType::Date64(DateUnit::Day)),
201             SQLType::Time => Ok(DataType::Time64(TimeUnit::Millisecond)),
202             SQLType::Timestamp => Ok(DataType::Date64(DateUnit::Millisecond)),
203             SQLType::Uuid
204             | SQLType::Clob(_)
205             | SQLType::Binary(_)
206             | SQLType::Varbinary(_)
207             | SQLType::Blob(_)
208             | SQLType::Regclass
209             | SQLType::Bytea
210             | SQLType::Custom(_)
211             | SQLType::Array(_) => Err(ExecutionError::General(format!(
212                 "Unsupported data type: {:?}.",
213                 sql_type
214             ))),
215         }
216     }
217 
218     /// Register a CSV file as a table so that it can be queried from SQL
register_csv( &mut self, name: &str, filename: &str, options: CsvReadOptions, ) -> Result<()>219     pub fn register_csv(
220         &mut self,
221         name: &str,
222         filename: &str,
223         options: CsvReadOptions,
224     ) -> Result<()> {
225         self.register_table(name, Box::new(CsvFile::try_new(filename, options)?));
226         Ok(())
227     }
228 
229     /// Register a Parquet file as a table so that it can be queried from SQL
register_parquet(&mut self, name: &str, filename: &str) -> Result<()>230     pub fn register_parquet(&mut self, name: &str, filename: &str) -> Result<()> {
231         let table = ParquetTable::try_new(&filename)?;
232         self.register_table(name, Box::new(table));
233         Ok(())
234     }
235 
236     /// Register a table so that it can be queried from SQL
register_table(&mut self, name: &str, provider: Box<dyn TableProvider>)237     pub fn register_table(&mut self, name: &str, provider: Box<dyn TableProvider>) {
238         self.datasources.insert(name.to_string(), provider);
239     }
240 
241     /// Get a table by name
table(&mut self, table_name: &str) -> Result<Arc<dyn Table>>242     pub fn table(&mut self, table_name: &str) -> Result<Arc<dyn Table>> {
243         match self.datasources.get(table_name) {
244             Some(provider) => {
245                 let schema = provider.schema().as_ref().clone();
246                 let table_scan = LogicalPlan::TableScan {
247                     schema_name: "".to_string(),
248                     table_name: table_name.to_string(),
249                     table_schema: Box::new(schema.to_owned()),
250                     projected_schema: Box::new(schema),
251                     projection: None,
252                 };
253                 Ok(Arc::new(TableImpl::new(
254                     &LogicalPlanBuilder::from(&table_scan).build()?,
255                 )))
256             }
257             _ => Err(ExecutionError::General(format!(
258                 "No table named '{}'",
259                 table_name
260             ))),
261         }
262     }
263 
264     /// Optimize the logical plan by applying optimizer rules
optimize(&self, plan: &LogicalPlan) -> Result<LogicalPlan>265     pub fn optimize(&self, plan: &LogicalPlan) -> Result<LogicalPlan> {
266         let rules: Vec<Box<dyn OptimizerRule>> = vec![
267             Box::new(ResolveColumnsRule::new()),
268             Box::new(ProjectionPushDown::new()),
269             Box::new(TypeCoercionRule::new(&self.scalar_functions)),
270         ];
271         let mut plan = plan.clone();
272         for mut rule in rules {
273             plan = rule.optimize(&plan)?;
274         }
275         Ok(plan)
276     }
277 
278     /// Create a physical plan from a logical plan
create_physical_plan( &mut self, logical_plan: &LogicalPlan, batch_size: usize, ) -> Result<Arc<dyn ExecutionPlan>>279     pub fn create_physical_plan(
280         &mut self,
281         logical_plan: &LogicalPlan,
282         batch_size: usize,
283     ) -> Result<Arc<dyn ExecutionPlan>> {
284         match logical_plan {
285             LogicalPlan::TableScan {
286                 table_name,
287                 projection,
288                 ..
289             } => match self.datasources.get(table_name) {
290                 Some(provider) => {
291                     let partitions = provider.scan(projection, batch_size)?;
292                     if partitions.is_empty() {
293                         Err(ExecutionError::General(
294                             "Table provider returned no partitions".to_string(),
295                         ))
296                     } else {
297                         let partition = partitions[0].lock().unwrap();
298                         let schema = partition.schema();
299                         let exec =
300                             DatasourceExec::new(schema.clone(), partitions.clone());
301                         Ok(Arc::new(exec))
302                     }
303                 }
304                 _ => Err(ExecutionError::General(format!(
305                     "No table named {}",
306                     table_name
307                 ))),
308             },
309             LogicalPlan::InMemoryScan {
310                 data,
311                 schema,
312                 projection,
313                 ..
314             } => Ok(Arc::new(MemoryExec::try_new(
315                 data,
316                 Arc::new(schema.as_ref().to_owned()),
317                 projection.to_owned(),
318             )?)),
319             LogicalPlan::CsvScan {
320                 path,
321                 schema,
322                 has_header,
323                 delimiter,
324                 projection,
325                 ..
326             } => Ok(Arc::new(CsvExec::try_new(
327                 path,
328                 CsvReadOptions::new()
329                     .schema(schema.as_ref())
330                     .delimiter_option(*delimiter)
331                     .has_header(*has_header),
332                 projection.to_owned(),
333                 batch_size,
334             )?)),
335             LogicalPlan::ParquetScan {
336                 path, projection, ..
337             } => Ok(Arc::new(ParquetExec::try_new(
338                 path,
339                 projection.to_owned(),
340                 batch_size,
341             )?)),
342             LogicalPlan::Projection { input, expr, .. } => {
343                 let input = self.create_physical_plan(input, batch_size)?;
344                 let input_schema = input.as_ref().schema().clone();
345                 let runtime_expr = expr
346                     .iter()
347                     .map(|e| self.create_physical_expr(e, &input_schema))
348                     .collect::<Result<Vec<_>>>()?;
349                 Ok(Arc::new(ProjectionExec::try_new(runtime_expr, input)?))
350             }
351             LogicalPlan::Aggregate {
352                 input,
353                 group_expr,
354                 aggr_expr,
355                 ..
356             } => {
357                 // Initially need to perform the aggregate and then merge the partitions
358                 let input = self.create_physical_plan(input, batch_size)?;
359                 let input_schema = input.as_ref().schema().clone();
360 
361                 let group_expr = group_expr
362                     .iter()
363                     .map(|e| self.create_physical_expr(e, &input_schema))
364                     .collect::<Result<Vec<_>>>()?;
365                 let aggr_expr = aggr_expr
366                     .iter()
367                     .map(|e| self.create_aggregate_expr(e, &input_schema))
368                     .collect::<Result<Vec<_>>>()?;
369 
370                 let initial_aggr =
371                     HashAggregateExec::try_new(group_expr, aggr_expr, input)?;
372 
373                 let schema = initial_aggr.schema();
374                 let partitions = initial_aggr.partitions()?;
375 
376                 if partitions.len() == 1 {
377                     return Ok(Arc::new(initial_aggr));
378                 }
379 
380                 let (final_group, final_aggr) = initial_aggr.make_final_expr();
381 
382                 let merge = Arc::new(MergeExec::new(schema.clone(), partitions));
383 
384                 Ok(Arc::new(HashAggregateExec::try_new(
385                     final_group,
386                     final_aggr,
387                     merge,
388                 )?))
389             }
390             LogicalPlan::Selection { input, expr, .. } => {
391                 let input = self.create_physical_plan(input, batch_size)?;
392                 let input_schema = input.as_ref().schema().clone();
393                 let runtime_expr = self.create_physical_expr(expr, &input_schema)?;
394                 Ok(Arc::new(SelectionExec::try_new(runtime_expr, input)?))
395             }
396             LogicalPlan::Limit { input, expr, .. } => {
397                 let input = self.create_physical_plan(input, batch_size)?;
398                 let input_schema = input.as_ref().schema().clone();
399 
400                 match expr {
401                     &Expr::Literal(ref scalar_value) => {
402                         let limit: usize = match scalar_value {
403                             ScalarValue::Int8(limit) if *limit >= 0 => {
404                                 Ok(*limit as usize)
405                             }
406                             ScalarValue::Int16(limit) if *limit >= 0 => {
407                                 Ok(*limit as usize)
408                             }
409                             ScalarValue::Int32(limit) if *limit >= 0 => {
410                                 Ok(*limit as usize)
411                             }
412                             ScalarValue::Int64(limit) if *limit >= 0 => {
413                                 Ok(*limit as usize)
414                             }
415                             ScalarValue::UInt8(limit) => Ok(*limit as usize),
416                             ScalarValue::UInt16(limit) => Ok(*limit as usize),
417                             ScalarValue::UInt32(limit) => Ok(*limit as usize),
418                             ScalarValue::UInt64(limit) => Ok(*limit as usize),
419                             _ => Err(ExecutionError::ExecutionError(
420                                 "Limit only supports non-negative integer literals"
421                                     .to_string(),
422                             )),
423                         }?;
424                         Ok(Arc::new(LimitExec::new(
425                             input_schema.clone(),
426                             input.partitions()?,
427                             limit,
428                         )))
429                     }
430                     _ => Err(ExecutionError::ExecutionError(
431                         "Limit only supports non-negative integer literals".to_string(),
432                     )),
433                 }
434             }
435             _ => Err(ExecutionError::General(
436                 "Unsupported logical plan variant".to_string(),
437             )),
438         }
439     }
440 
441     /// Create a physical expression from a logical expression
create_physical_expr( &self, e: &Expr, input_schema: &Schema, ) -> Result<Arc<dyn PhysicalExpr>>442     pub fn create_physical_expr(
443         &self,
444         e: &Expr,
445         input_schema: &Schema,
446     ) -> Result<Arc<dyn PhysicalExpr>> {
447         match e {
448             Expr::Alias(expr, name) => {
449                 let expr = self.create_physical_expr(expr, input_schema)?;
450                 Ok(Arc::new(Alias::new(expr, &name)))
451             }
452             Expr::Column(i) => {
453                 Ok(Arc::new(Column::new(*i, &input_schema.field(*i).name())))
454             }
455             Expr::Literal(value) => Ok(Arc::new(Literal::new(value.clone()))),
456             Expr::BinaryExpr { left, op, right } => Ok(Arc::new(BinaryExpr::new(
457                 self.create_physical_expr(left, input_schema)?,
458                 op.clone(),
459                 self.create_physical_expr(right, input_schema)?,
460             ))),
461             Expr::Cast { expr, data_type } => Ok(Arc::new(CastExpr::try_new(
462                 self.create_physical_expr(expr, input_schema)?,
463                 input_schema,
464                 data_type.clone(),
465             )?)),
466             Expr::ScalarFunction {
467                 name,
468                 args,
469                 return_type,
470             } => match &self.scalar_functions.get(name) {
471                 Some(f) => {
472                     let mut physical_args = vec![];
473                     for e in args {
474                         physical_args.push(self.create_physical_expr(e, input_schema)?);
475                     }
476                     Ok(Arc::new(ScalarFunctionExpr::new(
477                         name,
478                         Box::new(f.fun.clone()),
479                         physical_args,
480                         return_type,
481                     )))
482                 }
483                 _ => Err(ExecutionError::General(format!(
484                     "Invalid scalar function '{:?}'",
485                     name
486                 ))),
487             },
488             other => Err(ExecutionError::NotImplemented(format!(
489                 "Physical plan does not support logical expression {:?}",
490                 other
491             ))),
492         }
493     }
494 
495     /// Create an aggregate expression from a logical expression
create_aggregate_expr( &self, e: &Expr, input_schema: &Schema, ) -> Result<Arc<dyn AggregateExpr>>496     pub fn create_aggregate_expr(
497         &self,
498         e: &Expr,
499         input_schema: &Schema,
500     ) -> Result<Arc<dyn AggregateExpr>> {
501         match e {
502             Expr::AggregateFunction { name, args, .. } => {
503                 match name.to_lowercase().as_ref() {
504                     "sum" => Ok(Arc::new(Sum::new(
505                         self.create_physical_expr(&args[0], input_schema)?,
506                     ))),
507                     "avg" => Ok(Arc::new(Avg::new(
508                         self.create_physical_expr(&args[0], input_schema)?,
509                     ))),
510                     "max" => Ok(Arc::new(Max::new(
511                         self.create_physical_expr(&args[0], input_schema)?,
512                     ))),
513                     "min" => Ok(Arc::new(Min::new(
514                         self.create_physical_expr(&args[0], input_schema)?,
515                     ))),
516                     "count" => Ok(Arc::new(Count::new(
517                         self.create_physical_expr(&args[0], input_schema)?,
518                     ))),
519                     other => Err(ExecutionError::NotImplemented(format!(
520                         "Unsupported aggregate function '{}'",
521                         other
522                     ))),
523                 }
524             }
525             other => Err(ExecutionError::General(format!(
526                 "Invalid aggregate expression '{:?}'",
527                 other
528             ))),
529         }
530     }
531 
532     /// Execute a physical plan and collect the results in memory
collect(&self, plan: &dyn ExecutionPlan) -> Result<Vec<RecordBatch>>533     pub fn collect(&self, plan: &dyn ExecutionPlan) -> Result<Vec<RecordBatch>> {
534         let partitions = plan.partitions()?;
535 
536         match partitions.len() {
537             0 => Ok(vec![]),
538             1 => {
539                 let it = partitions[0].execute()?;
540                 common::collect(it)
541             }
542             _ => {
543                 // merge into a single partition
544                 let plan = MergeExec::new(plan.schema().clone(), partitions);
545                 let partitions = plan.partitions()?;
546                 if partitions.len() == 1 {
547                     common::collect(partitions[0].execute()?)
548                 } else {
549                     Err(ExecutionError::InternalError(format!(
550                         "MergeExec returned {} partitions",
551                         partitions.len()
552                     )))
553                 }
554             }
555         }
556     }
557 
558     /// Execute a query and write the results to a partitioned CSV file
write_csv(&self, plan: &dyn ExecutionPlan, path: &str) -> Result<()>559     pub fn write_csv(&self, plan: &dyn ExecutionPlan, path: &str) -> Result<()> {
560         // create directory to contain the CSV files (one per partition)
561         let path = path.to_string();
562         fs::create_dir(&path)?;
563 
564         let threads: Vec<JoinHandle<Result<()>>> = plan
565             .partitions()?
566             .iter()
567             .enumerate()
568             .map(|(i, p)| {
569                 let p = p.clone();
570                 let path = path.clone();
571                 thread::spawn(move || {
572                     let filename = format!("part-{}.csv", i);
573                     let path = Path::new(&path).join(&filename);
574                     let file = fs::File::create(path)?;
575                     let mut writer = csv::Writer::new(file);
576                     let it = p.execute()?;
577                     let mut it = it.lock().unwrap();
578                     loop {
579                         match it.next() {
580                             Ok(Some(batch)) => {
581                                 writer.write(&batch)?;
582                             }
583                             Ok(None) => break,
584                             Err(e) => return Err(e),
585                         }
586                     }
587                     Ok(())
588                 })
589             })
590             .collect();
591 
592         // combine the results from each thread
593         for thread in threads {
594             let join = thread.join().expect("Failed to join thread");
595             join?;
596         }
597 
598         Ok(())
599     }
600 }
601 
602 struct ExecutionContextSchemaProvider<'a> {
603     datasources: &'a HashMap<String, Box<dyn TableProvider>>,
604     scalar_functions: &'a HashMap<String, Box<ScalarFunction>>,
605 }
606 
607 impl SchemaProvider for ExecutionContextSchemaProvider<'_> {
get_table_meta(&self, name: &str) -> Option<Arc<Schema>>608     fn get_table_meta(&self, name: &str) -> Option<Arc<Schema>> {
609         self.datasources.get(name).map(|ds| ds.schema().clone())
610     }
611 
get_function_meta(&self, name: &str) -> Option<Arc<FunctionMeta>>612     fn get_function_meta(&self, name: &str) -> Option<Arc<FunctionMeta>> {
613         self.scalar_functions.get(name).map(|f| {
614             Arc::new(FunctionMeta::new(
615                 name.to_owned(),
616                 f.args.clone(),
617                 f.return_type.clone(),
618                 FunctionType::Scalar,
619             ))
620         })
621     }
622 }
623 
624 #[cfg(test)]
625 mod tests {
626 
627     use super::*;
628     use crate::datasource::MemTable;
629     use crate::execution::physical_plan::udf::ScalarUdf;
630     use crate::test;
631     use arrow::array::{ArrayRef, Int32Array};
632     use arrow::compute::add;
633     use std::fs::File;
634     use std::io::prelude::*;
635     use tempdir::TempDir;
636 
637     #[test]
parallel_projection() -> Result<()>638     fn parallel_projection() -> Result<()> {
639         let partition_count = 4;
640         let results = execute("SELECT c1, c2 FROM test", partition_count)?;
641 
642         // there should be one batch per partition
643         assert_eq!(results.len(), partition_count);
644 
645         // each batch should contain 2 columns and 10 rows
646         for batch in &results {
647             assert_eq!(batch.num_columns(), 2);
648             assert_eq!(batch.num_rows(), 10);
649         }
650 
651         Ok(())
652     }
653 
654     #[test]
parallel_selection() -> Result<()>655     fn parallel_selection() -> Result<()> {
656         let tmp_dir = TempDir::new("parallel_selection")?;
657         let partition_count = 4;
658         let mut ctx = create_ctx(&tmp_dir, partition_count)?;
659 
660         let logical_plan =
661             ctx.create_logical_plan("SELECT c1, c2 FROM test WHERE c1 > 0 AND c1 < 3")?;
662         let logical_plan = ctx.optimize(&logical_plan)?;
663 
664         let physical_plan = ctx.create_physical_plan(&logical_plan, 1024)?;
665 
666         let results = ctx.collect(physical_plan.as_ref())?;
667 
668         // there should be one batch per partition
669         assert_eq!(results.len(), partition_count);
670 
671         let row_count: usize = results.iter().map(|batch| batch.num_rows()).sum();
672         assert_eq!(row_count, 20);
673 
674         Ok(())
675     }
676 
677     #[test]
aggregate() -> Result<()>678     fn aggregate() -> Result<()> {
679         let results = execute("SELECT SUM(c1), SUM(c2) FROM test", 4)?;
680         assert_eq!(results.len(), 1);
681 
682         let batch = &results[0];
683         let expected: Vec<&str> = vec!["60,220"];
684         let mut rows = test::format_batch(&batch);
685         rows.sort();
686         assert_eq!(rows, expected);
687 
688         Ok(())
689     }
690 
691     #[test]
aggregate_avg() -> Result<()>692     fn aggregate_avg() -> Result<()> {
693         let results = execute("SELECT AVG(c1), AVG(c2) FROM test", 4)?;
694         assert_eq!(results.len(), 1);
695 
696         let batch = &results[0];
697         let expected: Vec<&str> = vec!["1.5,5.5"];
698         let mut rows = test::format_batch(&batch);
699         rows.sort();
700         assert_eq!(rows, expected);
701 
702         Ok(())
703     }
704 
705     #[test]
aggregate_max() -> Result<()>706     fn aggregate_max() -> Result<()> {
707         let results = execute("SELECT MAX(c1), MAX(c2) FROM test", 4)?;
708         assert_eq!(results.len(), 1);
709 
710         let batch = &results[0];
711         let expected: Vec<&str> = vec!["3,10"];
712         let mut rows = test::format_batch(&batch);
713         rows.sort();
714         assert_eq!(rows, expected);
715 
716         Ok(())
717     }
718 
719     #[test]
aggregate_min() -> Result<()>720     fn aggregate_min() -> Result<()> {
721         let results = execute("SELECT MIN(c1), MIN(c2) FROM test", 4)?;
722         assert_eq!(results.len(), 1);
723 
724         let batch = &results[0];
725         let expected: Vec<&str> = vec!["0,1"];
726         let mut rows = test::format_batch(&batch);
727         rows.sort();
728         assert_eq!(rows, expected);
729 
730         Ok(())
731     }
732 
733     #[test]
aggregate_grouped() -> Result<()>734     fn aggregate_grouped() -> Result<()> {
735         let results = execute("SELECT c1, SUM(c2) FROM test GROUP BY c1", 4)?;
736         assert_eq!(results.len(), 1);
737 
738         let batch = &results[0];
739         let expected: Vec<&str> = vec!["0,55", "1,55", "2,55", "3,55"];
740         let mut rows = test::format_batch(&batch);
741         rows.sort();
742         assert_eq!(rows, expected);
743 
744         Ok(())
745     }
746 
747     #[test]
aggregate_grouped_avg() -> Result<()>748     fn aggregate_grouped_avg() -> Result<()> {
749         let results = execute("SELECT c1, AVG(c2) FROM test GROUP BY c1", 4)?;
750         assert_eq!(results.len(), 1);
751 
752         let batch = &results[0];
753         let expected: Vec<&str> = vec!["0,5.5", "1,5.5", "2,5.5", "3,5.5"];
754         let mut rows = test::format_batch(&batch);
755         rows.sort();
756         assert_eq!(rows, expected);
757 
758         Ok(())
759     }
760 
761     #[test]
aggregate_grouped_max() -> Result<()>762     fn aggregate_grouped_max() -> Result<()> {
763         let results = execute("SELECT c1, MAX(c2) FROM test GROUP BY c1", 4)?;
764         assert_eq!(results.len(), 1);
765 
766         let batch = &results[0];
767         let expected: Vec<&str> = vec!["0,10", "1,10", "2,10", "3,10"];
768         let mut rows = test::format_batch(&batch);
769         rows.sort();
770         assert_eq!(rows, expected);
771 
772         Ok(())
773     }
774 
775     #[test]
aggregate_grouped_min() -> Result<()>776     fn aggregate_grouped_min() -> Result<()> {
777         let results = execute("SELECT c1, MIN(c2) FROM test GROUP BY c1", 4)?;
778         assert_eq!(results.len(), 1);
779 
780         let batch = &results[0];
781         let expected: Vec<&str> = vec!["0,1", "1,1", "2,1", "3,1"];
782         let mut rows = test::format_batch(&batch);
783         rows.sort();
784         assert_eq!(rows, expected);
785 
786         Ok(())
787     }
788 
789     #[test]
count_basic() -> Result<()>790     fn count_basic() -> Result<()> {
791         let results = execute("SELECT COUNT(c1), COUNT(c2) FROM test", 1)?;
792         assert_eq!(results.len(), 1);
793 
794         let batch = &results[0];
795         let expected: Vec<&str> = vec!["10,10"];
796         let mut rows = test::format_batch(&batch);
797         rows.sort();
798         assert_eq!(rows, expected);
799         Ok(())
800     }
801 
802     #[test]
count_partitioned() -> Result<()>803     fn count_partitioned() -> Result<()> {
804         let results = execute("SELECT COUNT(c1), COUNT(c2) FROM test", 4)?;
805         assert_eq!(results.len(), 1);
806 
807         let batch = &results[0];
808         let expected: Vec<&str> = vec!["40,40"];
809         let mut rows = test::format_batch(&batch);
810         rows.sort();
811         assert_eq!(rows, expected);
812         Ok(())
813     }
814 
815     #[test]
count_aggregated() -> Result<()>816     fn count_aggregated() -> Result<()> {
817         let results = execute("SELECT c1, COUNT(c2) FROM test GROUP BY c1", 4)?;
818         assert_eq!(results.len(), 1);
819 
820         let batch = &results[0];
821         let expected = vec!["0,10", "1,10", "2,10", "3,10"];
822         let mut rows = test::format_batch(&batch);
823         rows.sort();
824         assert_eq!(rows, expected);
825         Ok(())
826     }
827 
828     #[test]
aggregate_with_alias() -> Result<()>829     fn aggregate_with_alias() -> Result<()> {
830         let tmp_dir = TempDir::new("execute")?;
831         let mut ctx = create_ctx(&tmp_dir, 1)?;
832 
833         let schema = Arc::new(Schema::new(vec![
834             Field::new("state", DataType::Utf8, false),
835             Field::new("salary", DataType::UInt32, false),
836         ]));
837 
838         let plan = LogicalPlanBuilder::scan("default", "test", schema.as_ref(), None)?
839             .aggregate(
840                 vec![col("state")],
841                 vec![aggregate_expr("SUM", col("salary"), DataType::UInt32)],
842             )?
843             .project(vec![col("state"), col_index(1).alias("total_salary")])?
844             .build()?;
845 
846         let plan = ctx.optimize(&plan)?;
847 
848         let physical_plan = ctx.create_physical_plan(&Arc::new(plan), 1024)?;
849         assert_eq!("c1", physical_plan.schema().field(0).name().as_str());
850         assert_eq!(
851             "total_salary",
852             physical_plan.schema().field(1).name().as_str()
853         );
854         Ok(())
855     }
856 
857     #[test]
write_csv_results() -> Result<()>858     fn write_csv_results() -> Result<()> {
859         // create partitioned input file and context
860         let tmp_dir = TempDir::new("write_csv_results_temp")?;
861         let mut ctx = create_ctx(&tmp_dir, 4)?;
862 
863         // execute a simple query and write the results to CSV
864         let out_dir = tmp_dir.as_ref().to_str().unwrap().to_string() + "/out";
865         write_csv(&mut ctx, "SELECT c1, c2 FROM test", &out_dir)?;
866 
867         // create a new context and verify that the results were saved to a partitioned csv file
868         let mut ctx = ExecutionContext::new();
869 
870         let schema = Arc::new(Schema::new(vec![
871             Field::new("c1", DataType::UInt32, false),
872             Field::new("c2", DataType::UInt64, false),
873         ]));
874 
875         // register each partition as well as the top level dir
876         let csv_read_option = CsvReadOptions::new().schema(&schema);
877         ctx.register_csv("part0", &format!("{}/part-0.csv", out_dir), csv_read_option)?;
878         ctx.register_csv("part1", &format!("{}/part-1.csv", out_dir), csv_read_option)?;
879         ctx.register_csv("part2", &format!("{}/part-2.csv", out_dir), csv_read_option)?;
880         ctx.register_csv("part3", &format!("{}/part-3.csv", out_dir), csv_read_option)?;
881         ctx.register_csv("allparts", &out_dir, csv_read_option)?;
882 
883         let part0 = collect(&mut ctx, "SELECT c1, c2 FROM part0")?;
884         let part1 = collect(&mut ctx, "SELECT c1, c2 FROM part1")?;
885         let part2 = collect(&mut ctx, "SELECT c1, c2 FROM part2")?;
886         let part3 = collect(&mut ctx, "SELECT c1, c2 FROM part3")?;
887         let allparts = collect(&mut ctx, "SELECT c1, c2 FROM allparts")?;
888 
889         let part0_count: usize = part0.iter().map(|batch| batch.num_rows()).sum();
890         let part1_count: usize = part1.iter().map(|batch| batch.num_rows()).sum();
891         let part2_count: usize = part2.iter().map(|batch| batch.num_rows()).sum();
892         let part3_count: usize = part3.iter().map(|batch| batch.num_rows()).sum();
893         let allparts_count: usize = allparts.iter().map(|batch| batch.num_rows()).sum();
894 
895         assert_eq!(part0_count, 10);
896         assert_eq!(part1_count, 10);
897         assert_eq!(part2_count, 10);
898         assert_eq!(part3_count, 10);
899         assert_eq!(allparts_count, 40);
900 
901         Ok(())
902     }
903 
904     #[test]
scalar_udf() -> Result<()>905     fn scalar_udf() -> Result<()> {
906         let schema = Schema::new(vec![
907             Field::new("a", DataType::Int32, false),
908             Field::new("b", DataType::Int32, false),
909         ]);
910 
911         let batch = RecordBatch::try_new(
912             Arc::new(schema.clone()),
913             vec![
914                 Arc::new(Int32Array::from(vec![1, 10, 10, 100])),
915                 Arc::new(Int32Array::from(vec![2, 12, 12, 120])),
916             ],
917         )?;
918 
919         let mut ctx = ExecutionContext::new();
920 
921         let provider = MemTable::new(Arc::new(schema), vec![vec![batch]])?;
922         ctx.register_table("t", Box::new(provider));
923 
924         let myfunc: ScalarUdf = |args: &[ArrayRef]| {
925             let l = &args[0]
926                 .as_any()
927                 .downcast_ref::<Int32Array>()
928                 .expect("cast failed");
929             let r = &args[1]
930                 .as_any()
931                 .downcast_ref::<Int32Array>()
932                 .expect("cast failed");
933             Ok(Arc::new(add(l, r)?))
934         };
935 
936         let my_add = ScalarFunction::new(
937             "my_add",
938             vec![
939                 Field::new("a", DataType::Int32, true),
940                 Field::new("b", DataType::Int32, true),
941             ],
942             DataType::Int32,
943             myfunc,
944         );
945 
946         ctx.register_udf(my_add);
947 
948         let t = ctx.table("t")?;
949 
950         let plan = LogicalPlanBuilder::from(&t.to_logical_plan())
951             .project(vec![
952                 col("a"),
953                 col("b"),
954                 scalar_function("my_add", vec![col("a"), col("b")], DataType::Int32),
955             ])?
956             .build()?;
957 
958         assert_eq!(
959             format!("{:?}", plan),
960             "Projection: #a, #b, my_add(#a, #b)\n  TableScan: t projection=None"
961         );
962 
963         let plan = ctx.optimize(&plan)?;
964         let plan = ctx.create_physical_plan(&plan, 1024)?;
965         let result = ctx.collect(plan.as_ref())?;
966 
967         let batch = &result[0];
968         assert_eq!(3, batch.num_columns());
969         assert_eq!(4, batch.num_rows());
970 
971         let a = batch
972             .column(0)
973             .as_any()
974             .downcast_ref::<Int32Array>()
975             .expect("failed to cast a");
976         let b = batch
977             .column(1)
978             .as_any()
979             .downcast_ref::<Int32Array>()
980             .expect("failed to cast b");
981         let sum = batch
982             .column(2)
983             .as_any()
984             .downcast_ref::<Int32Array>()
985             .expect("failed to cast sum");
986 
987         assert_eq!(4, a.len());
988         assert_eq!(4, b.len());
989         assert_eq!(4, sum.len());
990         for i in 0..sum.len() {
991             assert_eq!(a.value(i) + b.value(i), sum.value(i));
992         }
993 
994         Ok(())
995     }
996 
997     /// Execute SQL and return results
collect(ctx: &mut ExecutionContext, sql: &str) -> Result<Vec<RecordBatch>>998     fn collect(ctx: &mut ExecutionContext, sql: &str) -> Result<Vec<RecordBatch>> {
999         let logical_plan = ctx.create_logical_plan(sql)?;
1000         let logical_plan = ctx.optimize(&logical_plan)?;
1001         let physical_plan = ctx.create_physical_plan(&logical_plan, 1024)?;
1002         ctx.collect(physical_plan.as_ref())
1003     }
1004 
1005     /// Execute SQL and return results
execute(sql: &str, partition_count: usize) -> Result<Vec<RecordBatch>>1006     fn execute(sql: &str, partition_count: usize) -> Result<Vec<RecordBatch>> {
1007         let tmp_dir = TempDir::new("execute")?;
1008         let mut ctx = create_ctx(&tmp_dir, partition_count)?;
1009         collect(&mut ctx, sql)
1010     }
1011 
1012     /// Execute SQL and write results to partitioned csv files
write_csv(ctx: &mut ExecutionContext, sql: &str, out_dir: &str) -> Result<()>1013     fn write_csv(ctx: &mut ExecutionContext, sql: &str, out_dir: &str) -> Result<()> {
1014         let logical_plan = ctx.create_logical_plan(sql)?;
1015         let logical_plan = ctx.optimize(&logical_plan)?;
1016         let physical_plan = ctx.create_physical_plan(&logical_plan, 1024)?;
1017         ctx.write_csv(physical_plan.as_ref(), out_dir)
1018     }
1019 
1020     /// Generate a partitioned CSV file and register it with an execution context
create_ctx(tmp_dir: &TempDir, partition_count: usize) -> Result<ExecutionContext>1021     fn create_ctx(tmp_dir: &TempDir, partition_count: usize) -> Result<ExecutionContext> {
1022         let mut ctx = ExecutionContext::new();
1023 
1024         // define schema for data source (csv file)
1025         let schema = Arc::new(Schema::new(vec![
1026             Field::new("c1", DataType::UInt32, false),
1027             Field::new("c2", DataType::UInt64, false),
1028         ]));
1029 
1030         // generate a partitioned file
1031         for partition in 0..partition_count {
1032             let filename = format!("partition-{}.csv", partition);
1033             let file_path = tmp_dir.path().join(&filename);
1034             let mut file = File::create(file_path)?;
1035 
1036             // generate some data
1037             for i in 0..=10 {
1038                 let data = format!("{},{}\n", partition, i);
1039                 file.write_all(data.as_bytes())?;
1040             }
1041         }
1042 
1043         // register csv file with the execution context
1044         ctx.register_csv(
1045             "test",
1046             tmp_dir.path().to_str().unwrap(),
1047             CsvReadOptions::new().schema(&schema),
1048         )?;
1049 
1050         Ok(ctx)
1051     }
1052 }
1053