1 // regarding copyright ownership.  The ASF licenses this file
2 // to you under the Apache License, Version 2.0 (the
3 // "License"); you may not use this file except in compliance
4 // with the License.  You may obtain a copy of the License at
5 //
6 //   http://www.apache.org/licenses/LICENSE-2.0
7 //
8 // Unless required by applicable law or agreed to in writing,
9 // software distributed under the License is distributed on an
10 // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
11 // KIND, either express or implied.  See the License for the
12 // specific language governing permissions and limitations
13 // under the License.
14 
15 //! Filter Push Down optimizer rule ensures that filters are applied as early as possible in the plan
16 
17 use crate::datasource::datasource::TableProviderFilterPushDown;
18 use crate::logical_plan::{and, LogicalPlan};
19 use crate::logical_plan::{DFSchema, Expr};
20 use crate::optimizer::optimizer::OptimizerRule;
21 use crate::optimizer::utils;
22 use crate::{error::Result, logical_plan::Operator};
23 use std::{
24     collections::{HashMap, HashSet},
25     sync::Arc,
26 };
27 
28 /// Filter Push Down optimizer rule pushes filter clauses down the plan
29 /// # Introduction
30 /// A filter-commutative operation is an operation whose result of filter(op(data)) = op(filter(data)).
31 /// An example of a filter-commutative operation is a projection; a counter-example is `limit`.
32 ///
33 /// The filter-commutative property is column-specific. An aggregate grouped by A on SUM(B)
34 /// can commute with a filter that depends on A only, but does not commute with a filter that depends
35 /// on SUM(B).
36 ///
37 /// This optimizer commutes filters with filter-commutative operations to push the filters
38 /// the closest possible to the scans, re-writing the filter expressions by every
39 /// projection that changes the filter's expression.
40 ///
41 /// Filter: #b Gt Int64(10)
42 ///     Projection: #a AS b
43 ///
44 /// is optimized to
45 ///
46 /// Projection: #a AS b
47 ///     Filter: #a Gt Int64(10)  <--- changed from #b to #a
48 ///
49 /// This performs a single pass trought the plan. When it passes trought a filter, it stores that filter,
50 /// and when it reaches a node that does not commute with it, it adds the filter to that place.
51 /// When it passes through a projection, it re-writes the filter's expression taking into accoun that projection.
52 /// When multiple filters would have been written, it `AND` their expressions into a single expression.
53 pub struct FilterPushDown {}
54 
55 #[derive(Debug, Clone, Default)]
56 struct State {
57     // (predicate, columns on the predicate)
58     filters: Vec<(Expr, HashSet<String>)>,
59 }
60 
61 type Predicates<'a> = (Vec<&'a Expr>, Vec<&'a HashSet<String>>);
62 
63 /// returns all predicates in `state` that depend on any of `used_columns`
get_predicates<'a>( state: &'a State, used_columns: &HashSet<String>, ) -> Predicates<'a>64 fn get_predicates<'a>(
65     state: &'a State,
66     used_columns: &HashSet<String>,
67 ) -> Predicates<'a> {
68     state
69         .filters
70         .iter()
71         .filter(|(_, columns)| {
72             !columns
73                 .intersection(used_columns)
74                 .collect::<HashSet<_>>()
75                 .is_empty()
76         })
77         .map(|&(ref a, ref b)| (a, b))
78         .unzip()
79 }
80 
81 // returns 3 (potentially overlaping) sets of predicates:
82 // * pushable to left: its columns are all on the left
83 // * pushable to right: its columns is all on the right
84 // * keep: the set of columns is not in only either left or right
85 // Note that a predicate can be both pushed to the left and to the right.
get_join_predicates<'a>( state: &'a State, left: &DFSchema, right: &DFSchema, ) -> ( Vec<&'a HashSet<String>>, Vec<&'a HashSet<String>>, Predicates<'a>, )86 fn get_join_predicates<'a>(
87     state: &'a State,
88     left: &DFSchema,
89     right: &DFSchema,
90 ) -> (
91     Vec<&'a HashSet<String>>,
92     Vec<&'a HashSet<String>>,
93     Predicates<'a>,
94 ) {
95     let left_columns = &left
96         .fields()
97         .iter()
98         .map(|f| f.name().clone())
99         .collect::<HashSet<_>>();
100     let right_columns = &right
101         .fields()
102         .iter()
103         .map(|f| f.name().clone())
104         .collect::<HashSet<_>>();
105 
106     let filters = state
107         .filters
108         .iter()
109         .map(|(predicate, columns)| {
110             (
111                 (predicate, columns),
112                 (
113                     columns,
114                     left_columns.intersection(columns).collect::<HashSet<_>>(),
115                     right_columns.intersection(columns).collect::<HashSet<_>>(),
116                 ),
117             )
118         })
119         .collect::<Vec<_>>();
120 
121     let pushable_to_left = filters
122         .iter()
123         .filter(|(_, (columns, left, _))| left.len() == columns.len())
124         .map(|((_, b), _)| *b)
125         .collect();
126     let pushable_to_right = filters
127         .iter()
128         .filter(|(_, (columns, _, right))| right.len() == columns.len())
129         .map(|((_, b), _)| *b)
130         .collect();
131     let keep = filters
132         .iter()
133         .filter(|(_, (columns, left, right))| {
134             // predicates whose columns are not in only one side of the join need to remain
135             let all_in_left = left.len() == columns.len();
136             let all_in_right = right.len() == columns.len();
137             !all_in_left && !all_in_right
138         })
139         .map(|((ref a, ref b), _)| (a, b))
140         .unzip();
141     (pushable_to_left, pushable_to_right, keep)
142 }
143 
144 /// Optimizes the plan
push_down(state: &State, plan: &LogicalPlan) -> Result<LogicalPlan>145 fn push_down(state: &State, plan: &LogicalPlan) -> Result<LogicalPlan> {
146     let new_inputs = utils::inputs(&plan)
147         .iter()
148         .map(|input| optimize(input, state.clone()))
149         .collect::<Result<Vec<_>>>()?;
150 
151     let expr = utils::expressions(&plan);
152     utils::from_plan(&plan, &expr, &new_inputs)
153 }
154 
155 /// returns a new [LogicalPlan] that wraps `plan` in a [LogicalPlan::Filter] with
156 /// its predicate be all `predicates` ANDed.
add_filter(plan: LogicalPlan, predicates: &[&Expr]) -> LogicalPlan157 fn add_filter(plan: LogicalPlan, predicates: &[&Expr]) -> LogicalPlan {
158     // reduce filters to a single filter with an AND
159     let predicate = predicates
160         .iter()
161         .skip(1)
162         .fold(predicates[0].clone(), |acc, predicate| {
163             and(acc, (*predicate).to_owned())
164         });
165 
166     LogicalPlan::Filter {
167         predicate,
168         input: Arc::new(plan),
169     }
170 }
171 
172 // remove all filters from `filters` that are in `predicate_columns`
remove_filters( filters: &[(Expr, HashSet<String>)], predicate_columns: &[&HashSet<String>], ) -> Vec<(Expr, HashSet<String>)>173 fn remove_filters(
174     filters: &[(Expr, HashSet<String>)],
175     predicate_columns: &[&HashSet<String>],
176 ) -> Vec<(Expr, HashSet<String>)> {
177     filters
178         .iter()
179         .filter(|(_, columns)| !predicate_columns.contains(&columns))
180         .cloned()
181         .collect::<Vec<_>>()
182 }
183 
184 // keeps all filters from `filters` that are in `predicate_columns`
keep_filters( filters: &[(Expr, HashSet<String>)], predicate_columns: &[&HashSet<String>], ) -> Vec<(Expr, HashSet<String>)>185 fn keep_filters(
186     filters: &[(Expr, HashSet<String>)],
187     predicate_columns: &[&HashSet<String>],
188 ) -> Vec<(Expr, HashSet<String>)> {
189     filters
190         .iter()
191         .filter(|(_, columns)| predicate_columns.contains(&columns))
192         .cloned()
193         .collect::<Vec<_>>()
194 }
195 
196 /// builds a new [LogicalPlan] from `plan` by issuing new [LogicalPlan::Filter] if any of the filters
197 /// in `state` depend on the columns `used_columns`.
issue_filters( mut state: State, used_columns: HashSet<String>, plan: &LogicalPlan, ) -> Result<LogicalPlan>198 fn issue_filters(
199     mut state: State,
200     used_columns: HashSet<String>,
201     plan: &LogicalPlan,
202 ) -> Result<LogicalPlan> {
203     let (predicates, predicate_columns) = get_predicates(&state, &used_columns);
204 
205     if predicates.is_empty() {
206         // all filters can be pushed down => optimize inputs and return new plan
207         return push_down(&state, plan);
208     }
209 
210     let plan = add_filter(plan.clone(), &predicates);
211 
212     state.filters = remove_filters(&state.filters, &predicate_columns);
213 
214     // continue optimization over all input nodes by cloning the current state (i.e. each node is independent)
215     push_down(&state, &plan)
216 }
217 
218 /// converts "A AND B AND C" => [A, B, C]
split_members<'a>(predicate: &'a Expr, predicates: &mut Vec<&'a Expr>)219 fn split_members<'a>(predicate: &'a Expr, predicates: &mut Vec<&'a Expr>) {
220     match predicate {
221         Expr::BinaryExpr {
222             right,
223             op: Operator::And,
224             left,
225         } => {
226             split_members(&left, predicates);
227             split_members(&right, predicates);
228         }
229         other => predicates.push(other),
230     }
231 }
232 
optimize(plan: &LogicalPlan, mut state: State) -> Result<LogicalPlan>233 fn optimize(plan: &LogicalPlan, mut state: State) -> Result<LogicalPlan> {
234     match plan {
235         LogicalPlan::Filter { input, predicate } => {
236             let mut predicates = vec![];
237             split_members(predicate, &mut predicates);
238 
239             predicates
240                 .into_iter()
241                 .try_for_each::<_, Result<()>>(|predicate| {
242                     let mut columns: HashSet<String> = HashSet::new();
243                     utils::expr_to_column_names(predicate, &mut columns)?;
244                     // collect the predicate
245                     state.filters.push((predicate.clone(), columns));
246                     Ok(())
247                 })?;
248 
249             optimize(input, state)
250         }
251         LogicalPlan::Projection {
252             input,
253             expr,
254             schema,
255         } => {
256             // A projection is filter-commutable, but re-writes all predicate expressions
257             // collect projection.
258             let mut projection = HashMap::new();
259             schema.fields().iter().enumerate().for_each(|(i, field)| {
260                 // strip alias, as they should not be part of filters
261                 let expr = match &expr[i] {
262                     Expr::Alias(expr, _) => expr.as_ref().clone(),
263                     expr => expr.clone(),
264                 };
265 
266                 projection.insert(field.name().clone(), expr);
267             });
268 
269             // re-write all filters based on this projection
270             // E.g. in `Filter: #b\n  Projection: #a > 1 as b`, we can swap them, but the filter must be "#a > 1"
271             for (predicate, columns) in state.filters.iter_mut() {
272                 *predicate = rewrite(predicate, &projection)?;
273 
274                 columns.clear();
275                 utils::expr_to_column_names(predicate, columns)?;
276             }
277 
278             // optimize inner
279             let new_input = optimize(input, state)?;
280 
281             utils::from_plan(&plan, &expr, &[new_input])
282         }
283         LogicalPlan::Aggregate {
284             input, aggr_expr, ..
285         } => {
286             // An aggregate's aggreagate columns are _not_ filter-commutable => collect these:
287             // * columns whose aggregation expression depends on
288             // * the aggregation columns themselves
289 
290             // construct set of columns that `aggr_expr` depends on
291             let mut used_columns = HashSet::new();
292             utils::exprlist_to_column_names(aggr_expr, &mut used_columns)?;
293 
294             let agg_columns = aggr_expr
295                 .iter()
296                 .map(|x| x.name(input.schema()))
297                 .collect::<Result<HashSet<_>>>()?;
298             used_columns.extend(agg_columns);
299 
300             issue_filters(state, used_columns, plan)
301         }
302         LogicalPlan::Sort { .. } => {
303             // sort is filter-commutable
304             push_down(&state, plan)
305         }
306         LogicalPlan::Limit { input, .. } => {
307             // limit is _not_ filter-commutable => collect all columns from its input
308             let used_columns = input
309                 .schema()
310                 .fields()
311                 .iter()
312                 .map(|f| f.name().clone())
313                 .collect::<HashSet<_>>();
314             issue_filters(state, used_columns, plan)
315         }
316         LogicalPlan::Join { left, right, .. } => {
317             let (pushable_to_left, pushable_to_right, keep) =
318                 get_join_predicates(&state, &left.schema(), &right.schema());
319 
320             let mut left_state = state.clone();
321             left_state.filters = keep_filters(&left_state.filters, &pushable_to_left);
322             let left = optimize(left, left_state)?;
323 
324             let mut right_state = state.clone();
325             right_state.filters = keep_filters(&right_state.filters, &pushable_to_right);
326             let right = optimize(right, right_state)?;
327 
328             // create a new Join with the new `left` and `right`
329             let expr = utils::expressions(&plan);
330             let plan = utils::from_plan(&plan, &expr, &[left, right])?;
331 
332             if keep.0.is_empty() {
333                 Ok(plan)
334             } else {
335                 // wrap the join on the filter whose predicates must be kept
336                 let plan = add_filter(plan, &keep.0);
337                 state.filters = remove_filters(&state.filters, &keep.1);
338 
339                 Ok(plan)
340             }
341         }
342         LogicalPlan::TableScan {
343             source,
344             projected_schema,
345             filters,
346             projection,
347             table_name,
348         } => {
349             let mut used_columns = HashSet::new();
350             let mut new_filters = filters.clone();
351 
352             for (filter_expr, cols) in &state.filters {
353                 let (preserve_filter_node, add_to_provider) =
354                     match source.supports_filter_pushdown(filter_expr)? {
355                         TableProviderFilterPushDown::Unsupported => (true, false),
356                         TableProviderFilterPushDown::Inexact => (true, true),
357                         TableProviderFilterPushDown::Exact => (false, true),
358                     };
359 
360                 if preserve_filter_node {
361                     used_columns.extend(cols.clone());
362                 }
363 
364                 if add_to_provider {
365                     new_filters.push(filter_expr.clone());
366                 }
367             }
368 
369             issue_filters(
370                 state,
371                 used_columns,
372                 &LogicalPlan::TableScan {
373                     source: source.clone(),
374                     projection: projection.clone(),
375                     projected_schema: projected_schema.clone(),
376                     table_name: table_name.clone(),
377                     filters: new_filters,
378                 },
379             )
380         }
381         _ => {
382             // all other plans are _not_ filter-commutable
383             let used_columns = plan
384                 .schema()
385                 .fields()
386                 .iter()
387                 .map(|f| f.name().clone())
388                 .collect::<HashSet<_>>();
389             issue_filters(state, used_columns, plan)
390         }
391     }
392 }
393 
394 impl OptimizerRule for FilterPushDown {
name(&self) -> &str395     fn name(&self) -> &str {
396         "filter_push_down"
397     }
398 
optimize(&self, plan: &LogicalPlan) -> Result<LogicalPlan>399     fn optimize(&self, plan: &LogicalPlan) -> Result<LogicalPlan> {
400         optimize(plan, State::default())
401     }
402 }
403 
404 impl FilterPushDown {
405     #[allow(missing_docs)]
new() -> Self406     pub fn new() -> Self {
407         Self {}
408     }
409 }
410 
411 /// replaces columns by its name on the projection.
rewrite(expr: &Expr, projection: &HashMap<String, Expr>) -> Result<Expr>412 fn rewrite(expr: &Expr, projection: &HashMap<String, Expr>) -> Result<Expr> {
413     let expressions = utils::expr_sub_expressions(&expr)?;
414 
415     let expressions = expressions
416         .iter()
417         .map(|e| rewrite(e, &projection))
418         .collect::<Result<Vec<_>>>()?;
419 
420     if let Expr::Column(name) = expr {
421         if let Some(expr) = projection.get(name) {
422             return Ok(expr.clone());
423         }
424     }
425 
426     utils::rewrite_expression(&expr, &expressions)
427 }
428 
429 #[cfg(test)]
430 mod tests {
431     use super::*;
432     use crate::datasource::datasource::Statistics;
433     use crate::datasource::TableProvider;
434     use crate::logical_plan::{lit, sum, DFSchema, Expr, LogicalPlanBuilder, Operator};
435     use crate::physical_plan::ExecutionPlan;
436     use crate::test::*;
437     use crate::{logical_plan::col, prelude::JoinType};
438     use arrow::datatypes::SchemaRef;
439 
assert_optimized_plan_eq(plan: &LogicalPlan, expected: &str)440     fn assert_optimized_plan_eq(plan: &LogicalPlan, expected: &str) {
441         let rule = FilterPushDown::new();
442         let optimized_plan = rule.optimize(plan).expect("failed to optimize plan");
443         let formatted_plan = format!("{:?}", optimized_plan);
444         assert_eq!(formatted_plan, expected);
445     }
446 
447     #[test]
filter_before_projection() -> Result<()>448     fn filter_before_projection() -> Result<()> {
449         let table_scan = test_table_scan()?;
450         let plan = LogicalPlanBuilder::from(&table_scan)
451             .project(&[col("a"), col("b")])?
452             .filter(col("a").eq(lit(1i64)))?
453             .build()?;
454         // filter is before projection
455         let expected = "\
456             Projection: #a, #b\
457             \n  Filter: #a Eq Int64(1)\
458             \n    TableScan: test projection=None";
459         assert_optimized_plan_eq(&plan, expected);
460         Ok(())
461     }
462 
463     #[test]
filter_after_limit() -> Result<()>464     fn filter_after_limit() -> Result<()> {
465         let table_scan = test_table_scan()?;
466         let plan = LogicalPlanBuilder::from(&table_scan)
467             .project(&[col("a"), col("b")])?
468             .limit(10)?
469             .filter(col("a").eq(lit(1i64)))?
470             .build()?;
471         // filter is before single projection
472         let expected = "\
473             Filter: #a Eq Int64(1)\
474             \n  Limit: 10\
475             \n    Projection: #a, #b\
476             \n      TableScan: test projection=None";
477         assert_optimized_plan_eq(&plan, expected);
478         Ok(())
479     }
480 
481     #[test]
filter_jump_2_plans() -> Result<()>482     fn filter_jump_2_plans() -> Result<()> {
483         let table_scan = test_table_scan()?;
484         let plan = LogicalPlanBuilder::from(&table_scan)
485             .project(&[col("a"), col("b"), col("c")])?
486             .project(&[col("c"), col("b")])?
487             .filter(col("a").eq(lit(1i64)))?
488             .build()?;
489         // filter is before double projection
490         let expected = "\
491             Projection: #c, #b\
492             \n  Projection: #a, #b, #c\
493             \n    Filter: #a Eq Int64(1)\
494             \n      TableScan: test projection=None";
495         assert_optimized_plan_eq(&plan, expected);
496         Ok(())
497     }
498 
499     #[test]
filter_move_agg() -> Result<()>500     fn filter_move_agg() -> Result<()> {
501         let table_scan = test_table_scan()?;
502         let plan = LogicalPlanBuilder::from(&table_scan)
503             .aggregate(&[col("a")], &[sum(col("b")).alias("total_salary")])?
504             .filter(col("a").gt(lit(10i64)))?
505             .build()?;
506         // filter of key aggregation is commutative
507         let expected = "\
508             Aggregate: groupBy=[[#a]], aggr=[[SUM(#b) AS total_salary]]\
509             \n  Filter: #a Gt Int64(10)\
510             \n    TableScan: test projection=None";
511         assert_optimized_plan_eq(&plan, expected);
512         Ok(())
513     }
514 
515     #[test]
filter_keep_agg() -> Result<()>516     fn filter_keep_agg() -> Result<()> {
517         let table_scan = test_table_scan()?;
518         let plan = LogicalPlanBuilder::from(&table_scan)
519             .aggregate(&[col("a")], &[sum(col("b")).alias("b")])?
520             .filter(col("b").gt(lit(10i64)))?
521             .build()?;
522         // filter of aggregate is after aggregation since they are non-commutative
523         let expected = "\
524             Filter: #b Gt Int64(10)\
525             \n  Aggregate: groupBy=[[#a]], aggr=[[SUM(#b) AS b]]\
526             \n    TableScan: test projection=None";
527         assert_optimized_plan_eq(&plan, expected);
528         Ok(())
529     }
530 
531     /// verifies that a filter is pushed to before a projection, the filter expression is correctly re-written
532     #[test]
alias() -> Result<()>533     fn alias() -> Result<()> {
534         let table_scan = test_table_scan()?;
535         let plan = LogicalPlanBuilder::from(&table_scan)
536             .project(&[col("a").alias("b"), col("c")])?
537             .filter(col("b").eq(lit(1i64)))?
538             .build()?;
539         // filter is before projection
540         let expected = "\
541             Projection: #a AS b, #c\
542             \n  Filter: #a Eq Int64(1)\
543             \n    TableScan: test projection=None";
544         assert_optimized_plan_eq(&plan, expected);
545         Ok(())
546     }
547 
add(left: Expr, right: Expr) -> Expr548     fn add(left: Expr, right: Expr) -> Expr {
549         Expr::BinaryExpr {
550             left: Box::new(left),
551             op: Operator::Plus,
552             right: Box::new(right),
553         }
554     }
555 
multiply(left: Expr, right: Expr) -> Expr556     fn multiply(left: Expr, right: Expr) -> Expr {
557         Expr::BinaryExpr {
558             left: Box::new(left),
559             op: Operator::Multiply,
560             right: Box::new(right),
561         }
562     }
563 
564     /// verifies that a filter is pushed to before a projection with a complex expression, the filter expression is correctly re-written
565     #[test]
complex_expression() -> Result<()>566     fn complex_expression() -> Result<()> {
567         let table_scan = test_table_scan()?;
568         let plan = LogicalPlanBuilder::from(&table_scan)
569             .project(&[
570                 add(multiply(col("a"), lit(2)), col("c")).alias("b"),
571                 col("c"),
572             ])?
573             .filter(col("b").eq(lit(1i64)))?
574             .build()?;
575 
576         // not part of the test, just good to know:
577         assert_eq!(
578             format!("{:?}", plan),
579             "\
580             Filter: #b Eq Int64(1)\
581             \n  Projection: #a Multiply Int32(2) Plus #c AS b, #c\
582             \n    TableScan: test projection=None"
583         );
584 
585         // filter is before projection
586         let expected = "\
587             Projection: #a Multiply Int32(2) Plus #c AS b, #c\
588             \n  Filter: #a Multiply Int32(2) Plus #c Eq Int64(1)\
589             \n    TableScan: test projection=None";
590         assert_optimized_plan_eq(&plan, expected);
591         Ok(())
592     }
593 
594     /// verifies that when a filter is pushed to after 2 projections, the filter expression is correctly re-written
595     #[test]
complex_plan() -> Result<()>596     fn complex_plan() -> Result<()> {
597         let table_scan = test_table_scan()?;
598         let plan = LogicalPlanBuilder::from(&table_scan)
599             .project(&[
600                 add(multiply(col("a"), lit(2)), col("c")).alias("b"),
601                 col("c"),
602             ])?
603             // second projection where we rename columns, just to make it difficult
604             .project(&[multiply(col("b"), lit(3)).alias("a"), col("c")])?
605             .filter(col("a").eq(lit(1i64)))?
606             .build()?;
607 
608         // not part of the test, just good to know:
609         assert_eq!(
610             format!("{:?}", plan),
611             "\
612             Filter: #a Eq Int64(1)\
613             \n  Projection: #b Multiply Int32(3) AS a, #c\
614             \n    Projection: #a Multiply Int32(2) Plus #c AS b, #c\
615             \n      TableScan: test projection=None"
616         );
617 
618         // filter is before the projections
619         let expected = "\
620         Projection: #b Multiply Int32(3) AS a, #c\
621         \n  Projection: #a Multiply Int32(2) Plus #c AS b, #c\
622         \n    Filter: #a Multiply Int32(2) Plus #c Multiply Int32(3) Eq Int64(1)\
623         \n      TableScan: test projection=None";
624         assert_optimized_plan_eq(&plan, expected);
625         Ok(())
626     }
627 
628     /// verifies that when two filters apply after an aggregation that only allows one to be pushed, one is pushed
629     /// and the other not.
630     #[test]
multi_filter() -> Result<()>631     fn multi_filter() -> Result<()> {
632         // the aggregation allows one filter to pass (b), and the other one to not pass (SUM(c))
633         let table_scan = test_table_scan()?;
634         let plan = LogicalPlanBuilder::from(&table_scan)
635             .project(&[col("a").alias("b"), col("c")])?
636             .aggregate(&[col("b")], &[sum(col("c"))])?
637             .filter(col("b").gt(lit(10i64)))?
638             .filter(col("SUM(c)").gt(lit(10i64)))?
639             .build()?;
640 
641         // not part of the test, just good to know:
642         assert_eq!(
643             format!("{:?}", plan),
644             "\
645             Filter: #SUM(c) Gt Int64(10)\
646             \n  Filter: #b Gt Int64(10)\
647             \n    Aggregate: groupBy=[[#b]], aggr=[[SUM(#c)]]\
648             \n      Projection: #a AS b, #c\
649             \n        TableScan: test projection=None"
650         );
651 
652         // filter is before the projections
653         let expected = "\
654         Filter: #SUM(c) Gt Int64(10)\
655         \n  Aggregate: groupBy=[[#b]], aggr=[[SUM(#c)]]\
656         \n    Projection: #a AS b, #c\
657         \n      Filter: #a Gt Int64(10)\
658         \n        TableScan: test projection=None";
659         assert_optimized_plan_eq(&plan, expected);
660 
661         Ok(())
662     }
663 
664     /// verifies that when a filter with two predicates is applied after an aggregation that only allows one to be pushed, one is pushed
665     /// and the other not.
666     #[test]
split_filter() -> Result<()>667     fn split_filter() -> Result<()> {
668         // the aggregation allows one filter to pass (b), and the other one to not pass (SUM(c))
669         let table_scan = test_table_scan()?;
670         let plan = LogicalPlanBuilder::from(&table_scan)
671             .project(&[col("a").alias("b"), col("c")])?
672             .aggregate(&[col("b")], &[sum(col("c"))])?
673             .filter(and(
674                 col("SUM(c)").gt(lit(10i64)),
675                 and(col("b").gt(lit(10i64)), col("SUM(c)").lt(lit(20i64))),
676             ))?
677             .build()?;
678 
679         // not part of the test, just good to know:
680         assert_eq!(
681             format!("{:?}", plan),
682             "\
683             Filter: #SUM(c) Gt Int64(10) And #b Gt Int64(10) And #SUM(c) Lt Int64(20)\
684             \n  Aggregate: groupBy=[[#b]], aggr=[[SUM(#c)]]\
685             \n    Projection: #a AS b, #c\
686             \n      TableScan: test projection=None"
687         );
688 
689         // filter is before the projections
690         let expected = "\
691         Filter: #SUM(c) Gt Int64(10) And #SUM(c) Lt Int64(20)\
692         \n  Aggregate: groupBy=[[#b]], aggr=[[SUM(#c)]]\
693         \n    Projection: #a AS b, #c\
694         \n      Filter: #a Gt Int64(10)\
695         \n        TableScan: test projection=None";
696         assert_optimized_plan_eq(&plan, expected);
697 
698         Ok(())
699     }
700 
701     /// verifies that when two limits are in place, we jump neither
702     #[test]
double_limit() -> Result<()>703     fn double_limit() -> Result<()> {
704         let table_scan = test_table_scan()?;
705         let plan = LogicalPlanBuilder::from(&table_scan)
706             .project(&[col("a"), col("b")])?
707             .limit(20)?
708             .limit(10)?
709             .project(&[col("a"), col("b")])?
710             .filter(col("a").eq(lit(1i64)))?
711             .build()?;
712         // filter does not just any of the limits
713         let expected = "\
714             Projection: #a, #b\
715             \n  Filter: #a Eq Int64(1)\
716             \n    Limit: 10\
717             \n      Limit: 20\
718             \n        Projection: #a, #b\
719             \n          TableScan: test projection=None";
720         assert_optimized_plan_eq(&plan, expected);
721         Ok(())
722     }
723 
724     /// verifies that filters with the same columns are correctly placed
725     #[test]
filter_2_breaks_limits() -> Result<()>726     fn filter_2_breaks_limits() -> Result<()> {
727         let table_scan = test_table_scan()?;
728         let plan = LogicalPlanBuilder::from(&table_scan)
729             .project(&[col("a")])?
730             .filter(col("a").lt_eq(lit(1i64)))?
731             .limit(1)?
732             .project(&[col("a")])?
733             .filter(col("a").gt_eq(lit(1i64)))?
734             .build()?;
735         // Should be able to move both filters below the projections
736 
737         // not part of the test
738         assert_eq!(
739             format!("{:?}", plan),
740             "Filter: #a GtEq Int64(1)\
741              \n  Projection: #a\
742              \n    Limit: 1\
743              \n      Filter: #a LtEq Int64(1)\
744              \n        Projection: #a\
745              \n          TableScan: test projection=None"
746         );
747 
748         let expected = "\
749         Projection: #a\
750         \n  Filter: #a GtEq Int64(1)\
751         \n    Limit: 1\
752         \n      Projection: #a\
753         \n        Filter: #a LtEq Int64(1)\
754         \n          TableScan: test projection=None";
755 
756         assert_optimized_plan_eq(&plan, expected);
757         Ok(())
758     }
759 
760     /// verifies that filters to be placed on the same depth are ANDed
761     #[test]
two_filters_on_same_depth() -> Result<()>762     fn two_filters_on_same_depth() -> Result<()> {
763         let table_scan = test_table_scan()?;
764         let plan = LogicalPlanBuilder::from(&table_scan)
765             .limit(1)?
766             .filter(col("a").lt_eq(lit(1i64)))?
767             .filter(col("a").gt_eq(lit(1i64)))?
768             .project(&[col("a")])?
769             .build()?;
770 
771         // not part of the test
772         assert_eq!(
773             format!("{:?}", plan),
774             "Projection: #a\
775             \n  Filter: #a GtEq Int64(1)\
776             \n    Filter: #a LtEq Int64(1)\
777             \n      Limit: 1\
778             \n        TableScan: test projection=None"
779         );
780 
781         let expected = "\
782         Projection: #a\
783         \n  Filter: #a GtEq Int64(1) And #a LtEq Int64(1)\
784         \n    Limit: 1\
785         \n      TableScan: test projection=None";
786 
787         assert_optimized_plan_eq(&plan, expected);
788         Ok(())
789     }
790 
791     /// verifies that filters on a plan with user nodes are not lost
792     /// (ARROW-10547)
793     #[test]
filters_user_defined_node() -> Result<()>794     fn filters_user_defined_node() -> Result<()> {
795         let table_scan = test_table_scan()?;
796         let plan = LogicalPlanBuilder::from(&table_scan)
797             .filter(col("a").lt_eq(lit(1i64)))?
798             .build()?;
799 
800         let plan = crate::test::user_defined::new(plan);
801 
802         let expected = "\
803             TestUserDefined\
804              \n  Filter: #a LtEq Int64(1)\
805              \n    TableScan: test projection=None";
806 
807         // not part of the test
808         assert_eq!(format!("{:?}", plan), expected);
809 
810         assert_optimized_plan_eq(&plan, expected);
811         Ok(())
812     }
813 
814     /// post-join predicates on a column common to both sides is pushed to both sides
815     #[test]
filter_join_on_common_independent() -> Result<()>816     fn filter_join_on_common_independent() -> Result<()> {
817         let table_scan = test_table_scan()?;
818         let left = LogicalPlanBuilder::from(&table_scan).build()?;
819         let right = LogicalPlanBuilder::from(&table_scan)
820             .project(&[col("a")])?
821             .build()?;
822         let plan = LogicalPlanBuilder::from(&left)
823             .join(&right, JoinType::Inner, &["a"], &["a"])?
824             .filter(col("a").lt_eq(lit(1i64)))?
825             .build()?;
826 
827         // not part of the test, just good to know:
828         assert_eq!(
829             format!("{:?}", plan),
830             "\
831             Filter: #a LtEq Int64(1)\
832             \n  Join: a = a\
833             \n    TableScan: test projection=None\
834             \n    Projection: #a\
835             \n      TableScan: test projection=None"
836         );
837 
838         // filter sent to side before the join
839         let expected = "\
840         Join: a = a\
841         \n  Filter: #a LtEq Int64(1)\
842         \n    TableScan: test projection=None\
843         \n  Projection: #a\
844         \n    Filter: #a LtEq Int64(1)\
845         \n      TableScan: test projection=None";
846         assert_optimized_plan_eq(&plan, expected);
847         Ok(())
848     }
849 
850     /// post-join predicates with columns from both sides are not pushed
851     #[test]
filter_join_on_common_dependent() -> Result<()>852     fn filter_join_on_common_dependent() -> Result<()> {
853         let table_scan = test_table_scan()?;
854         let left = LogicalPlanBuilder::from(&table_scan)
855             .project(&[col("a"), col("c")])?
856             .build()?;
857         let right = LogicalPlanBuilder::from(&table_scan)
858             .project(&[col("a"), col("b")])?
859             .build()?;
860         let plan = LogicalPlanBuilder::from(&left)
861             .join(&right, JoinType::Inner, &["a"], &["a"])?
862             // "b" and "c" are not shared by either side: they are only available together after the join
863             .filter(col("c").lt_eq(col("b")))?
864             .build()?;
865 
866         // not part of the test, just good to know:
867         assert_eq!(
868             format!("{:?}", plan),
869             "\
870             Filter: #c LtEq #b\
871             \n  Join: a = a\
872             \n    Projection: #a, #c\
873             \n      TableScan: test projection=None\
874             \n    Projection: #a, #b\
875             \n      TableScan: test projection=None"
876         );
877 
878         // expected is equal: no push-down
879         let expected = &format!("{:?}", plan);
880         assert_optimized_plan_eq(&plan, expected);
881         Ok(())
882     }
883 
884     /// post-join predicates with columns from one side of a join are pushed only to that side
885     #[test]
filter_join_on_one_side() -> Result<()>886     fn filter_join_on_one_side() -> Result<()> {
887         let table_scan = test_table_scan()?;
888         let left = LogicalPlanBuilder::from(&table_scan)
889             .project(&[col("a"), col("b")])?
890             .build()?;
891         let right = LogicalPlanBuilder::from(&table_scan)
892             .project(&[col("a"), col("c")])?
893             .build()?;
894         let plan = LogicalPlanBuilder::from(&left)
895             .join(&right, JoinType::Inner, &["a"], &["a"])?
896             .filter(col("b").lt_eq(lit(1i64)))?
897             .build()?;
898 
899         // not part of the test, just good to know:
900         assert_eq!(
901             format!("{:?}", plan),
902             "\
903             Filter: #b LtEq Int64(1)\
904             \n  Join: a = a\
905             \n    Projection: #a, #b\
906             \n      TableScan: test projection=None\
907             \n    Projection: #a, #c\
908             \n      TableScan: test projection=None"
909         );
910 
911         let expected = "\
912         Join: a = a\
913         \n  Projection: #a, #b\
914         \n    Filter: #b LtEq Int64(1)\
915         \n      TableScan: test projection=None\
916         \n  Projection: #a, #c\
917         \n    TableScan: test projection=None";
918         assert_optimized_plan_eq(&plan, expected);
919         Ok(())
920     }
921 
922     struct PushDownProvider {
923         pub filter_support: TableProviderFilterPushDown,
924     }
925 
926     impl TableProvider for PushDownProvider {
schema(&self) -> SchemaRef927         fn schema(&self) -> SchemaRef {
928             Arc::new(arrow::datatypes::Schema::new(vec![
929                 arrow::datatypes::Field::new(
930                     "a",
931                     arrow::datatypes::DataType::Int32,
932                     true,
933                 ),
934             ]))
935         }
936 
scan( &self, _: &Option<Vec<usize>>, _: usize, _: &[Expr], ) -> Result<Arc<dyn ExecutionPlan>>937         fn scan(
938             &self,
939             _: &Option<Vec<usize>>,
940             _: usize,
941             _: &[Expr],
942         ) -> Result<Arc<dyn ExecutionPlan>> {
943             unimplemented!()
944         }
945 
supports_filter_pushdown( &self, _: &Expr, ) -> Result<TableProviderFilterPushDown>946         fn supports_filter_pushdown(
947             &self,
948             _: &Expr,
949         ) -> Result<TableProviderFilterPushDown> {
950             Ok(self.filter_support.clone())
951         }
952 
as_any(&self) -> &dyn std::any::Any953         fn as_any(&self) -> &dyn std::any::Any {
954             self
955         }
956 
statistics(&self) -> Statistics957         fn statistics(&self) -> Statistics {
958             Statistics::default()
959         }
960     }
961 
table_scan_with_pushdown_provider( filter_support: TableProviderFilterPushDown, ) -> Result<LogicalPlan>962     fn table_scan_with_pushdown_provider(
963         filter_support: TableProviderFilterPushDown,
964     ) -> Result<LogicalPlan> {
965         let test_provider = PushDownProvider { filter_support };
966 
967         let table_scan = LogicalPlan::TableScan {
968             table_name: "".into(),
969             filters: vec![],
970             projected_schema: Arc::new(DFSchema::try_from_qualified(
971                 "",
972                 &*test_provider.schema(),
973             )?),
974             projection: None,
975             source: Arc::new(test_provider),
976         };
977 
978         LogicalPlanBuilder::from(&table_scan)
979             .filter(col("a").eq(lit(1i64)))?
980             .build()
981     }
982 
983     #[test]
filter_with_table_provider_exact() -> Result<()>984     fn filter_with_table_provider_exact() -> Result<()> {
985         let plan = table_scan_with_pushdown_provider(TableProviderFilterPushDown::Exact)?;
986 
987         let expected = "\
988         TableScan: projection=None, filters=[#a Eq Int64(1)]";
989         assert_optimized_plan_eq(&plan, expected);
990         Ok(())
991     }
992 
993     #[test]
filter_with_table_provider_inexact() -> Result<()>994     fn filter_with_table_provider_inexact() -> Result<()> {
995         let plan =
996             table_scan_with_pushdown_provider(TableProviderFilterPushDown::Inexact)?;
997 
998         let expected = "\
999         Filter: #a Eq Int64(1)\
1000         \n  TableScan: projection=None, filters=[#a Eq Int64(1)]";
1001         assert_optimized_plan_eq(&plan, expected);
1002         Ok(())
1003     }
1004 
1005     #[test]
filter_with_table_provider_unsupported() -> Result<()>1006     fn filter_with_table_provider_unsupported() -> Result<()> {
1007         let plan =
1008             table_scan_with_pushdown_provider(TableProviderFilterPushDown::Unsupported)?;
1009 
1010         let expected = "\
1011         Filter: #a Eq Int64(1)\
1012         \n  TableScan: projection=None";
1013         assert_optimized_plan_eq(&plan, expected);
1014         Ok(())
1015     }
1016 }
1017