1 // Licensed to the Apache Software Foundation (ASF) under one
2 // or more contributor license agreements.  See the NOTICE file
3 // distributed with this work for additional information
4 // regarding copyright ownership.  The ASF licenses this file
5 // to you under the Apache License, Version 2.0 (the
6 // "License"); you may not use this file except in compliance
7 // with the License.  You may obtain a copy of the License at
8 //
9 //   http://www.apache.org/licenses/LICENSE-2.0
10 //
11 // Unless required by applicable law or agreed to in writing,
12 // software distributed under the License is distributed on an
13 // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14 // KIND, either express or implied.  See the License for the
15 // specific language governing permissions and limitations
16 // under the License.
17 
18 //! FilterExec evaluates a boolean predicate against all input batches to determine which rows to
19 //! include in its output batches.
20 
21 use std::any::Any;
22 use std::pin::Pin;
23 use std::sync::Arc;
24 use std::task::{Context, Poll};
25 
26 use super::{RecordBatchStream, SendableRecordBatchStream};
27 use crate::error::{DataFusionError, Result};
28 use crate::physical_plan::{ExecutionPlan, Partitioning, PhysicalExpr};
29 use arrow::array::BooleanArray;
30 use arrow::compute::filter_record_batch;
31 use arrow::datatypes::{DataType, SchemaRef};
32 use arrow::error::Result as ArrowResult;
33 use arrow::record_batch::RecordBatch;
34 
35 use async_trait::async_trait;
36 
37 use futures::stream::{Stream, StreamExt};
38 
39 /// FilterExec evaluates a boolean predicate against all input batches to determine which rows to
40 /// include in its output batches.
41 #[derive(Debug)]
42 pub struct FilterExec {
43     /// The expression to filter on. This expression must evaluate to a boolean value.
44     predicate: Arc<dyn PhysicalExpr>,
45     /// The input plan
46     input: Arc<dyn ExecutionPlan>,
47 }
48 
49 impl FilterExec {
50     /// Create a FilterExec on an input
try_new( predicate: Arc<dyn PhysicalExpr>, input: Arc<dyn ExecutionPlan>, ) -> Result<Self>51     pub fn try_new(
52         predicate: Arc<dyn PhysicalExpr>,
53         input: Arc<dyn ExecutionPlan>,
54     ) -> Result<Self> {
55         match predicate.data_type(input.schema().as_ref())? {
56             DataType::Boolean => Ok(Self {
57                 predicate,
58                 input: input.clone(),
59             }),
60             other => Err(DataFusionError::Plan(format!(
61                 "Filter predicate must return boolean values, not {:?}",
62                 other
63             ))),
64         }
65     }
66 
67     /// The expression to filter on. This expression must evaluate to a boolean value.
predicate(&self) -> &Arc<dyn PhysicalExpr>68     pub fn predicate(&self) -> &Arc<dyn PhysicalExpr> {
69         &self.predicate
70     }
71 
72     /// The input plan
input(&self) -> &Arc<dyn ExecutionPlan>73     pub fn input(&self) -> &Arc<dyn ExecutionPlan> {
74         &self.input
75     }
76 }
77 
78 #[async_trait]
79 impl ExecutionPlan for FilterExec {
80     /// Return a reference to Any that can be used for downcasting
as_any(&self) -> &dyn Any81     fn as_any(&self) -> &dyn Any {
82         self
83     }
84 
85     /// Get the schema for this execution plan
schema(&self) -> SchemaRef86     fn schema(&self) -> SchemaRef {
87         // The filter operator does not make any changes to the schema of its input
88         self.input.schema()
89     }
90 
children(&self) -> Vec<Arc<dyn ExecutionPlan>>91     fn children(&self) -> Vec<Arc<dyn ExecutionPlan>> {
92         vec![self.input.clone()]
93     }
94 
95     /// Get the output partitioning of this plan
output_partitioning(&self) -> Partitioning96     fn output_partitioning(&self) -> Partitioning {
97         self.input.output_partitioning()
98     }
99 
with_new_children( &self, children: Vec<Arc<dyn ExecutionPlan>>, ) -> Result<Arc<dyn ExecutionPlan>>100     fn with_new_children(
101         &self,
102         children: Vec<Arc<dyn ExecutionPlan>>,
103     ) -> Result<Arc<dyn ExecutionPlan>> {
104         match children.len() {
105             1 => Ok(Arc::new(FilterExec::try_new(
106                 self.predicate.clone(),
107                 children[0].clone(),
108             )?)),
109             _ => Err(DataFusionError::Internal(
110                 "FilterExec wrong number of children".to_string(),
111             )),
112         }
113     }
114 
execute(&self, partition: usize) -> Result<SendableRecordBatchStream>115     async fn execute(&self, partition: usize) -> Result<SendableRecordBatchStream> {
116         Ok(Box::pin(FilterExecStream {
117             schema: self.input.schema().clone(),
118             predicate: self.predicate.clone(),
119             input: self.input.execute(partition).await?,
120         }))
121     }
122 }
123 
124 /// The FilterExec streams wraps the input iterator and applies the predicate expression to
125 /// determine which rows to include in its output batches
126 struct FilterExecStream {
127     /// Output schema, which is the same as the input schema for this operator
128     schema: SchemaRef,
129     /// The expression to filter on. This expression must evaluate to a boolean value.
130     predicate: Arc<dyn PhysicalExpr>,
131     /// The input partition to filter.
132     input: SendableRecordBatchStream,
133 }
134 
batch_filter( batch: &RecordBatch, predicate: &Arc<dyn PhysicalExpr>, ) -> ArrowResult<RecordBatch>135 fn batch_filter(
136     batch: &RecordBatch,
137     predicate: &Arc<dyn PhysicalExpr>,
138 ) -> ArrowResult<RecordBatch> {
139     predicate
140         .evaluate(&batch)
141         .map(|v| v.into_array(batch.num_rows()))
142         .map_err(DataFusionError::into_arrow_external_error)
143         .and_then(|array| {
144             array
145                 .as_any()
146                 .downcast_ref::<BooleanArray>()
147                 .ok_or_else(|| {
148                     DataFusionError::Internal(
149                         "Filter predicate evaluated to non-boolean value".to_string(),
150                     )
151                     .into_arrow_external_error()
152                 })
153                 // apply filter array to record batch
154                 .and_then(|filter_array| filter_record_batch(batch, filter_array))
155         })
156 }
157 
158 impl Stream for FilterExecStream {
159     type Item = ArrowResult<RecordBatch>;
160 
poll_next( mut self: Pin<&mut Self>, cx: &mut Context<'_>, ) -> Poll<Option<Self::Item>>161     fn poll_next(
162         mut self: Pin<&mut Self>,
163         cx: &mut Context<'_>,
164     ) -> Poll<Option<Self::Item>> {
165         self.input.poll_next_unpin(cx).map(|x| match x {
166             Some(Ok(batch)) => Some(batch_filter(&batch, &self.predicate)),
167             other => other,
168         })
169     }
170 
size_hint(&self) -> (usize, Option<usize>)171     fn size_hint(&self) -> (usize, Option<usize>) {
172         // same number of record batches
173         self.input.size_hint()
174     }
175 }
176 
177 impl RecordBatchStream for FilterExecStream {
schema(&self) -> SchemaRef178     fn schema(&self) -> SchemaRef {
179         self.schema.clone()
180     }
181 }
182 
183 #[cfg(test)]
184 mod tests {
185 
186     use super::*;
187     use crate::physical_plan::csv::{CsvExec, CsvReadOptions};
188     use crate::physical_plan::expressions::*;
189     use crate::physical_plan::ExecutionPlan;
190     use crate::scalar::ScalarValue;
191     use crate::test;
192     use crate::{logical_plan::Operator, physical_plan::collect};
193     use std::iter::Iterator;
194 
195     #[tokio::test]
simple_predicate() -> Result<()>196     async fn simple_predicate() -> Result<()> {
197         let schema = test::aggr_test_schema();
198 
199         let partitions = 4;
200         let path = test::create_partitioned_csv("aggregate_test_100.csv", partitions)?;
201 
202         let csv =
203             CsvExec::try_new(&path, CsvReadOptions::new().schema(&schema), None, 1024)?;
204 
205         let predicate: Arc<dyn PhysicalExpr> = binary(
206             binary(
207                 col("c2"),
208                 Operator::Gt,
209                 lit(ScalarValue::from(1u32)),
210                 &schema,
211             )?,
212             Operator::And,
213             binary(
214                 col("c2"),
215                 Operator::Lt,
216                 lit(ScalarValue::from(4u32)),
217                 &schema,
218             )?,
219             &schema,
220         )?;
221 
222         let filter: Arc<dyn ExecutionPlan> =
223             Arc::new(FilterExec::try_new(predicate, Arc::new(csv))?);
224 
225         let results = collect(filter).await?;
226 
227         results
228             .iter()
229             .for_each(|batch| assert_eq!(13, batch.num_columns()));
230         let row_count: usize = results.iter().map(|batch| batch.num_rows()).sum();
231         assert_eq!(41, row_count);
232 
233         Ok(())
234     }
235 }
236