1 // Licensed to the Apache Software Foundation (ASF) under one
2 // or more contributor license agreements.  See the NOTICE file
3 // distributed with this work for additional information
4 // regarding copyright ownership.  The ASF licenses this file
5 // to you under the Apache License, Version 2.0 (the
6 // "License"); you may not use this file except in compliance
7 // with the License.  You may obtain a copy of the License at
8 //
9 //   http://www.apache.org/licenses/LICENSE-2.0
10 //
11 // Unless required by applicable law or agreed to in writing,
12 // software distributed under the License is distributed on an
13 // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14 // KIND, either express or implied.  See the License for the
15 // specific language governing permissions and limitations
16 // under the License
17 
18 //! Optimizer rule to switch build and probe order of hash join
19 //! based on statistics of a `TableProvider`. If the number of
20 //! rows of both sources is known, the order can be switched
21 //! for a faster hash join.
22 
23 use std::sync::Arc;
24 
25 use crate::logical_plan::LogicalPlan;
26 use crate::optimizer::optimizer::OptimizerRule;
27 use crate::{error::Result, prelude::JoinType};
28 
29 use super::utils;
30 
31 /// BuildProbeOrder reorders the build and probe phase of
32 /// hash joins. This uses the amount of rows that a datasource has.
33 /// The rule optimizes the order such that the left (build) side of the join
34 /// is the smallest.
35 /// If the information is not available, the order stays the same,
36 /// so that it could be optimized manually in a query.
37 pub struct HashBuildProbeOrder {}
38 
39 // Gets exact number of rows, if known by the statistics of the underlying
get_num_rows(logical_plan: &LogicalPlan) -> Option<usize>40 fn get_num_rows(logical_plan: &LogicalPlan) -> Option<usize> {
41     match logical_plan {
42         LogicalPlan::TableScan { source, .. } => source.statistics().num_rows,
43         LogicalPlan::EmptyRelation {
44             produce_one_row, ..
45         } => {
46             if *produce_one_row {
47                 Some(1)
48             } else {
49                 Some(0)
50             }
51         }
52         LogicalPlan::Limit { n: limit, input } => {
53             let num_rows_input = get_num_rows(input);
54             num_rows_input.map(|rows| std::cmp::min(*limit, rows))
55         }
56         LogicalPlan::Aggregate { .. } => {
57             // we cannot yet predict how many rows will be produced by an aggregate because
58             // we do not know the cardinality of the grouping keys
59             None
60         }
61         LogicalPlan::Filter { .. } => {
62             // we cannot yet predict how many rows will be produced by a filter because
63             // we don't know how selective it is (how many rows it will filter out)
64             None
65         }
66         LogicalPlan::Join { .. } => {
67             // we cannot predict the cardinality of the join output
68             None
69         }
70         LogicalPlan::Repartition { .. } => {
71             // we cannot predict how rows will be repartitioned
72             None
73         }
74         // the following operators are special cases and not querying data
75         LogicalPlan::CreateExternalTable { .. } => None,
76         LogicalPlan::Explain { .. } => None,
77         // we do not support estimating rows with extensions yet
78         LogicalPlan::Extension { .. } => None,
79         // the following operators do not modify row count in any way
80         LogicalPlan::Projection { input, .. } => get_num_rows(input),
81         LogicalPlan::Sort { input, .. } => get_num_rows(input),
82     }
83 }
84 
85 // Finds out whether to swap left vs right order based on statistics
should_swap_join_order(left: &LogicalPlan, right: &LogicalPlan) -> bool86 fn should_swap_join_order(left: &LogicalPlan, right: &LogicalPlan) -> bool {
87     let left_rows = get_num_rows(left);
88     let right_rows = get_num_rows(right);
89 
90     match (left_rows, right_rows) {
91         (Some(l), Some(r)) => l > r,
92         _ => false,
93     }
94 }
95 
96 impl OptimizerRule for HashBuildProbeOrder {
name(&self) -> &str97     fn name(&self) -> &str {
98         "hash_build_probe_order"
99     }
100 
optimize(&self, plan: &LogicalPlan) -> Result<LogicalPlan>101     fn optimize(&self, plan: &LogicalPlan) -> Result<LogicalPlan> {
102         match plan {
103             // Main optimization rule, swaps order of left and right
104             // based on number of rows in each table
105             LogicalPlan::Join {
106                 left,
107                 right,
108                 on,
109                 join_type,
110                 schema,
111             } => {
112                 let left = self.optimize(left)?;
113                 let right = self.optimize(right)?;
114                 if should_swap_join_order(&left, &right) {
115                     // Swap left and right, change join type and (equi-)join key order
116                     Ok(LogicalPlan::Join {
117                         left: Arc::new(right),
118                         right: Arc::new(left),
119                         on: on
120                             .iter()
121                             .map(|(l, r)| (r.to_string(), l.to_string()))
122                             .collect(),
123                         join_type: swap_join_type(*join_type),
124                         schema: schema.clone(),
125                     })
126                 } else {
127                     // Keep join as is
128                     Ok(LogicalPlan::Join {
129                         left: Arc::new(left),
130                         right: Arc::new(right),
131                         on: on.clone(),
132                         join_type: *join_type,
133                         schema: schema.clone(),
134                     })
135                 }
136             }
137             // Rest: recurse into plan, apply optimization where possible
138             LogicalPlan::Projection { .. }
139             | LogicalPlan::Aggregate { .. }
140             | LogicalPlan::TableScan { .. }
141             | LogicalPlan::Limit { .. }
142             | LogicalPlan::Filter { .. }
143             | LogicalPlan::Repartition { .. }
144             | LogicalPlan::EmptyRelation { .. }
145             | LogicalPlan::Sort { .. }
146             | LogicalPlan::CreateExternalTable { .. }
147             | LogicalPlan::Explain { .. }
148             | LogicalPlan::Extension { .. } => {
149                 let expr = utils::expressions(plan);
150 
151                 // apply the optimization to all inputs of the plan
152                 let inputs = utils::inputs(plan);
153                 let new_inputs = inputs
154                     .iter()
155                     .map(|plan| self.optimize(plan))
156                     .collect::<Result<Vec<_>>>()?;
157 
158                 utils::from_plan(plan, &expr, &new_inputs)
159             }
160         }
161     }
162 }
163 
164 impl HashBuildProbeOrder {
165     #[allow(missing_docs)]
new() -> Self166     pub fn new() -> Self {
167         Self {}
168     }
169 }
170 
swap_join_type(join_type: JoinType) -> JoinType171 fn swap_join_type(join_type: JoinType) -> JoinType {
172     match join_type {
173         JoinType::Inner => JoinType::Inner,
174         JoinType::Left => JoinType::Right,
175         JoinType::Right => JoinType::Left,
176     }
177 }
178 
179 #[cfg(test)]
180 mod tests {
181     use super::*;
182     use std::sync::Arc;
183 
184     use crate::{
185         datasource::{datasource::Statistics, TableProvider},
186         logical_plan::{DFSchema, Expr},
187         test::*,
188     };
189 
190     struct TestTableProvider {
191         num_rows: usize,
192     }
193 
194     impl TableProvider for TestTableProvider {
as_any(&self) -> &dyn std::any::Any195         fn as_any(&self) -> &dyn std::any::Any {
196             unimplemented!()
197         }
schema(&self) -> arrow::datatypes::SchemaRef198         fn schema(&self) -> arrow::datatypes::SchemaRef {
199             unimplemented!()
200         }
201 
scan( &self, _projection: &Option<Vec<usize>>, _batch_size: usize, _filters: &[Expr], ) -> Result<std::sync::Arc<dyn crate::physical_plan::ExecutionPlan>>202         fn scan(
203             &self,
204             _projection: &Option<Vec<usize>>,
205             _batch_size: usize,
206             _filters: &[Expr],
207         ) -> Result<std::sync::Arc<dyn crate::physical_plan::ExecutionPlan>> {
208             unimplemented!()
209         }
statistics(&self) -> crate::datasource::datasource::Statistics210         fn statistics(&self) -> crate::datasource::datasource::Statistics {
211             Statistics {
212                 num_rows: Some(self.num_rows),
213                 total_byte_size: None,
214                 column_statistics: None,
215             }
216         }
217     }
218 
219     #[test]
test_num_rows() -> Result<()>220     fn test_num_rows() -> Result<()> {
221         let table_scan = test_table_scan()?;
222 
223         assert_eq!(get_num_rows(&table_scan), Some(0));
224 
225         Ok(())
226     }
227 
228     #[test]
test_swap_order()229     fn test_swap_order() {
230         let lp_left = LogicalPlan::TableScan {
231             table_name: "left".to_string(),
232             projection: None,
233             source: Arc::new(TestTableProvider { num_rows: 1000 }),
234             projected_schema: Arc::new(DFSchema::empty()),
235             filters: vec![],
236         };
237 
238         let lp_right = LogicalPlan::TableScan {
239             table_name: "right".to_string(),
240             projection: None,
241             source: Arc::new(TestTableProvider { num_rows: 100 }),
242             projected_schema: Arc::new(DFSchema::empty()),
243             filters: vec![],
244         };
245 
246         assert!(should_swap_join_order(&lp_left, &lp_right));
247         assert!(!should_swap_join_order(&lp_right, &lp_left));
248     }
249 }
250