1 // Licensed to the Apache Software Foundation (ASF) under one
2 // or more contributor license agreements.  See the NOTICE file
3 // distributed with this work for additional information
4 // regarding copyright ownership.  The ASF licenses this file
5 // to you under the Apache License, Version 2.0 (the
6 // "License"); you may not use this file except in compliance
7 // with the License.  You may obtain a copy of the License at
8 //
9 //   http://www.apache.org/licenses/LICENSE-2.0
10 //
11 // Unless required by applicable law or agreed to in writing,
12 // software distributed under the License is distributed on an
13 // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14 // KIND, either express or implied.  See the License for the
15 // specific language governing permissions and limitations
16 // under the License.
17 
18 use arrow::array::Int32Array;
19 use arrow::datatypes::{DataType, Field, Schema, SchemaRef};
20 use arrow::error::Result as ArrowResult;
21 use arrow::record_batch::RecordBatch;
22 
23 use datafusion::error::{DataFusionError, Result};
24 use datafusion::{
25     datasource::{datasource::Statistics, TableProvider},
26     physical_plan::collect,
27 };
28 
29 use datafusion::execution::context::ExecutionContext;
30 use datafusion::logical_plan::{col, Expr, LogicalPlan, LogicalPlanBuilder};
31 use datafusion::physical_plan::{
32     ExecutionPlan, Partitioning, RecordBatchStream, SendableRecordBatchStream,
33 };
34 
35 use futures::stream::Stream;
36 use std::any::Any;
37 use std::pin::Pin;
38 use std::sync::Arc;
39 use std::task::{Context, Poll};
40 
41 use async_trait::async_trait;
42 
43 //// Custom source dataframe tests ////
44 
45 struct CustomTableProvider;
46 #[derive(Debug, Clone)]
47 struct CustomExecutionPlan {
48     projection: Option<Vec<usize>>,
49 }
50 struct TestCustomRecordBatchStream {
51     /// the nb of batches of TEST_CUSTOM_RECORD_BATCH generated
52     nb_batch: i32,
53 }
54 macro_rules! TEST_CUSTOM_SCHEMA_REF {
55     () => {
56         Arc::new(Schema::new(vec![
57             Field::new("c1", DataType::Int32, false),
58             Field::new("c2", DataType::Int32, false),
59         ]))
60     };
61 }
62 macro_rules! TEST_CUSTOM_RECORD_BATCH {
63     () => {
64         RecordBatch::try_new(
65             TEST_CUSTOM_SCHEMA_REF!(),
66             vec![
67                 Arc::new(Int32Array::from(vec![1, 10, 10, 100])),
68                 Arc::new(Int32Array::from(vec![2, 12, 12, 120])),
69             ],
70         )
71     };
72 }
73 
74 impl RecordBatchStream for TestCustomRecordBatchStream {
schema(&self) -> SchemaRef75     fn schema(&self) -> SchemaRef {
76         TEST_CUSTOM_SCHEMA_REF!()
77     }
78 }
79 
80 impl Stream for TestCustomRecordBatchStream {
81     type Item = ArrowResult<RecordBatch>;
82 
poll_next( self: Pin<&mut Self>, _cx: &mut Context<'_>, ) -> Poll<Option<Self::Item>>83     fn poll_next(
84         self: Pin<&mut Self>,
85         _cx: &mut Context<'_>,
86     ) -> Poll<Option<Self::Item>> {
87         if self.nb_batch > 0 {
88             self.get_mut().nb_batch -= 1;
89             Poll::Ready(Some(TEST_CUSTOM_RECORD_BATCH!()))
90         } else {
91             Poll::Ready(None)
92         }
93     }
94 }
95 
96 #[async_trait]
97 impl ExecutionPlan for CustomExecutionPlan {
as_any(&self) -> &dyn Any98     fn as_any(&self) -> &dyn Any {
99         self
100     }
schema(&self) -> SchemaRef101     fn schema(&self) -> SchemaRef {
102         let schema = TEST_CUSTOM_SCHEMA_REF!();
103         match &self.projection {
104             None => schema,
105             Some(p) => Arc::new(Schema::new(
106                 p.iter().map(|i| schema.field(*i).clone()).collect(),
107             )),
108         }
109     }
output_partitioning(&self) -> Partitioning110     fn output_partitioning(&self) -> Partitioning {
111         Partitioning::UnknownPartitioning(1)
112     }
children(&self) -> Vec<Arc<dyn ExecutionPlan>>113     fn children(&self) -> Vec<Arc<dyn ExecutionPlan>> {
114         vec![]
115     }
with_new_children( &self, children: Vec<Arc<dyn ExecutionPlan>>, ) -> Result<Arc<dyn ExecutionPlan>>116     fn with_new_children(
117         &self,
118         children: Vec<Arc<dyn ExecutionPlan>>,
119     ) -> Result<Arc<dyn ExecutionPlan>> {
120         if children.is_empty() {
121             Ok(Arc::new(self.clone()))
122         } else {
123             Err(DataFusionError::Internal(
124                 "Children cannot be replaced in CustomExecutionPlan".to_owned(),
125             ))
126         }
127     }
execute(&self, _partition: usize) -> Result<SendableRecordBatchStream>128     async fn execute(&self, _partition: usize) -> Result<SendableRecordBatchStream> {
129         Ok(Box::pin(TestCustomRecordBatchStream { nb_batch: 1 }))
130     }
131 }
132 
133 impl TableProvider for CustomTableProvider {
as_any(&self) -> &dyn Any134     fn as_any(&self) -> &dyn Any {
135         self
136     }
137 
schema(&self) -> SchemaRef138     fn schema(&self) -> SchemaRef {
139         TEST_CUSTOM_SCHEMA_REF!()
140     }
141 
scan( &self, projection: &Option<Vec<usize>>, _batch_size: usize, _filters: &[Expr], ) -> Result<Arc<dyn ExecutionPlan>>142     fn scan(
143         &self,
144         projection: &Option<Vec<usize>>,
145         _batch_size: usize,
146         _filters: &[Expr],
147     ) -> Result<Arc<dyn ExecutionPlan>> {
148         Ok(Arc::new(CustomExecutionPlan {
149             projection: projection.clone(),
150         }))
151     }
152 
statistics(&self) -> Statistics153     fn statistics(&self) -> Statistics {
154         Statistics::default()
155     }
156 }
157 
158 #[tokio::test]
custom_source_dataframe() -> Result<()>159 async fn custom_source_dataframe() -> Result<()> {
160     let mut ctx = ExecutionContext::new();
161 
162     let table = ctx.read_table(Arc::new(CustomTableProvider))?;
163     let logical_plan = LogicalPlanBuilder::from(&table.to_logical_plan())
164         .project(&[col("c2")])?
165         .build()?;
166 
167     let optimized_plan = ctx.optimize(&logical_plan)?;
168     match &optimized_plan {
169         LogicalPlan::Projection { input, .. } => match &**input {
170             LogicalPlan::TableScan {
171                 source,
172                 projected_schema,
173                 ..
174             } => {
175                 assert_eq!(source.schema().fields().len(), 2);
176                 assert_eq!(projected_schema.fields().len(), 1);
177             }
178             _ => panic!("input to projection should be TableScan"),
179         },
180         _ => panic!("expect optimized_plan to be projection"),
181     }
182 
183     let expected = "Projection: #c2\
184         \n  TableScan: projection=Some([1])";
185     assert_eq!(format!("{:?}", optimized_plan), expected);
186 
187     let physical_plan = ctx.create_physical_plan(&optimized_plan)?;
188 
189     assert_eq!(1, physical_plan.schema().fields().len());
190     assert_eq!("c2", physical_plan.schema().field(0).name().as_str());
191 
192     let batches = collect(physical_plan).await?;
193     let origin_rec_batch = TEST_CUSTOM_RECORD_BATCH!()?;
194     assert_eq!(1, batches.len());
195     assert_eq!(1, batches[0].num_columns());
196     assert_eq!(origin_rec_batch.num_rows(), batches[0].num_rows());
197 
198     Ok(())
199 }
200