1 // Licensed to the Apache Software Foundation (ASF) under one
2 // or more contributor license agreements. See the NOTICE file
3 // distributed with this work for additional information
4 // regarding copyright ownership. The ASF licenses this file
5 // to you under the Apache License, Version 2.0 (the
6 // "License"); you may not use this file except in compliance
7 // with the License. You may obtain a copy of the License at
8 //
9 // http://www.apache.org/licenses/LICENSE-2.0
10 //
11 // Unless required by applicable law or agreed to in writing,
12 // software distributed under the License is distributed on an
13 // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14 // KIND, either express or implied. See the License for the
15 // specific language governing permissions and limitations
16 // under the License.
17
18 //! FilterExec evaluates a boolean predicate against all input batches to determine which rows to
19 //! include in its output batches.
20
21 use std::any::Any;
22 use std::pin::Pin;
23 use std::sync::Arc;
24 use std::task::{Context, Poll};
25
26 use super::{RecordBatchStream, SendableRecordBatchStream};
27 use crate::error::{DataFusionError, Result};
28 use crate::physical_plan::{ExecutionPlan, Partitioning, PhysicalExpr};
29 use arrow::array::BooleanArray;
30 use arrow::compute::filter_record_batch;
31 use arrow::datatypes::{DataType, SchemaRef};
32 use arrow::error::Result as ArrowResult;
33 use arrow::record_batch::RecordBatch;
34
35 use async_trait::async_trait;
36
37 use futures::stream::{Stream, StreamExt};
38
39 /// FilterExec evaluates a boolean predicate against all input batches to determine which rows to
40 /// include in its output batches.
41 #[derive(Debug)]
42 pub struct FilterExec {
43 /// The expression to filter on. This expression must evaluate to a boolean value.
44 predicate: Arc<dyn PhysicalExpr>,
45 /// The input plan
46 input: Arc<dyn ExecutionPlan>,
47 }
48
49 impl FilterExec {
50 /// Create a FilterExec on an input
try_new( predicate: Arc<dyn PhysicalExpr>, input: Arc<dyn ExecutionPlan>, ) -> Result<Self>51 pub fn try_new(
52 predicate: Arc<dyn PhysicalExpr>,
53 input: Arc<dyn ExecutionPlan>,
54 ) -> Result<Self> {
55 match predicate.data_type(input.schema().as_ref())? {
56 DataType::Boolean => Ok(Self {
57 predicate,
58 input: input.clone(),
59 }),
60 other => Err(DataFusionError::Plan(format!(
61 "Filter predicate must return boolean values, not {:?}",
62 other
63 ))),
64 }
65 }
66
67 /// The expression to filter on. This expression must evaluate to a boolean value.
predicate(&self) -> &Arc<dyn PhysicalExpr>68 pub fn predicate(&self) -> &Arc<dyn PhysicalExpr> {
69 &self.predicate
70 }
71
72 /// The input plan
input(&self) -> &Arc<dyn ExecutionPlan>73 pub fn input(&self) -> &Arc<dyn ExecutionPlan> {
74 &self.input
75 }
76 }
77
78 #[async_trait]
79 impl ExecutionPlan for FilterExec {
80 /// Return a reference to Any that can be used for downcasting
as_any(&self) -> &dyn Any81 fn as_any(&self) -> &dyn Any {
82 self
83 }
84
85 /// Get the schema for this execution plan
schema(&self) -> SchemaRef86 fn schema(&self) -> SchemaRef {
87 // The filter operator does not make any changes to the schema of its input
88 self.input.schema()
89 }
90
children(&self) -> Vec<Arc<dyn ExecutionPlan>>91 fn children(&self) -> Vec<Arc<dyn ExecutionPlan>> {
92 vec![self.input.clone()]
93 }
94
95 /// Get the output partitioning of this plan
output_partitioning(&self) -> Partitioning96 fn output_partitioning(&self) -> Partitioning {
97 self.input.output_partitioning()
98 }
99
with_new_children( &self, children: Vec<Arc<dyn ExecutionPlan>>, ) -> Result<Arc<dyn ExecutionPlan>>100 fn with_new_children(
101 &self,
102 children: Vec<Arc<dyn ExecutionPlan>>,
103 ) -> Result<Arc<dyn ExecutionPlan>> {
104 match children.len() {
105 1 => Ok(Arc::new(FilterExec::try_new(
106 self.predicate.clone(),
107 children[0].clone(),
108 )?)),
109 _ => Err(DataFusionError::Internal(
110 "FilterExec wrong number of children".to_string(),
111 )),
112 }
113 }
114
execute(&self, partition: usize) -> Result<SendableRecordBatchStream>115 async fn execute(&self, partition: usize) -> Result<SendableRecordBatchStream> {
116 Ok(Box::pin(FilterExecStream {
117 schema: self.input.schema().clone(),
118 predicate: self.predicate.clone(),
119 input: self.input.execute(partition).await?,
120 }))
121 }
122 }
123
124 /// The FilterExec streams wraps the input iterator and applies the predicate expression to
125 /// determine which rows to include in its output batches
126 struct FilterExecStream {
127 /// Output schema, which is the same as the input schema for this operator
128 schema: SchemaRef,
129 /// The expression to filter on. This expression must evaluate to a boolean value.
130 predicate: Arc<dyn PhysicalExpr>,
131 /// The input partition to filter.
132 input: SendableRecordBatchStream,
133 }
134
batch_filter( batch: &RecordBatch, predicate: &Arc<dyn PhysicalExpr>, ) -> ArrowResult<RecordBatch>135 fn batch_filter(
136 batch: &RecordBatch,
137 predicate: &Arc<dyn PhysicalExpr>,
138 ) -> ArrowResult<RecordBatch> {
139 predicate
140 .evaluate(&batch)
141 .map(|v| v.into_array(batch.num_rows()))
142 .map_err(DataFusionError::into_arrow_external_error)
143 .and_then(|array| {
144 array
145 .as_any()
146 .downcast_ref::<BooleanArray>()
147 .ok_or_else(|| {
148 DataFusionError::Internal(
149 "Filter predicate evaluated to non-boolean value".to_string(),
150 )
151 .into_arrow_external_error()
152 })
153 // apply filter array to record batch
154 .and_then(|filter_array| filter_record_batch(batch, filter_array))
155 })
156 }
157
158 impl Stream for FilterExecStream {
159 type Item = ArrowResult<RecordBatch>;
160
poll_next( mut self: Pin<&mut Self>, cx: &mut Context<'_>, ) -> Poll<Option<Self::Item>>161 fn poll_next(
162 mut self: Pin<&mut Self>,
163 cx: &mut Context<'_>,
164 ) -> Poll<Option<Self::Item>> {
165 self.input.poll_next_unpin(cx).map(|x| match x {
166 Some(Ok(batch)) => Some(batch_filter(&batch, &self.predicate)),
167 other => other,
168 })
169 }
170
size_hint(&self) -> (usize, Option<usize>)171 fn size_hint(&self) -> (usize, Option<usize>) {
172 // same number of record batches
173 self.input.size_hint()
174 }
175 }
176
177 impl RecordBatchStream for FilterExecStream {
schema(&self) -> SchemaRef178 fn schema(&self) -> SchemaRef {
179 self.schema.clone()
180 }
181 }
182
183 #[cfg(test)]
184 mod tests {
185
186 use super::*;
187 use crate::physical_plan::csv::{CsvExec, CsvReadOptions};
188 use crate::physical_plan::expressions::*;
189 use crate::physical_plan::ExecutionPlan;
190 use crate::scalar::ScalarValue;
191 use crate::test;
192 use crate::{logical_plan::Operator, physical_plan::collect};
193 use std::iter::Iterator;
194
195 #[tokio::test]
simple_predicate() -> Result<()>196 async fn simple_predicate() -> Result<()> {
197 let schema = test::aggr_test_schema();
198
199 let partitions = 4;
200 let path = test::create_partitioned_csv("aggregate_test_100.csv", partitions)?;
201
202 let csv =
203 CsvExec::try_new(&path, CsvReadOptions::new().schema(&schema), None, 1024)?;
204
205 let predicate: Arc<dyn PhysicalExpr> = binary(
206 binary(
207 col("c2"),
208 Operator::Gt,
209 lit(ScalarValue::from(1u32)),
210 &schema,
211 )?,
212 Operator::And,
213 binary(
214 col("c2"),
215 Operator::Lt,
216 lit(ScalarValue::from(4u32)),
217 &schema,
218 )?,
219 &schema,
220 )?;
221
222 let filter: Arc<dyn ExecutionPlan> =
223 Arc::new(FilterExec::try_new(predicate, Arc::new(csv))?);
224
225 let results = collect(filter).await?;
226
227 results
228 .iter()
229 .for_each(|batch| assert_eq!(13, batch.num_columns()));
230 let row_count: usize = results.iter().map(|batch| batch.num_rows()).sum();
231 assert_eq!(41, row_count);
232
233 Ok(())
234 }
235 }
236