1 /*-------------------------------------------------------------------------
2 *
3 * multi_logical_planner.c
4 *
5 * Routines for constructing a logical plan tree from the given Query tree
6 * structure. This new logical plan is based on multi-relational algebra rules.
7 *
8 * Copyright (c) Citus Data, Inc.
9 *
10 * $Id$
11 *
12 *-------------------------------------------------------------------------
13 */
14
15 #include "postgres.h"
16
17 #include "distributed/pg_version_constants.h"
18
19 #include "access/heapam.h"
20 #include "access/nbtree.h"
21 #include "catalog/pg_am.h"
22 #include "catalog/pg_class.h"
23 #include "commands/defrem.h"
24 #include "distributed/citus_clauses.h"
25 #include "distributed/colocation_utils.h"
26 #include "distributed/metadata_cache.h"
27 #include "distributed/insert_select_planner.h"
28 #include "distributed/listutils.h"
29 #include "distributed/multi_logical_optimizer.h"
30 #include "distributed/multi_logical_planner.h"
31 #include "distributed/multi_physical_planner.h"
32 #include "distributed/reference_table_utils.h"
33 #include "distributed/relation_restriction_equivalence.h"
34 #include "distributed/query_pushdown_planning.h"
35 #include "distributed/query_utils.h"
36 #include "distributed/multi_router_planner.h"
37 #include "distributed/worker_protocol.h"
38 #include "distributed/version_compat.h"
39 #include "nodes/makefuncs.h"
40 #include "nodes/nodeFuncs.h"
41 #include "nodes/pathnodes.h"
42 #include "optimizer/optimizer.h"
43 #include "optimizer/clauses.h"
44 #include "optimizer/prep.h"
45 #include "optimizer/tlist.h"
46 #include "parser/parsetree.h"
47 #include "utils/builtins.h"
48 #include "utils/datum.h"
49 #include "utils/lsyscache.h"
50 #include "utils/syscache.h"
51 #include "utils/rel.h"
52 #include "utils/relcache.h"
53
54
55 /* Struct to differentiate different qualifier types in an expression tree walker */
56 typedef struct QualifierWalkerContext
57 {
58 List *baseQualifierList;
59 List *outerJoinQualifierList;
60 } QualifierWalkerContext;
61
62
63 /* Function pointer type definition for apply join rule functions */
64 typedef MultiNode *(*RuleApplyFunction) (MultiNode *leftNode, MultiNode *rightNode,
65 List *partitionColumnList, JoinType joinType,
66 List *joinClauses);
67
68 typedef bool (*CheckNodeFunc)(Node *);
69
70 static RuleApplyFunction RuleApplyFunctionArray[JOIN_RULE_LAST] = { 0 }; /* join rules */
71
72 /* Local functions forward declarations */
73 static FieldSelect * CompositeFieldRecursive(Expr *expression, Query *query);
74 static Oid NodeTryGetRteRelid(Node *node);
75 static bool FullCompositeFieldList(List *compositeFieldList);
76 static bool HasUnsupportedJoinWalker(Node *node, void *context);
77 static bool ErrorHintRequired(const char *errorHint, Query *queryTree);
78 static bool HasTablesample(Query *queryTree);
79 static bool HasComplexRangeTableType(Query *queryTree);
80 static bool IsReadIntermediateResultFunction(Node *node);
81 static bool IsReadIntermediateResultArrayFunction(Node *node);
82 static bool IsCitusExtraDataContainerFunc(Node *node);
83 static bool IsFunctionWithOid(Node *node, Oid funcOid);
84 static bool IsGroupingFunc(Node *node);
85 static bool ExtractFromExpressionWalker(Node *node,
86 QualifierWalkerContext *walkerContext);
87 static List * MultiTableNodeList(List *tableEntryList, List *rangeTableList);
88 static List * AddMultiCollectNodes(List *tableNodeList);
89 static MultiNode * MultiJoinTree(List *joinOrderList, List *collectTableList,
90 List *joinClauseList);
91 static MultiCollect * CollectNodeForTable(List *collectTableList, uint32 rangeTableId);
92 static MultiSelect * MultiSelectNode(List *whereClauseList);
93 static bool IsSelectClause(Node *clause);
94
95 /* Local functions forward declarations for applying joins */
96 static MultiNode * ApplyJoinRule(MultiNode *leftNode, MultiNode *rightNode,
97 JoinRuleType ruleType, List *partitionColumnList,
98 JoinType joinType, List *joinClauseList);
99 static RuleApplyFunction JoinRuleApplyFunction(JoinRuleType ruleType);
100 static MultiNode * ApplyReferenceJoin(MultiNode *leftNode, MultiNode *rightNode,
101 List *partitionColumnList, JoinType joinType,
102 List *joinClauses);
103 static MultiNode * ApplyLocalJoin(MultiNode *leftNode, MultiNode *rightNode,
104 List *partitionColumnList, JoinType joinType,
105 List *joinClauses);
106 static MultiNode * ApplySingleRangePartitionJoin(MultiNode *leftNode,
107 MultiNode *rightNode,
108 List *partitionColumnList,
109 JoinType joinType,
110 List *applicableJoinClauses);
111 static MultiNode * ApplySingleHashPartitionJoin(MultiNode *leftNode,
112 MultiNode *rightNode,
113 List *partitionColumnList,
114 JoinType joinType,
115 List *applicableJoinClauses);
116 static MultiJoin * ApplySinglePartitionJoin(MultiNode *leftNode, MultiNode *rightNode,
117 List *partitionColumnList, JoinType joinType,
118 List *joinClauses);
119 static MultiNode * ApplyDualPartitionJoin(MultiNode *leftNode, MultiNode *rightNode,
120 List *partitionColumnList, JoinType joinType,
121 List *joinClauses);
122 static MultiNode * ApplyCartesianProductReferenceJoin(MultiNode *leftNode,
123 MultiNode *rightNode,
124 List *partitionColumnList,
125 JoinType joinType,
126 List *joinClauses);
127 static MultiNode * ApplyCartesianProduct(MultiNode *leftNode, MultiNode *rightNode,
128 List *partitionColumnList, JoinType joinType,
129 List *joinClauses);
130
131
132 /*
133 * MultiLogicalPlanCreate takes in both the original query and its corresponding modified
134 * query tree yield by the standard planner. It uses helper functions to create logical
135 * plan and adds a root node to top of it. The original query is only used for subquery
136 * pushdown planning.
137 *
138 * We also pass queryTree and plannerRestrictionContext to the planner. They
139 * are primarily used to decide whether the subquery is safe to pushdown.
140 * If not, it helps to produce meaningful error messages for subquery
141 * pushdown planning.
142 */
143 MultiTreeRoot *
MultiLogicalPlanCreate(Query * originalQuery,Query * queryTree,PlannerRestrictionContext * plannerRestrictionContext)144 MultiLogicalPlanCreate(Query *originalQuery, Query *queryTree,
145 PlannerRestrictionContext *plannerRestrictionContext)
146 {
147 MultiNode *multiQueryNode = NULL;
148
149
150 if (ShouldUseSubqueryPushDown(originalQuery, queryTree, plannerRestrictionContext))
151 {
152 multiQueryNode = SubqueryMultiNodeTree(originalQuery, queryTree,
153 plannerRestrictionContext);
154 }
155 else
156 {
157 multiQueryNode = MultiNodeTree(queryTree);
158 }
159
160 /* add a root node to serve as the permanent handle to the tree */
161 MultiTreeRoot *rootNode = CitusMakeNode(MultiTreeRoot);
162 SetChild((MultiUnaryNode *) rootNode, multiQueryNode);
163
164 return rootNode;
165 }
166
167
168 /*
169 * FindNodeMatchingCheckFunction finds a node for which the checker function returns true.
170 *
171 * To call this function directly with an RTE, use:
172 * range_table_walker(rte, FindNodeMatchingCheckFunction, checker, QTW_EXAMINE_RTES_BEFORE)
173 */
174 bool
FindNodeMatchingCheckFunction(Node * node,CheckNodeFunc checker)175 FindNodeMatchingCheckFunction(Node *node, CheckNodeFunc checker)
176 {
177 if (node == NULL)
178 {
179 return false;
180 }
181
182 if (checker(node))
183 {
184 return true;
185 }
186
187 if (IsA(node, RangeTblEntry))
188 {
189 /* query_tree_walker descends into RTEs */
190 return false;
191 }
192 else if (IsA(node, Query))
193 {
194 return query_tree_walker((Query *) node, FindNodeMatchingCheckFunction, checker,
195 QTW_EXAMINE_RTES_BEFORE);
196 }
197
198 return expression_tree_walker(node, FindNodeMatchingCheckFunction, checker);
199 }
200
201
202 /*
203 * TargetListOnPartitionColumn checks if at least one target list entry is on
204 * partition column.
205 */
206 bool
TargetListOnPartitionColumn(Query * query,List * targetEntryList)207 TargetListOnPartitionColumn(Query *query, List *targetEntryList)
208 {
209 bool targetListOnPartitionColumn = false;
210 List *compositeFieldList = NIL;
211
212 ListCell *targetEntryCell = NULL;
213 foreach(targetEntryCell, targetEntryList)
214 {
215 TargetEntry *targetEntry = (TargetEntry *) lfirst(targetEntryCell);
216 Expr *targetExpression = targetEntry->expr;
217
218 bool skipOuterVars = true;
219 bool isPartitionColumn = IsPartitionColumn(targetExpression, query,
220 skipOuterVars);
221 Var *column = NULL;
222 RangeTblEntry *rte = NULL;
223
224 FindReferencedTableColumn(targetExpression, NIL, query, &column, &rte,
225 skipOuterVars);
226 Oid relationId = rte ? rte->relid : InvalidOid;
227
228 /*
229 * If the expression belongs to a non-distributed table continue searching for
230 * other partition keys.
231 */
232 if (IsCitusTableType(relationId, CITUS_TABLE_WITH_NO_DIST_KEY))
233 {
234 continue;
235 }
236
237 if (isPartitionColumn)
238 {
239 FieldSelect *compositeField = CompositeFieldRecursive(targetExpression,
240 query);
241 if (compositeField)
242 {
243 compositeFieldList = lappend(compositeFieldList, compositeField);
244 }
245 else
246 {
247 targetListOnPartitionColumn = true;
248 break;
249 }
250 }
251 }
252
253 /* check composite fields */
254 if (!targetListOnPartitionColumn)
255 {
256 bool fullCompositeFieldList = FullCompositeFieldList(compositeFieldList);
257 if (fullCompositeFieldList)
258 {
259 targetListOnPartitionColumn = true;
260 }
261 }
262
263 /*
264 * We could still behave as if the target list is on partition column if
265 * range table entries don't contain a distributed table.
266 */
267 if (!targetListOnPartitionColumn)
268 {
269 if (!FindNodeMatchingCheckFunctionInRangeTableList(query->rtable,
270 IsDistributedTableRTE))
271 {
272 targetListOnPartitionColumn = true;
273 }
274 }
275
276 return targetListOnPartitionColumn;
277 }
278
279
280 /*
281 * FindNodeMatchingCheckFunctionInRangeTableList finds a node for which the checker
282 * function returns true.
283 *
284 * FindNodeMatchingCheckFunctionInRangeTableList relies on
285 * FindNodeMatchingCheckFunction() but only considers the range table entries.
286 */
287 bool
FindNodeMatchingCheckFunctionInRangeTableList(List * rtable,CheckNodeFunc checker)288 FindNodeMatchingCheckFunctionInRangeTableList(List *rtable, CheckNodeFunc checker)
289 {
290 return range_table_walker(rtable, FindNodeMatchingCheckFunction, checker,
291 QTW_EXAMINE_RTES_BEFORE);
292 }
293
294
295 /*
296 * NodeTryGetRteRelid returns the relid of the given RTE_RELATION RangeTableEntry.
297 * Returns InvalidOid if any of these assumptions fail for given node.
298 */
299 static Oid
NodeTryGetRteRelid(Node * node)300 NodeTryGetRteRelid(Node *node)
301 {
302 if (node == NULL)
303 {
304 return InvalidOid;
305 }
306
307 if (!IsA(node, RangeTblEntry))
308 {
309 return InvalidOid;
310 }
311
312 RangeTblEntry *rangeTableEntry = (RangeTblEntry *) node;
313
314 if (rangeTableEntry->rtekind != RTE_RELATION)
315 {
316 return InvalidOid;
317 }
318
319 return rangeTableEntry->relid;
320 }
321
322
323 /*
324 * IsCitusTableRTE gets a node and returns true if the node is a
325 * range table relation entry that points to a distributed relation.
326 */
327 bool
IsCitusTableRTE(Node * node)328 IsCitusTableRTE(Node *node)
329 {
330 Oid relationId = NodeTryGetRteRelid(node);
331 return relationId != InvalidOid && IsCitusTable(relationId);
332 }
333
334
335 /*
336 * IsDistributedOrReferenceTableRTE returns true if the given node
337 * is eeither a distributed(hash/range/append) or reference table.
338 */
339 bool
IsDistributedOrReferenceTableRTE(Node * node)340 IsDistributedOrReferenceTableRTE(Node *node)
341 {
342 Oid relationId = NodeTryGetRteRelid(node);
343 if (!OidIsValid(relationId))
344 {
345 return false;
346 }
347 return IsCitusTableType(relationId, DISTRIBUTED_TABLE) ||
348 IsCitusTableType(relationId, REFERENCE_TABLE);
349 }
350
351
352 /*
353 * IsDistributedTableRTE gets a node and returns true if the node
354 * is a range table relation entry that points to a distributed relation,
355 * returning false still if the relation is a reference table.
356 */
357 bool
IsDistributedTableRTE(Node * node)358 IsDistributedTableRTE(Node *node)
359 {
360 Oid relationId = NodeTryGetRteRelid(node);
361 return relationId != InvalidOid && IsCitusTableType(relationId, DISTRIBUTED_TABLE);
362 }
363
364
365 /*
366 * IsReferenceTableRTE gets a node and returns true if the node
367 * is a range table relation entry that points to a reference table.
368 */
369 bool
IsReferenceTableRTE(Node * node)370 IsReferenceTableRTE(Node *node)
371 {
372 Oid relationId = NodeTryGetRteRelid(node);
373 return relationId != InvalidOid && IsCitusTableType(relationId, REFERENCE_TABLE);
374 }
375
376
377 /*
378 * FullCompositeFieldList gets a composite field list, and checks if all fields
379 * of composite type are used in the list.
380 */
381 static bool
FullCompositeFieldList(List * compositeFieldList)382 FullCompositeFieldList(List *compositeFieldList)
383 {
384 bool fullCompositeFieldList = true;
385 bool *compositeFieldArray = NULL;
386 uint32 compositeFieldCount = 0;
387
388 ListCell *fieldSelectCell = NULL;
389 foreach(fieldSelectCell, compositeFieldList)
390 {
391 FieldSelect *fieldSelect = (FieldSelect *) lfirst(fieldSelectCell);
392
393 Expr *fieldExpression = fieldSelect->arg;
394 if (!IsA(fieldExpression, Var))
395 {
396 continue;
397 }
398
399 if (compositeFieldArray == NULL)
400 {
401 Var *compositeColumn = (Var *) fieldExpression;
402 Oid compositeTypeId = compositeColumn->vartype;
403 Oid compositeRelationId = get_typ_typrelid(compositeTypeId);
404
405 /* get composite type attribute count */
406 Relation relation = relation_open(compositeRelationId, AccessShareLock);
407 compositeFieldCount = relation->rd_att->natts;
408 compositeFieldArray = palloc0(compositeFieldCount * sizeof(bool));
409 relation_close(relation, AccessShareLock);
410
411 for (uint32 compositeFieldIndex = 0;
412 compositeFieldIndex < compositeFieldCount;
413 compositeFieldIndex++)
414 {
415 compositeFieldArray[compositeFieldIndex] = false;
416 }
417 }
418
419 uint32 compositeFieldIndex = fieldSelect->fieldnum - 1;
420 compositeFieldArray[compositeFieldIndex] = true;
421 }
422
423 for (uint32 fieldIndex = 0; fieldIndex < compositeFieldCount; fieldIndex++)
424 {
425 if (!compositeFieldArray[fieldIndex])
426 {
427 fullCompositeFieldList = false;
428 }
429 }
430
431 if (compositeFieldCount == 0)
432 {
433 fullCompositeFieldList = false;
434 }
435
436 return fullCompositeFieldList;
437 }
438
439
440 /*
441 * CompositeFieldRecursive recursively finds composite field in the query tree
442 * referred by given expression. If expression does not refer to a composite
443 * field, then it returns NULL.
444 *
445 * If expression is a field select we directly return composite field. If it is
446 * a column is referenced from a subquery, then we recursively check that subquery
447 * until we reach the source of that column, and find composite field. If this
448 * column is referenced from join range table entry, then we resolve which join
449 * column it refers and recursively use this column with the same query.
450 */
451 static FieldSelect *
CompositeFieldRecursive(Expr * expression,Query * query)452 CompositeFieldRecursive(Expr *expression, Query *query)
453 {
454 FieldSelect *compositeField = NULL;
455 List *rangetableList = query->rtable;
456 Var *candidateColumn = NULL;
457
458 if (IsA(expression, FieldSelect))
459 {
460 compositeField = (FieldSelect *) expression;
461 return compositeField;
462 }
463
464 if (IsA(expression, Var))
465 {
466 candidateColumn = (Var *) expression;
467 }
468 else
469 {
470 return NULL;
471 }
472
473 Index rangeTableEntryIndex = candidateColumn->varno - 1;
474 RangeTblEntry *rangeTableEntry = list_nth(rangetableList, rangeTableEntryIndex);
475
476 if (rangeTableEntry->rtekind == RTE_SUBQUERY)
477 {
478 Query *subquery = rangeTableEntry->subquery;
479 List *targetEntryList = subquery->targetList;
480 AttrNumber targetEntryIndex = candidateColumn->varattno - 1;
481 TargetEntry *subqueryTargetEntry = list_nth(targetEntryList, targetEntryIndex);
482
483 Expr *subqueryExpression = subqueryTargetEntry->expr;
484 compositeField = CompositeFieldRecursive(subqueryExpression, subquery);
485 }
486 else if (rangeTableEntry->rtekind == RTE_JOIN)
487 {
488 List *joinColumnList = rangeTableEntry->joinaliasvars;
489 AttrNumber joinColumnIndex = candidateColumn->varattno - 1;
490 Expr *joinColumn = list_nth(joinColumnList, joinColumnIndex);
491
492 compositeField = CompositeFieldRecursive(joinColumn, query);
493 }
494
495 return compositeField;
496 }
497
498
499 /*
500 * SubqueryEntryList finds the subquery nodes in the range table entry list, and
501 * builds a list of subquery range table entries from these subquery nodes. Range
502 * table entry list also includes subqueries which are pulled up. We don't want
503 * to add pulled up subqueries to list, so we walk over join tree indexes and
504 * check range table entries referenced in the join tree.
505 */
506 List *
SubqueryEntryList(Query * queryTree)507 SubqueryEntryList(Query *queryTree)
508 {
509 List *rangeTableList = queryTree->rtable;
510 List *subqueryEntryList = NIL;
511 List *joinTreeTableIndexList = NIL;
512 ListCell *joinTreeTableIndexCell = NULL;
513
514 /*
515 * Extract all range table indexes from the join tree. Note that here we
516 * only walk over range table entries at this level and do not recurse into
517 * subqueries.
518 */
519 ExtractRangeTableIndexWalker((Node *) queryTree->jointree, &joinTreeTableIndexList);
520 foreach(joinTreeTableIndexCell, joinTreeTableIndexList)
521 {
522 /*
523 * Join tree's range table index starts from 1 in the query tree. But,
524 * list indexes start from 0.
525 */
526 int joinTreeTableIndex = lfirst_int(joinTreeTableIndexCell);
527 int rangeTableListIndex = joinTreeTableIndex - 1;
528 RangeTblEntry *rangeTableEntry =
529 (RangeTblEntry *) list_nth(rangeTableList, rangeTableListIndex);
530
531 if (rangeTableEntry->rtekind == RTE_SUBQUERY)
532 {
533 subqueryEntryList = lappend(subqueryEntryList, rangeTableEntry);
534 }
535 }
536
537 return subqueryEntryList;
538 }
539
540
541 /*
542 * MultiNodeTree takes in a parsed query tree and uses that tree to construct a
543 * logical plan. This plan is based on multi-relational algebra. This function
544 * creates the logical plan in several steps.
545 *
546 * First, the function checks if there is a subquery. If there is a subquery
547 * it recursively creates nested multi trees. If this query has a subquery, the
548 * function does not create any join trees and jumps to last step.
549 *
550 * If there is no subquery, the function calculates the join order using tables
551 * in the query and join clauses between the tables. Second, the function
552 * starts building the logical plan from the bottom-up, and begins with the table
553 * and collect nodes. Third, the function builds the join tree using the join
554 * order information and table nodes.
555 *
556 * In the last step, the function adds the select, project, aggregate, sort,
557 * group, and limit nodes if they appear in the original query tree.
558 */
559 MultiNode *
MultiNodeTree(Query * queryTree)560 MultiNodeTree(Query *queryTree)
561 {
562 List *rangeTableList = queryTree->rtable;
563 List *targetEntryList = queryTree->targetList;
564 List *joinClauseList = NIL;
565 List *joinOrderList = NIL;
566 List *tableEntryList = NIL;
567 List *tableNodeList = NIL;
568 List *collectTableList = NIL;
569 MultiNode *joinTreeNode = NULL;
570 MultiNode *currentTopNode = NULL;
571
572 /* verify we can perform distributed planning on this query */
573 DeferredErrorMessage *unsupportedQueryError = DeferErrorIfQueryNotSupported(
574 queryTree);
575 if (unsupportedQueryError != NULL)
576 {
577 RaiseDeferredError(unsupportedQueryError, ERROR);
578 }
579
580 /* extract where clause qualifiers and verify we can plan for them */
581 List *whereClauseList = WhereClauseList(queryTree->jointree);
582 unsupportedQueryError = DeferErrorIfUnsupportedClause(whereClauseList);
583 if (unsupportedQueryError)
584 {
585 RaiseDeferredErrorInternal(unsupportedQueryError, ERROR);
586 }
587
588 /*
589 * If we have a subquery, build a multi table node for the subquery and
590 * add a collect node on top of the multi table node.
591 */
592 List *subqueryEntryList = SubqueryEntryList(queryTree);
593 if (subqueryEntryList != NIL)
594 {
595 MultiCollect *subqueryCollectNode = CitusMakeNode(MultiCollect);
596 ListCell *columnCell = NULL;
597
598 /* we only support single subquery in the entry list */
599 Assert(list_length(subqueryEntryList) == 1);
600
601 RangeTblEntry *subqueryRangeTableEntry = (RangeTblEntry *) linitial(
602 subqueryEntryList);
603 Query *subqueryTree = subqueryRangeTableEntry->subquery;
604
605 /* ensure if subquery satisfies preconditions */
606 Assert(DeferErrorIfUnsupportedSubqueryRepartition(subqueryTree) == NULL);
607
608 MultiTable *subqueryNode = CitusMakeNode(MultiTable);
609 subqueryNode->relationId = SUBQUERY_RELATION_ID;
610 subqueryNode->rangeTableId = SUBQUERY_RANGE_TABLE_ID;
611 subqueryNode->partitionColumn = NULL;
612 subqueryNode->alias = NULL;
613 subqueryNode->referenceNames = NULL;
614
615 /*
616 * We disregard pulled subqueries. This changes order of range table list.
617 * We do not allow subquery joins, so we will have only one range table
618 * entry in range table list after dropping pulled subquery. For this
619 * reason, here we are updating columns in the most outer query for where
620 * clause list and target list accordingly.
621 */
622 Assert(list_length(subqueryEntryList) == 1);
623
624 List *whereClauseColumnList = pull_var_clause_default((Node *) whereClauseList);
625 List *targetListColumnList = pull_var_clause_default((Node *) targetEntryList);
626
627 List *columnList = list_concat(whereClauseColumnList, targetListColumnList);
628 foreach(columnCell, columnList)
629 {
630 Var *column = (Var *) lfirst(columnCell);
631 column->varno = 1;
632 }
633
634 /* recursively create child nested multitree */
635 MultiNode *subqueryExtendedNode = MultiNodeTree(subqueryTree);
636
637 SetChild((MultiUnaryNode *) subqueryCollectNode, (MultiNode *) subqueryNode);
638 SetChild((MultiUnaryNode *) subqueryNode, subqueryExtendedNode);
639
640 currentTopNode = (MultiNode *) subqueryCollectNode;
641 }
642 else
643 {
644 /*
645 * We calculate the join order using the list of tables in the query and
646 * the join clauses between them. Note that this function owns the table
647 * entry list's memory, and JoinOrderList() shallow copies the list's
648 * elements.
649 */
650 joinClauseList = JoinClauseList(whereClauseList);
651 tableEntryList = UsedTableEntryList(queryTree);
652
653 /* build the list of multi table nodes */
654 tableNodeList = MultiTableNodeList(tableEntryList, rangeTableList);
655
656 /* add collect nodes on top of the multi table nodes */
657 collectTableList = AddMultiCollectNodes(tableNodeList);
658
659 /* find best join order for commutative inner joins */
660 joinOrderList = JoinOrderList(tableEntryList, joinClauseList);
661
662 /* build join tree using the join order and collected tables */
663 joinTreeNode = MultiJoinTree(joinOrderList, collectTableList, joinClauseList);
664
665 currentTopNode = joinTreeNode;
666 }
667
668 Assert(currentTopNode != NULL);
669
670 /* build select node if the query has selection criteria */
671 MultiSelect *selectNode = MultiSelectNode(whereClauseList);
672 if (selectNode != NULL)
673 {
674 SetChild((MultiUnaryNode *) selectNode, currentTopNode);
675 currentTopNode = (MultiNode *) selectNode;
676 }
677
678 /* build project node for the columns to project */
679 MultiProject *projectNode = MultiProjectNode(targetEntryList);
680 SetChild((MultiUnaryNode *) projectNode, currentTopNode);
681 currentTopNode = (MultiNode *) projectNode;
682
683 /*
684 * We build the extended operator node to capture aggregate functions, group
685 * clauses, sort clauses, limit/offset clauses, and expressions. We need to
686 * distinguish between aggregates and expressions; and we address this later
687 * in the logical optimizer.
688 */
689 MultiExtendedOp *extendedOpNode = MultiExtendedOpNode(queryTree, queryTree);
690 SetChild((MultiUnaryNode *) extendedOpNode, currentTopNode);
691 currentTopNode = (MultiNode *) extendedOpNode;
692
693 return currentTopNode;
694 }
695
696
697 /*
698 * ContainsReadIntermediateResultFunction determines whether an expresion tree contains
699 * a call to the read_intermediate_result function.
700 */
701 bool
ContainsReadIntermediateResultFunction(Node * node)702 ContainsReadIntermediateResultFunction(Node *node)
703 {
704 return FindNodeMatchingCheckFunction(node, IsReadIntermediateResultFunction);
705 }
706
707
708 /*
709 * ContainsReadIntermediateResultArrayFunction determines whether an expresion
710 * tree contains a call to the read_intermediate_results(result_ids, format)
711 * function.
712 */
713 bool
ContainsReadIntermediateResultArrayFunction(Node * node)714 ContainsReadIntermediateResultArrayFunction(Node *node)
715 {
716 return FindNodeMatchingCheckFunction(node, IsReadIntermediateResultArrayFunction);
717 }
718
719
720 /*
721 * IsReadIntermediateResultFunction determines whether a given node is a function call
722 * to the read_intermediate_result function.
723 */
724 static bool
IsReadIntermediateResultFunction(Node * node)725 IsReadIntermediateResultFunction(Node *node)
726 {
727 return IsFunctionWithOid(node, CitusReadIntermediateResultFuncId());
728 }
729
730
731 /*
732 * IsReadIntermediateResultArrayFunction determines whether a given node is a
733 * function call to the read_intermediate_results(result_ids, format) function.
734 */
735 static bool
IsReadIntermediateResultArrayFunction(Node * node)736 IsReadIntermediateResultArrayFunction(Node *node)
737 {
738 return IsFunctionWithOid(node, CitusReadIntermediateResultArrayFuncId());
739 }
740
741
742 /*
743 * IsCitusExtraDataContainerRelation determines whether a range table entry contains a
744 * call to the citus_extradata_container function.
745 */
746 bool
IsCitusExtraDataContainerRelation(RangeTblEntry * rte)747 IsCitusExtraDataContainerRelation(RangeTblEntry *rte)
748 {
749 if (rte->rtekind != RTE_FUNCTION || list_length(rte->functions) != 1)
750 {
751 /* avoid more expensive checks below for non-functions */
752 return false;
753 }
754
755 if (!CitusHasBeenLoaded() || !CheckCitusVersion(DEBUG5))
756 {
757 return false;
758 }
759
760 return FindNodeMatchingCheckFunction((Node *) rte->functions,
761 IsCitusExtraDataContainerFunc);
762 }
763
764
765 /*
766 * IsCitusExtraDataContainerFunc determines whether a given node is a function call
767 * to the citus_extradata_container function.
768 */
769 static bool
IsCitusExtraDataContainerFunc(Node * node)770 IsCitusExtraDataContainerFunc(Node *node)
771 {
772 return IsFunctionWithOid(node, CitusExtraDataContainerFuncId());
773 }
774
775
776 /*
777 * IsFunctionWithOid determines whether a given node is a function call
778 * to the read_intermediate_result function.
779 */
780 static bool
IsFunctionWithOid(Node * node,Oid funcOid)781 IsFunctionWithOid(Node *node, Oid funcOid)
782 {
783 if (IsA(node, FuncExpr))
784 {
785 FuncExpr *funcExpr = (FuncExpr *) node;
786
787 if (funcExpr->funcid == funcOid)
788 {
789 return true;
790 }
791 }
792
793 return false;
794 }
795
796
797 /*
798 * IsGroupingFunc returns whether node is a GroupingFunc.
799 */
800 static bool
IsGroupingFunc(Node * node)801 IsGroupingFunc(Node *node)
802 {
803 return IsA(node, GroupingFunc);
804 }
805
806
807 /*
808 * FindIntermediateResultIdIfExists extracts the id of the intermediate result
809 * if the given RTE contains a read_intermediate_results function, NULL otherwise
810 */
811 char *
FindIntermediateResultIdIfExists(RangeTblEntry * rte)812 FindIntermediateResultIdIfExists(RangeTblEntry *rte)
813 {
814 char *resultId = NULL;
815
816 Assert(rte->rtekind == RTE_FUNCTION);
817
818 List *functionList = rte->functions;
819 RangeTblFunction *rangeTblfunction = (RangeTblFunction *) linitial(functionList);
820 FuncExpr *funcExpr = (FuncExpr *) rangeTblfunction->funcexpr;
821
822 if (IsReadIntermediateResultFunction((Node *) funcExpr))
823 {
824 Const *resultIdConst = linitial(funcExpr->args);
825
826 if (!resultIdConst->constisnull)
827 {
828 resultId = TextDatumGetCString(resultIdConst->constvalue);
829 }
830 }
831
832 return resultId;
833 }
834
835
836 /*
837 * ErrorIfQueryNotSupported checks that we can perform distributed planning for
838 * the given query. The checks in this function will be removed as we support
839 * more functionality in our distributed planning.
840 */
841 DeferredErrorMessage *
DeferErrorIfQueryNotSupported(Query * queryTree)842 DeferErrorIfQueryNotSupported(Query *queryTree)
843 {
844 char *errorMessage = NULL;
845 bool preconditionsSatisfied = true;
846 const char *errorHint = NULL;
847 const char *joinHint = "Consider joining tables on partition column and have "
848 "equal filter on joining columns.";
849 const char *filterHint = "Consider using an equality filter on the distributed "
850 "table's partition column.";
851
852 if (queryTree->setOperations)
853 {
854 preconditionsSatisfied = false;
855 errorMessage = "could not run distributed query with UNION, INTERSECT, or "
856 "EXCEPT";
857 errorHint = filterHint;
858 }
859
860 if (queryTree->hasRecursive)
861 {
862 preconditionsSatisfied = false;
863 errorMessage = "could not run distributed query with RECURSIVE";
864 errorHint = filterHint;
865 }
866
867 if (queryTree->cteList)
868 {
869 preconditionsSatisfied = false;
870 errorMessage = "could not run distributed query with common table expressions";
871 errorHint = filterHint;
872 }
873
874 if (queryTree->hasForUpdate)
875 {
876 preconditionsSatisfied = false;
877 errorMessage = "could not run distributed query with FOR UPDATE/SHARE commands";
878 errorHint = filterHint;
879 }
880
881 if (queryTree->groupingSets)
882 {
883 preconditionsSatisfied = false;
884 errorMessage = "could not run distributed query with GROUPING SETS, CUBE, "
885 "or ROLLUP";
886 errorHint = filterHint;
887 }
888
889 if (FindNodeMatchingCheckFunction((Node *) queryTree, IsGroupingFunc))
890 {
891 preconditionsSatisfied = false;
892 errorMessage = "could not run distributed query with GROUPING";
893 errorHint = filterHint;
894 }
895
896 bool hasTablesample = HasTablesample(queryTree);
897 if (hasTablesample)
898 {
899 preconditionsSatisfied = false;
900 errorMessage = "could not run distributed query which use TABLESAMPLE";
901 errorHint = filterHint;
902 }
903
904 bool hasUnsupportedJoin = HasUnsupportedJoinWalker((Node *) queryTree->jointree,
905 NULL);
906 if (hasUnsupportedJoin)
907 {
908 preconditionsSatisfied = false;
909 errorMessage = "could not run distributed query with join types other than "
910 "INNER or OUTER JOINS";
911 errorHint = joinHint;
912 }
913
914 bool hasComplexRangeTableType = HasComplexRangeTableType(queryTree);
915 if (hasComplexRangeTableType)
916 {
917 preconditionsSatisfied = false;
918 errorMessage = "could not run distributed query with complex table expressions";
919 errorHint = filterHint;
920 }
921
922 if (FindNodeMatchingCheckFunction((Node *) queryTree->limitCount, IsNodeSubquery))
923 {
924 preconditionsSatisfied = false;
925 errorMessage = "subquery in LIMIT is not supported in multi-shard queries";
926 }
927
928 if (FindNodeMatchingCheckFunction((Node *) queryTree->limitOffset, IsNodeSubquery))
929 {
930 preconditionsSatisfied = false;
931 errorMessage = "subquery in OFFSET is not supported in multi-shard queries";
932 }
933
934 RTEListProperties *queryRteListProperties = GetRTEListPropertiesForQuery(queryTree);
935 if (queryRteListProperties->hasCitusLocalTable ||
936 queryRteListProperties->hasPostgresLocalTable)
937 {
938 preconditionsSatisfied = false;
939 errorMessage = "direct joins between distributed and local tables are "
940 "not supported";
941 errorHint = LOCAL_TABLE_SUBQUERY_CTE_HINT;
942 }
943
944 /* finally check and error out if not satisfied */
945 if (!preconditionsSatisfied)
946 {
947 bool showHint = ErrorHintRequired(errorHint, queryTree);
948 return DeferredError(ERRCODE_FEATURE_NOT_SUPPORTED,
949 errorMessage, NULL,
950 showHint ? errorHint : NULL);
951 }
952
953 return NULL;
954 }
955
956
957 /* HasTablesample returns tree if the query contains tablesample */
958 static bool
HasTablesample(Query * queryTree)959 HasTablesample(Query *queryTree)
960 {
961 List *rangeTableList = queryTree->rtable;
962 ListCell *rangeTableEntryCell = NULL;
963 bool hasTablesample = false;
964
965 foreach(rangeTableEntryCell, rangeTableList)
966 {
967 RangeTblEntry *rangeTableEntry = lfirst(rangeTableEntryCell);
968 if (rangeTableEntry->tablesample)
969 {
970 hasTablesample = true;
971 break;
972 }
973 }
974
975 return hasTablesample;
976 }
977
978
979 /*
980 * HasUnsupportedJoinWalker returns tree if the query contains an unsupported
981 * join type. We currently support inner, left, right, full and anti joins.
982 * Semi joins are not supported. A full description of these join types is
983 * included in nodes/nodes.h.
984 */
985 static bool
HasUnsupportedJoinWalker(Node * node,void * context)986 HasUnsupportedJoinWalker(Node *node, void *context)
987 {
988 bool hasUnsupportedJoin = false;
989
990 if (node == NULL)
991 {
992 return false;
993 }
994
995 if (IsA(node, JoinExpr))
996 {
997 JoinExpr *joinExpr = (JoinExpr *) node;
998 JoinType joinType = joinExpr->jointype;
999 bool outerJoin = IS_OUTER_JOIN(joinType);
1000 if (!outerJoin && joinType != JOIN_INNER && joinType != JOIN_SEMI)
1001 {
1002 hasUnsupportedJoin = true;
1003 }
1004 }
1005
1006 if (!hasUnsupportedJoin)
1007 {
1008 hasUnsupportedJoin = expression_tree_walker(node, HasUnsupportedJoinWalker,
1009 NULL);
1010 }
1011
1012 return hasUnsupportedJoin;
1013 }
1014
1015
1016 /*
1017 * ErrorHintRequired returns true if error hint shold be displayed with the
1018 * query error message. Error hint is valid only for queries involving reference
1019 * and hash partitioned tables. If more than one hash distributed table is
1020 * present we display the hint only if the tables are colocated. If the query
1021 * only has reference table(s), then it is handled by router planner.
1022 */
1023 static bool
ErrorHintRequired(const char * errorHint,Query * queryTree)1024 ErrorHintRequired(const char *errorHint, Query *queryTree)
1025 {
1026 List *distributedRelationIdList = DistributedRelationIdList(queryTree);
1027 ListCell *relationIdCell = NULL;
1028 List *colocationIdList = NIL;
1029
1030 if (errorHint == NULL)
1031 {
1032 return false;
1033 }
1034
1035 foreach(relationIdCell, distributedRelationIdList)
1036 {
1037 Oid relationId = lfirst_oid(relationIdCell);
1038 if (IsCitusTableType(relationId, REFERENCE_TABLE))
1039 {
1040 continue;
1041 }
1042 else if (IsCitusTableType(relationId, HASH_DISTRIBUTED))
1043 {
1044 int colocationId = TableColocationId(relationId);
1045 colocationIdList = list_append_unique_int(colocationIdList, colocationId);
1046 }
1047 else
1048 {
1049 return false;
1050 }
1051 }
1052
1053 /* do not display the hint if there are more than one colocation group */
1054 if (list_length(colocationIdList) > 1)
1055 {
1056 return false;
1057 }
1058
1059 return true;
1060 }
1061
1062
1063 /*
1064 * DeferErrorIfUnsupportedSubqueryRepartition checks that we can perform distributed planning for
1065 * the given subquery. If not, a deferred error is returned. The function recursively
1066 * does this check to all lower levels of the subquery.
1067 */
1068 DeferredErrorMessage *
DeferErrorIfUnsupportedSubqueryRepartition(Query * subqueryTree)1069 DeferErrorIfUnsupportedSubqueryRepartition(Query *subqueryTree)
1070 {
1071 char *errorDetail = NULL;
1072 bool preconditionsSatisfied = true;
1073 List *joinTreeTableIndexList = NIL;
1074
1075 if (!subqueryTree->hasAggs)
1076 {
1077 preconditionsSatisfied = false;
1078 errorDetail = "Subqueries without aggregates are not supported yet";
1079 }
1080
1081 if (subqueryTree->groupClause == NIL)
1082 {
1083 preconditionsSatisfied = false;
1084 errorDetail = "Subqueries without group by clause are not supported yet";
1085 }
1086
1087 if (subqueryTree->sortClause != NULL)
1088 {
1089 preconditionsSatisfied = false;
1090 errorDetail = "Subqueries with order by clause are not supported yet";
1091 }
1092
1093 if (subqueryTree->limitCount != NULL)
1094 {
1095 preconditionsSatisfied = false;
1096 errorDetail = "Subqueries with limit are not supported yet";
1097 }
1098
1099 if (subqueryTree->limitOffset != NULL)
1100 {
1101 preconditionsSatisfied = false;
1102 errorDetail = "Subqueries with offset are not supported yet";
1103 }
1104
1105 if (subqueryTree->hasSubLinks)
1106 {
1107 preconditionsSatisfied = false;
1108 errorDetail = "Subqueries other than from-clause subqueries are unsupported";
1109 }
1110
1111 /* finally check and return error if conditions are not satisfied */
1112 if (!preconditionsSatisfied)
1113 {
1114 return DeferredError(ERRCODE_FEATURE_NOT_SUPPORTED,
1115 "cannot perform distributed planning on this query",
1116 errorDetail, NULL);
1117 }
1118
1119 /*
1120 * Extract all range table indexes from the join tree. Note that sub-queries
1121 * that get pulled up by PostgreSQL don't appear in this join tree.
1122 */
1123 ExtractRangeTableIndexWalker((Node *) subqueryTree->jointree,
1124 &joinTreeTableIndexList);
1125 Assert(list_length(joinTreeTableIndexList) == 1);
1126
1127 /* continue with the inner subquery */
1128 int rangeTableIndex = linitial_int(joinTreeTableIndexList);
1129 RangeTblEntry *rangeTableEntry = rt_fetch(rangeTableIndex, subqueryTree->rtable);
1130 if (rangeTableEntry->rtekind == RTE_RELATION)
1131 {
1132 return NULL;
1133 }
1134
1135 Assert(rangeTableEntry->rtekind == RTE_SUBQUERY);
1136 Query *innerSubquery = rangeTableEntry->subquery;
1137
1138 /* recursively continue to the inner subqueries */
1139 return DeferErrorIfUnsupportedSubqueryRepartition(innerSubquery);
1140 }
1141
1142
1143 /*
1144 * HasComplexRangeTableType checks if the given query tree contains any complex
1145 * range table types. For this, the function walks over all range tables in the
1146 * join tree, and checks if they correspond to simple relations or subqueries.
1147 * If they don't, the function assumes the query has complex range tables.
1148 */
1149 static bool
HasComplexRangeTableType(Query * queryTree)1150 HasComplexRangeTableType(Query *queryTree)
1151 {
1152 List *rangeTableList = queryTree->rtable;
1153 List *joinTreeTableIndexList = NIL;
1154 ListCell *joinTreeTableIndexCell = NULL;
1155 bool hasComplexRangeTableType = false;
1156
1157 /*
1158 * Extract all range table indexes from the join tree. Note that sub-queries
1159 * that get pulled up by PostgreSQL don't appear in this join tree.
1160 */
1161 ExtractRangeTableIndexWalker((Node *) queryTree->jointree, &joinTreeTableIndexList);
1162 foreach(joinTreeTableIndexCell, joinTreeTableIndexList)
1163 {
1164 /*
1165 * Join tree's range table index starts from 1 in the query tree. But,
1166 * list indexes start from 0.
1167 */
1168 int joinTreeTableIndex = lfirst_int(joinTreeTableIndexCell);
1169 int rangeTableListIndex = joinTreeTableIndex - 1;
1170
1171 RangeTblEntry *rangeTableEntry =
1172 (RangeTblEntry *) list_nth(rangeTableList, rangeTableListIndex);
1173
1174 /*
1175 * Check if the range table in the join tree is a simple relation or a
1176 * subquery or a function. Note that RTE_FUNCTIONs are handled via (sub)query
1177 * pushdown.
1178 */
1179 if (rangeTableEntry->rtekind != RTE_RELATION &&
1180 rangeTableEntry->rtekind != RTE_SUBQUERY &&
1181 rangeTableEntry->rtekind != RTE_FUNCTION &&
1182 rangeTableEntry->rtekind != RTE_VALUES)
1183 {
1184 hasComplexRangeTableType = true;
1185 }
1186
1187 /*
1188 * Check if the subquery range table entry includes children inheritance.
1189 *
1190 * Note that PostgreSQL flattens out simple union all queries into an
1191 * append relation, sets "inh" field of RangeTblEntry to true and deletes
1192 * set operations. Here we check this for subqueries.
1193 */
1194 if (rangeTableEntry->rtekind == RTE_SUBQUERY && rangeTableEntry->inh)
1195 {
1196 hasComplexRangeTableType = true;
1197 }
1198 }
1199
1200 return hasComplexRangeTableType;
1201 }
1202
1203
1204 /*
1205 * WhereClauseList walks over the FROM expression in the query tree, and builds
1206 * a list of all clauses from the expression tree. The function checks for both
1207 * implicitly and explicitly defined clauses, but only selects INNER join
1208 * explicit clauses, and skips any outer-join clauses. Explicit clauses are
1209 * expressed as "SELECT ... FROM R1 INNER JOIN R2 ON R1.A = R2.A". Implicit
1210 * joins differ in that they live in the WHERE clause, and are expressed as
1211 * "SELECT ... FROM ... WHERE R1.a = R2.a".
1212 */
1213 List *
WhereClauseList(FromExpr * fromExpr)1214 WhereClauseList(FromExpr *fromExpr)
1215 {
1216 FromExpr *fromExprCopy = copyObject(fromExpr);
1217 QualifierWalkerContext *walkerContext = palloc0(sizeof(QualifierWalkerContext));
1218
1219 ExtractFromExpressionWalker((Node *) fromExprCopy, walkerContext);
1220 List *whereClauseList = walkerContext->baseQualifierList;
1221
1222 return whereClauseList;
1223 }
1224
1225
1226 /*
1227 * QualifierList walks over the FROM expression in the query tree, and builds
1228 * a list of all qualifiers from the expression tree. The function checks for
1229 * both implicitly and explicitly defined qualifiers. Note that this function
1230 * is very similar to WhereClauseList(), but QualifierList() also includes
1231 * outer-join clauses.
1232 */
1233 List *
QualifierList(FromExpr * fromExpr)1234 QualifierList(FromExpr *fromExpr)
1235 {
1236 FromExpr *fromExprCopy = copyObject(fromExpr);
1237 QualifierWalkerContext *walkerContext = palloc0(sizeof(QualifierWalkerContext));
1238 List *qualifierList = NIL;
1239
1240 ExtractFromExpressionWalker((Node *) fromExprCopy, walkerContext);
1241 qualifierList = list_concat(qualifierList, walkerContext->baseQualifierList);
1242 qualifierList = list_concat(qualifierList, walkerContext->outerJoinQualifierList);
1243
1244 return qualifierList;
1245 }
1246
1247
1248 /*
1249 * DeferErrorIfUnsupportedClause walks over the given list of clauses, and
1250 * checks that we can recognize all the clauses. This function ensures that
1251 * we do not drop an unsupported clause type on the floor, and thus prevents
1252 * erroneous results.
1253 *
1254 * Returns a deferred error, caller is responsible for raising the error.
1255 */
1256 DeferredErrorMessage *
DeferErrorIfUnsupportedClause(List * clauseList)1257 DeferErrorIfUnsupportedClause(List *clauseList)
1258 {
1259 ListCell *clauseCell = NULL;
1260 foreach(clauseCell, clauseList)
1261 {
1262 Node *clause = (Node *) lfirst(clauseCell);
1263
1264 if (!(IsSelectClause(clause) || IsJoinClause(clause) || or_clause(clause)))
1265 {
1266 return DeferredError(ERRCODE_FEATURE_NOT_SUPPORTED,
1267 "unsupported clause type", NULL, NULL);
1268 }
1269 }
1270 return NULL;
1271 }
1272
1273
1274 /*
1275 * JoinClauseList finds the join clauses from the given where clause expression
1276 * list, and returns them. The function does not iterate into nested OR clauses
1277 * and relies on find_duplicate_ors() in the optimizer to pull up factorizable
1278 * OR clauses.
1279 */
1280 List *
JoinClauseList(List * whereClauseList)1281 JoinClauseList(List *whereClauseList)
1282 {
1283 List *joinClauseList = NIL;
1284 ListCell *whereClauseCell = NULL;
1285
1286 foreach(whereClauseCell, whereClauseList)
1287 {
1288 Node *whereClause = (Node *) lfirst(whereClauseCell);
1289 if (IsJoinClause(whereClause))
1290 {
1291 joinClauseList = lappend(joinClauseList, whereClause);
1292 }
1293 }
1294
1295 return joinClauseList;
1296 }
1297
1298
1299 /*
1300 * ExtractFromExpressionWalker walks over a FROM expression, and finds all
1301 * implicit and explicit qualifiers in the expression. The function looks at
1302 * join and from expression nodes to find qualifiers, and returns these
1303 * qualifiers.
1304 *
1305 * Note that we don't want outer join clauses in regular outer join planning,
1306 * but we need outer join clauses in subquery pushdown prerequisite checks.
1307 * Therefore, outer join qualifiers are returned in a different list than other
1308 * qualifiers inside the given walker context. For this reason, we return two
1309 * qualifier lists.
1310 *
1311 * Note that we check if the qualifier node in join and from expression nodes
1312 * is a list node. If it is not a list node which is the case for subqueries,
1313 * then we run eval_const_expressions(), canonicalize_qual() and make_ands_implicit()
1314 * on the qualifier node and get a list of flattened implicitly AND'ed qualifier
1315 * list. Actually in the planer phase of PostgreSQL these functions also run on
1316 * subqueries but differently from the outermost query, they are run on a copy
1317 * of parse tree and changes do not get persisted as modifications to the original
1318 * query tree.
1319 *
1320 * Also this function adds SubLinks to the baseQualifierList when they appear on
1321 * the query's WHERE clause. The callers of the function should consider processing
1322 * Sublinks as well.
1323 */
1324 static bool
ExtractFromExpressionWalker(Node * node,QualifierWalkerContext * walkerContext)1325 ExtractFromExpressionWalker(Node *node, QualifierWalkerContext *walkerContext)
1326 {
1327 if (node == NULL)
1328 {
1329 return false;
1330 }
1331
1332 /*
1333 * Get qualifier lists of join and from expression nodes. Note that in the
1334 * case of subqueries, PostgreSQL can skip simplifying, flattening and
1335 * making ANDs implicit. If qualifiers node is not a list, then we run these
1336 * preprocess routines on qualifiers node.
1337 */
1338 if (IsA(node, JoinExpr))
1339 {
1340 List *joinQualifierList = NIL;
1341 JoinExpr *joinExpression = (JoinExpr *) node;
1342 Node *joinQualifiersNode = joinExpression->quals;
1343 JoinType joinType = joinExpression->jointype;
1344
1345 if (joinQualifiersNode != NULL)
1346 {
1347 if (IsA(joinQualifiersNode, List))
1348 {
1349 joinQualifierList = (List *) joinQualifiersNode;
1350 }
1351 else
1352 {
1353 /* this part of code only run for subqueries */
1354 Node *joinClause = eval_const_expressions(NULL, joinQualifiersNode);
1355 joinClause = (Node *) canonicalize_qual((Expr *) joinClause, false);
1356 joinQualifierList = make_ands_implicit((Expr *) joinClause);
1357 }
1358 }
1359
1360 /* return outer join clauses in a separate list */
1361 if (joinType == JOIN_INNER || joinType == JOIN_SEMI)
1362 {
1363 walkerContext->baseQualifierList =
1364 list_concat(walkerContext->baseQualifierList, joinQualifierList);
1365 }
1366 else if (IS_OUTER_JOIN(joinType))
1367 {
1368 walkerContext->outerJoinQualifierList =
1369 list_concat(walkerContext->outerJoinQualifierList, joinQualifierList);
1370 }
1371 }
1372 else if (IsA(node, FromExpr))
1373 {
1374 List *fromQualifierList = NIL;
1375 FromExpr *fromExpression = (FromExpr *) node;
1376 Node *fromQualifiersNode = fromExpression->quals;
1377
1378 if (fromQualifiersNode != NULL)
1379 {
1380 if (IsA(fromQualifiersNode, List))
1381 {
1382 fromQualifierList = (List *) fromQualifiersNode;
1383 }
1384 else
1385 {
1386 /* this part of code only run for subqueries */
1387 Node *fromClause = eval_const_expressions(NULL, fromQualifiersNode);
1388 fromClause = (Node *) canonicalize_qual((Expr *) fromClause, false);
1389 fromQualifierList = make_ands_implicit((Expr *) fromClause);
1390 }
1391
1392 walkerContext->baseQualifierList =
1393 list_concat(walkerContext->baseQualifierList, fromQualifierList);
1394 }
1395 }
1396
1397 bool walkerResult = expression_tree_walker(node, ExtractFromExpressionWalker,
1398 (void *) walkerContext);
1399
1400 return walkerResult;
1401 }
1402
1403
1404 /*
1405 * IsJoinClause determines if the given node is a join clause according to our
1406 * criteria. Our criteria defines a join clause as an equi join operator between
1407 * two columns that belong to two different tables.
1408 */
1409 bool
IsJoinClause(Node * clause)1410 IsJoinClause(Node *clause)
1411 {
1412 Var *var = NULL;
1413
1414 /*
1415 * take all column references from the clause, if we find 2 column references from a
1416 * different relation we assume this is a join clause
1417 */
1418 List *varList = pull_var_clause_default(clause);
1419 if (list_length(varList) <= 0)
1420 {
1421 /* no column references in query, not describing a join */
1422 return false;
1423 }
1424 Var *initialVar = castNode(Var, linitial(varList));
1425
1426 foreach_ptr(var, varList)
1427 {
1428 if (var->varno != initialVar->varno)
1429 {
1430 /*
1431 * this column reference comes from a different relation, hence describing a
1432 * join
1433 */
1434 return true;
1435 }
1436 }
1437
1438 /* all column references were to the same relation, no join */
1439 return false;
1440 }
1441
1442
1443 /*
1444 * TableEntryList finds the regular relation nodes in the range table entry
1445 * list, and builds a list of table entries from these regular relation nodes.
1446 */
1447 List *
TableEntryList(List * rangeTableList)1448 TableEntryList(List *rangeTableList)
1449 {
1450 List *tableEntryList = NIL;
1451 ListCell *rangeTableCell = NULL;
1452 uint32 tableId = 1; /* range table indices start at 1 */
1453
1454 foreach(rangeTableCell, rangeTableList)
1455 {
1456 RangeTblEntry *rangeTableEntry = (RangeTblEntry *) lfirst(rangeTableCell);
1457
1458 if (rangeTableEntry->rtekind == RTE_RELATION)
1459 {
1460 TableEntry *tableEntry = (TableEntry *) palloc0(sizeof(TableEntry));
1461 tableEntry->relationId = rangeTableEntry->relid;
1462 tableEntry->rangeTableId = tableId;
1463
1464 tableEntryList = lappend(tableEntryList, tableEntry);
1465 }
1466
1467 /*
1468 * Increment tableId regardless so that table entry's tableId remains
1469 * congruent with column's range table reference (varno).
1470 */
1471 tableId++;
1472 }
1473
1474 return tableEntryList;
1475 }
1476
1477
1478 /*
1479 * UsedTableEntryList returns list of relation range table entries
1480 * that are referenced within the query. Unused entries due to query
1481 * flattening or re-rewriting are ignored.
1482 */
1483 List *
UsedTableEntryList(Query * query)1484 UsedTableEntryList(Query *query)
1485 {
1486 List *tableEntryList = NIL;
1487 List *rangeTableList = query->rtable;
1488 List *joinTreeTableIndexList = NIL;
1489 ListCell *joinTreeTableIndexCell = NULL;
1490
1491 ExtractRangeTableIndexWalker((Node *) query->jointree, &joinTreeTableIndexList);
1492 foreach(joinTreeTableIndexCell, joinTreeTableIndexList)
1493 {
1494 int joinTreeTableIndex = lfirst_int(joinTreeTableIndexCell);
1495 RangeTblEntry *rangeTableEntry = rt_fetch(joinTreeTableIndex, rangeTableList);
1496 if (rangeTableEntry->rtekind == RTE_RELATION)
1497 {
1498 TableEntry *tableEntry = (TableEntry *) palloc0(sizeof(TableEntry));
1499 tableEntry->relationId = rangeTableEntry->relid;
1500 tableEntry->rangeTableId = joinTreeTableIndex;
1501
1502 tableEntryList = lappend(tableEntryList, tableEntry);
1503 }
1504 }
1505
1506 return tableEntryList;
1507 }
1508
1509
1510 /*
1511 * MultiTableNodeList builds a list of MultiTable nodes from the given table
1512 * entry list. A multi table node represents one entry from the range table
1513 * list. These entries may belong to the same physical relation in the case of
1514 * self-joins.
1515 */
1516 static List *
MultiTableNodeList(List * tableEntryList,List * rangeTableList)1517 MultiTableNodeList(List *tableEntryList, List *rangeTableList)
1518 {
1519 List *tableNodeList = NIL;
1520 ListCell *tableEntryCell = NULL;
1521
1522 foreach(tableEntryCell, tableEntryList)
1523 {
1524 TableEntry *tableEntry = (TableEntry *) lfirst(tableEntryCell);
1525 Oid relationId = tableEntry->relationId;
1526 uint32 rangeTableId = tableEntry->rangeTableId;
1527 Var *partitionColumn = PartitionColumn(relationId, rangeTableId);
1528 RangeTblEntry *rangeTableEntry = rt_fetch(rangeTableId, rangeTableList);
1529
1530 MultiTable *tableNode = CitusMakeNode(MultiTable);
1531 tableNode->subquery = NULL;
1532 tableNode->relationId = relationId;
1533 tableNode->rangeTableId = rangeTableId;
1534 tableNode->partitionColumn = partitionColumn;
1535 tableNode->alias = rangeTableEntry->alias;
1536 tableNode->referenceNames = rangeTableEntry->eref;
1537 tableNode->includePartitions = GetOriginalInh(rangeTableEntry);
1538
1539 tableNodeList = lappend(tableNodeList, tableNode);
1540 }
1541
1542 return tableNodeList;
1543 }
1544
1545
1546 /* Adds a MultiCollect node on top of each MultiTable node in the given list. */
1547 static List *
AddMultiCollectNodes(List * tableNodeList)1548 AddMultiCollectNodes(List *tableNodeList)
1549 {
1550 List *collectTableList = NIL;
1551 ListCell *tableNodeCell = NULL;
1552
1553 foreach(tableNodeCell, tableNodeList)
1554 {
1555 MultiTable *tableNode = (MultiTable *) lfirst(tableNodeCell);
1556
1557 MultiCollect *collectNode = CitusMakeNode(MultiCollect);
1558 SetChild((MultiUnaryNode *) collectNode, (MultiNode *) tableNode);
1559
1560 collectTableList = lappend(collectTableList, collectNode);
1561 }
1562
1563 return collectTableList;
1564 }
1565
1566
1567 /*
1568 * MultiJoinTree takes in the join order information and the list of tables, and
1569 * builds a join tree by applying the corresponding join rules. The function
1570 * builds a left deep tree, as expressed by the join order list.
1571 *
1572 * The function starts by setting the first table as the top node in the join
1573 * tree. Then, the function iterates over the list of tables, and builds a new
1574 * join node between the top of the join tree and the next table in the list.
1575 * At each iteration, the function sets the top of the join tree to the newly
1576 * built list. This results in a left deep join tree, and the function returns
1577 * this tree after every table in the list has been joined.
1578 */
1579 static MultiNode *
MultiJoinTree(List * joinOrderList,List * collectTableList,List * joinWhereClauseList)1580 MultiJoinTree(List *joinOrderList, List *collectTableList, List *joinWhereClauseList)
1581 {
1582 MultiNode *currentTopNode = NULL;
1583 ListCell *joinOrderCell = NULL;
1584 bool firstJoinNode = true;
1585
1586 foreach(joinOrderCell, joinOrderList)
1587 {
1588 JoinOrderNode *joinOrderNode = (JoinOrderNode *) lfirst(joinOrderCell);
1589 uint32 joinTableId = joinOrderNode->tableEntry->rangeTableId;
1590 MultiCollect *collectNode = CollectNodeForTable(collectTableList, joinTableId);
1591
1592 if (firstJoinNode)
1593 {
1594 currentTopNode = (MultiNode *) collectNode;
1595 firstJoinNode = false;
1596 }
1597 else
1598 {
1599 JoinRuleType joinRuleType = joinOrderNode->joinRuleType;
1600 JoinType joinType = joinOrderNode->joinType;
1601 List *partitionColumnList = joinOrderNode->partitionColumnList;
1602 List *joinClauseList = joinOrderNode->joinClauseList;
1603
1604 /*
1605 * Build a join node between the top of our join tree and the next
1606 * table in the join order.
1607 */
1608 MultiNode *newJoinNode = ApplyJoinRule(currentTopNode,
1609 (MultiNode *) collectNode,
1610 joinRuleType, partitionColumnList,
1611 joinType,
1612 joinClauseList);
1613
1614 /* the new join node becomes the top of our join tree */
1615 currentTopNode = newJoinNode;
1616 }
1617 }
1618
1619 /* current top node points to the entire left deep join tree */
1620 return currentTopNode;
1621 }
1622
1623
1624 /*
1625 * CollectNodeForTable finds the MultiCollect node whose MultiTable node has the
1626 * given range table identifier. Note that this function expects each collect
1627 * node in the given list to have one table node as its child.
1628 */
1629 static MultiCollect *
CollectNodeForTable(List * collectTableList,uint32 rangeTableId)1630 CollectNodeForTable(List *collectTableList, uint32 rangeTableId)
1631 {
1632 MultiCollect *collectNodeForTable = NULL;
1633 ListCell *collectTableCell = NULL;
1634
1635 foreach(collectTableCell, collectTableList)
1636 {
1637 MultiCollect *collectNode = (MultiCollect *) lfirst(collectTableCell);
1638
1639 List *tableIdList = OutputTableIdList((MultiNode *) collectNode);
1640 uint32 tableId = (uint32) linitial_int(tableIdList);
1641 Assert(list_length(tableIdList) == 1);
1642
1643 if (tableId == rangeTableId)
1644 {
1645 collectNodeForTable = collectNode;
1646 break;
1647 }
1648 }
1649
1650 Assert(collectNodeForTable != NULL);
1651 return collectNodeForTable;
1652 }
1653
1654
1655 /*
1656 * MultiSelectNode extracts the select clauses from the given where clause list,
1657 * and builds a MultiSelect node from these clauses. If the expression tree does
1658 * not have any select clauses, the function return null.
1659 */
1660 static MultiSelect *
MultiSelectNode(List * whereClauseList)1661 MultiSelectNode(List *whereClauseList)
1662 {
1663 List *selectClauseList = NIL;
1664 MultiSelect *selectNode = NULL;
1665
1666 ListCell *whereClauseCell = NULL;
1667 foreach(whereClauseCell, whereClauseList)
1668 {
1669 Node *whereClause = (Node *) lfirst(whereClauseCell);
1670 if (IsSelectClause(whereClause))
1671 {
1672 selectClauseList = lappend(selectClauseList, whereClause);
1673 }
1674 }
1675
1676 if (list_length(selectClauseList) > 0)
1677 {
1678 selectNode = CitusMakeNode(MultiSelect);
1679 selectNode->selectClauseList = selectClauseList;
1680 }
1681
1682 return selectNode;
1683 }
1684
1685
1686 /*
1687 * IsSelectClause determines if the given node is a select clause according to
1688 * our criteria. Our criteria defines a select clause as an expression that has
1689 * zero or more columns belonging to only one table. The function assumes that
1690 * no sublinks exists in the clause.
1691 */
1692 static bool
IsSelectClause(Node * clause)1693 IsSelectClause(Node *clause)
1694 {
1695 ListCell *columnCell = NULL;
1696 bool isSelectClause = true;
1697
1698 /* extract columns from the clause */
1699 List *columnList = pull_var_clause_default(clause);
1700 if (list_length(columnList) == 0)
1701 {
1702 return true;
1703 }
1704
1705 /* get first column's tableId */
1706 Var *firstColumn = (Var *) linitial(columnList);
1707 Index firstColumnTableId = firstColumn->varno;
1708
1709 /* check if all columns are from the same table */
1710 foreach(columnCell, columnList)
1711 {
1712 Var *column = (Var *) lfirst(columnCell);
1713 if (column->varno != firstColumnTableId)
1714 {
1715 isSelectClause = false;
1716 }
1717 }
1718
1719 return isSelectClause;
1720 }
1721
1722
1723 /*
1724 * MultiProjectNode builds the project node using the target entry information
1725 * from the query tree. The project node only encapsulates projected columns,
1726 * and does not include aggregates, group clauses, or project expressions.
1727 */
1728 MultiProject *
MultiProjectNode(List * targetEntryList)1729 MultiProjectNode(List *targetEntryList)
1730 {
1731 List *uniqueColumnList = NIL;
1732 ListCell *columnCell = NULL;
1733
1734 /* extract the list of columns and remove any duplicates */
1735 List *columnList = pull_var_clause_default((Node *) targetEntryList);
1736 foreach(columnCell, columnList)
1737 {
1738 Var *column = (Var *) lfirst(columnCell);
1739
1740 uniqueColumnList = list_append_unique(uniqueColumnList, column);
1741 }
1742
1743 /* create project node with list of columns to project */
1744 MultiProject *projectNode = CitusMakeNode(MultiProject);
1745 projectNode->columnList = uniqueColumnList;
1746
1747 return projectNode;
1748 }
1749
1750
1751 /* Builds the extended operator node using fields from the given query tree. */
1752 MultiExtendedOp *
MultiExtendedOpNode(Query * queryTree,Query * originalQuery)1753 MultiExtendedOpNode(Query *queryTree, Query *originalQuery)
1754 {
1755 MultiExtendedOp *extendedOpNode = CitusMakeNode(MultiExtendedOp);
1756 extendedOpNode->targetList = queryTree->targetList;
1757 extendedOpNode->groupClauseList = queryTree->groupClause;
1758 extendedOpNode->sortClauseList = queryTree->sortClause;
1759 extendedOpNode->limitCount = queryTree->limitCount;
1760 extendedOpNode->limitOffset = queryTree->limitOffset;
1761 #if PG_VERSION_NUM >= PG_VERSION_13
1762 extendedOpNode->limitOption = queryTree->limitOption;
1763 #endif
1764 extendedOpNode->havingQual = queryTree->havingQual;
1765 extendedOpNode->distinctClause = queryTree->distinctClause;
1766 extendedOpNode->hasDistinctOn = queryTree->hasDistinctOn;
1767 extendedOpNode->hasWindowFuncs = queryTree->hasWindowFuncs;
1768 extendedOpNode->windowClause = queryTree->windowClause;
1769 extendedOpNode->onlyPushableWindowFunctions =
1770 !queryTree->hasWindowFuncs ||
1771 SafeToPushdownWindowFunction(originalQuery, NULL);
1772
1773 return extendedOpNode;
1774 }
1775
1776
1777 /* Helper function to return the parent node of the given node. */
1778 MultiNode *
ParentNode(MultiNode * multiNode)1779 ParentNode(MultiNode *multiNode)
1780 {
1781 MultiNode *parentNode = multiNode->parentNode;
1782 return parentNode;
1783 }
1784
1785
1786 /* Helper function to return the child of the given unary node. */
1787 MultiNode *
ChildNode(MultiUnaryNode * multiNode)1788 ChildNode(MultiUnaryNode *multiNode)
1789 {
1790 MultiNode *childNode = multiNode->childNode;
1791 return childNode;
1792 }
1793
1794
1795 /* Helper function to return the grand child of the given unary node. */
1796 MultiNode *
GrandChildNode(MultiUnaryNode * multiNode)1797 GrandChildNode(MultiUnaryNode *multiNode)
1798 {
1799 MultiNode *childNode = ChildNode(multiNode);
1800 MultiNode *grandChildNode = ChildNode((MultiUnaryNode *) childNode);
1801
1802 return grandChildNode;
1803 }
1804
1805
1806 /* Sets the given child node as a child of the given unary parent node. */
1807 void
SetChild(MultiUnaryNode * parent,MultiNode * child)1808 SetChild(MultiUnaryNode *parent, MultiNode *child)
1809 {
1810 parent->childNode = child;
1811 child->parentNode = (MultiNode *) parent;
1812 }
1813
1814
1815 /* Sets the given child node as a left child of the given parent node. */
1816 void
SetLeftChild(MultiBinaryNode * parent,MultiNode * leftChild)1817 SetLeftChild(MultiBinaryNode *parent, MultiNode *leftChild)
1818 {
1819 parent->leftChildNode = leftChild;
1820 leftChild->parentNode = (MultiNode *) parent;
1821 }
1822
1823
1824 /* Sets the given child node as a right child of the given parent node. */
1825 void
SetRightChild(MultiBinaryNode * parent,MultiNode * rightChild)1826 SetRightChild(MultiBinaryNode *parent, MultiNode *rightChild)
1827 {
1828 parent->rightChildNode = rightChild;
1829 rightChild->parentNode = (MultiNode *) parent;
1830 }
1831
1832
1833 /* Returns true if the given node is a unary operator. */
1834 bool
UnaryOperator(MultiNode * node)1835 UnaryOperator(MultiNode *node)
1836 {
1837 bool unaryOperator = false;
1838
1839 if (CitusIsA(node, MultiTreeRoot) || CitusIsA(node, MultiTable) ||
1840 CitusIsA(node, MultiCollect) || CitusIsA(node, MultiSelect) ||
1841 CitusIsA(node, MultiProject) || CitusIsA(node, MultiPartition) ||
1842 CitusIsA(node, MultiExtendedOp))
1843 {
1844 unaryOperator = true;
1845 }
1846
1847 return unaryOperator;
1848 }
1849
1850
1851 /* Returns true if the given node is a binary operator. */
1852 bool
BinaryOperator(MultiNode * node)1853 BinaryOperator(MultiNode *node)
1854 {
1855 bool binaryOperator = false;
1856
1857 if (CitusIsA(node, MultiJoin) || CitusIsA(node, MultiCartesianProduct))
1858 {
1859 binaryOperator = true;
1860 }
1861
1862 return binaryOperator;
1863 }
1864
1865
1866 /*
1867 * OutputTableIdList finds all table identifiers that are output by the given
1868 * multi node, and returns these identifiers in a new list.
1869 */
1870 List *
OutputTableIdList(MultiNode * multiNode)1871 OutputTableIdList(MultiNode *multiNode)
1872 {
1873 List *tableIdList = NIL;
1874 List *tableNodeList = FindNodesOfType(multiNode, T_MultiTable);
1875 ListCell *tableNodeCell = NULL;
1876
1877 foreach(tableNodeCell, tableNodeList)
1878 {
1879 MultiTable *tableNode = (MultiTable *) lfirst(tableNodeCell);
1880 int tableId = (int) tableNode->rangeTableId;
1881
1882 if (tableId != SUBQUERY_RANGE_TABLE_ID)
1883 {
1884 tableIdList = lappend_int(tableIdList, tableId);
1885 }
1886 }
1887
1888 return tableIdList;
1889 }
1890
1891
1892 /*
1893 * FindNodesOfType takes in a given logical plan tree, and recursively traverses
1894 * the tree in preorder. The function finds all nodes of requested type during
1895 * the traversal, and returns them in a list.
1896 */
1897 List *
FindNodesOfType(MultiNode * node,int type)1898 FindNodesOfType(MultiNode *node, int type)
1899 {
1900 List *nodeList = NIL;
1901
1902 /* terminal condition for recursion */
1903 if (node == NULL)
1904 {
1905 return NIL;
1906 }
1907
1908 /* current node has expected node type */
1909 int nodeType = CitusNodeTag(node);
1910 if (nodeType == type)
1911 {
1912 nodeList = lappend(nodeList, node);
1913 }
1914
1915 if (UnaryOperator(node))
1916 {
1917 MultiNode *childNode = ((MultiUnaryNode *) node)->childNode;
1918 List *childNodeList = FindNodesOfType(childNode, type);
1919
1920 nodeList = list_concat(nodeList, childNodeList);
1921 }
1922 else if (BinaryOperator(node))
1923 {
1924 MultiNode *leftChildNode = ((MultiBinaryNode *) node)->leftChildNode;
1925 MultiNode *rightChildNode = ((MultiBinaryNode *) node)->rightChildNode;
1926
1927 List *leftChildNodeList = FindNodesOfType(leftChildNode, type);
1928 List *rightChildNodeList = FindNodesOfType(rightChildNode, type);
1929
1930 nodeList = list_concat(nodeList, leftChildNodeList);
1931 nodeList = list_concat(nodeList, rightChildNodeList);
1932 }
1933
1934 return nodeList;
1935 }
1936
1937
1938 /*
1939 * pull_var_clause_default calls pull_var_clause with the most commonly used
1940 * arguments for distributed planning.
1941 */
1942 List *
pull_var_clause_default(Node * node)1943 pull_var_clause_default(Node *node)
1944 {
1945 /*
1946 * PVC_REJECT_PLACEHOLDERS is implicit if PVC_INCLUDE_PLACEHOLDERS
1947 * isn't specified.
1948 */
1949 List *columnList = pull_var_clause(node, PVC_RECURSE_AGGREGATES |
1950 PVC_RECURSE_WINDOWFUNCS);
1951
1952 return columnList;
1953 }
1954
1955
1956 /*
1957 * ApplyJoinRule finds the join rule application function that corresponds to
1958 * the given join rule, and calls this function to create a new join node that
1959 * joins the left and right nodes together.
1960 */
1961 static MultiNode *
ApplyJoinRule(MultiNode * leftNode,MultiNode * rightNode,JoinRuleType ruleType,List * partitionColumnList,JoinType joinType,List * joinClauseList)1962 ApplyJoinRule(MultiNode *leftNode, MultiNode *rightNode, JoinRuleType ruleType,
1963 List *partitionColumnList, JoinType joinType, List *joinClauseList)
1964 {
1965 List *leftTableIdList = OutputTableIdList(leftNode);
1966 List *rightTableIdList = OutputTableIdList(rightNode);
1967 int rightTableIdCount PG_USED_FOR_ASSERTS_ONLY = 0;
1968
1969 rightTableIdCount = list_length(rightTableIdList);
1970 Assert(rightTableIdCount == 1);
1971
1972 /* find applicable join clauses between the left and right data sources */
1973 uint32 rightTableId = (uint32) linitial_int(rightTableIdList);
1974 List *applicableJoinClauses = ApplicableJoinClauses(leftTableIdList, rightTableId,
1975 joinClauseList);
1976
1977 /* call the join rule application function to create the new join node */
1978 RuleApplyFunction ruleApplyFunction = JoinRuleApplyFunction(ruleType);
1979 MultiNode *multiNode = (*ruleApplyFunction)(leftNode, rightNode, partitionColumnList,
1980 joinType, applicableJoinClauses);
1981
1982 if (joinType != JOIN_INNER && CitusIsA(multiNode, MultiJoin))
1983 {
1984 MultiJoin *joinNode = (MultiJoin *) multiNode;
1985
1986 /* preserve non-join clauses for OUTER joins */
1987 joinNode->joinClauseList = list_copy(joinClauseList);
1988 }
1989
1990 return multiNode;
1991 }
1992
1993
1994 /*
1995 * JoinRuleApplyFunction returns a function pointer for the rule application
1996 * function; this rule application function corresponds to the given rule type.
1997 * This function also initializes the rule application function array in a
1998 * static code block, if the array has not been initialized.
1999 */
2000 static RuleApplyFunction
JoinRuleApplyFunction(JoinRuleType ruleType)2001 JoinRuleApplyFunction(JoinRuleType ruleType)
2002 {
2003 static bool ruleApplyFunctionInitialized = false;
2004
2005 if (!ruleApplyFunctionInitialized)
2006 {
2007 RuleApplyFunctionArray[REFERENCE_JOIN] = &ApplyReferenceJoin;
2008 RuleApplyFunctionArray[LOCAL_PARTITION_JOIN] = &ApplyLocalJoin;
2009 RuleApplyFunctionArray[SINGLE_HASH_PARTITION_JOIN] =
2010 &ApplySingleHashPartitionJoin;
2011 RuleApplyFunctionArray[SINGLE_RANGE_PARTITION_JOIN] =
2012 &ApplySingleRangePartitionJoin;
2013 RuleApplyFunctionArray[DUAL_PARTITION_JOIN] = &ApplyDualPartitionJoin;
2014 RuleApplyFunctionArray[CARTESIAN_PRODUCT_REFERENCE_JOIN] =
2015 &ApplyCartesianProductReferenceJoin;
2016 RuleApplyFunctionArray[CARTESIAN_PRODUCT] = &ApplyCartesianProduct;
2017
2018 ruleApplyFunctionInitialized = true;
2019 }
2020
2021 RuleApplyFunction ruleApplyFunction = RuleApplyFunctionArray[ruleType];
2022 Assert(ruleApplyFunction != NULL);
2023
2024 return ruleApplyFunction;
2025 }
2026
2027
2028 /*
2029 * ApplyBroadcastJoin creates a new MultiJoin node that joins the left and the
2030 * right node. The new node uses the broadcast join rule to perform the join.
2031 */
2032 static MultiNode *
ApplyReferenceJoin(MultiNode * leftNode,MultiNode * rightNode,List * partitionColumnList,JoinType joinType,List * applicableJoinClauses)2033 ApplyReferenceJoin(MultiNode *leftNode, MultiNode *rightNode,
2034 List *partitionColumnList, JoinType joinType,
2035 List *applicableJoinClauses)
2036 {
2037 MultiJoin *joinNode = CitusMakeNode(MultiJoin);
2038 joinNode->joinRuleType = REFERENCE_JOIN;
2039 joinNode->joinType = joinType;
2040 joinNode->joinClauseList = applicableJoinClauses;
2041
2042 SetLeftChild((MultiBinaryNode *) joinNode, leftNode);
2043 SetRightChild((MultiBinaryNode *) joinNode, rightNode);
2044
2045 return (MultiNode *) joinNode;
2046 }
2047
2048
2049 /*
2050 * ApplyCartesianProductReferenceJoin creates a new MultiJoin node that joins
2051 * the left and the right node. The new node uses the broadcast join rule to
2052 * perform the join.
2053 */
2054 static MultiNode *
ApplyCartesianProductReferenceJoin(MultiNode * leftNode,MultiNode * rightNode,List * partitionColumnList,JoinType joinType,List * applicableJoinClauses)2055 ApplyCartesianProductReferenceJoin(MultiNode *leftNode, MultiNode *rightNode,
2056 List *partitionColumnList, JoinType joinType,
2057 List *applicableJoinClauses)
2058 {
2059 MultiJoin *joinNode = CitusMakeNode(MultiJoin);
2060 joinNode->joinRuleType = CARTESIAN_PRODUCT_REFERENCE_JOIN;
2061 joinNode->joinType = joinType;
2062 joinNode->joinClauseList = applicableJoinClauses;
2063
2064 SetLeftChild((MultiBinaryNode *) joinNode, leftNode);
2065 SetRightChild((MultiBinaryNode *) joinNode, rightNode);
2066
2067 return (MultiNode *) joinNode;
2068 }
2069
2070
2071 /*
2072 * ApplyLocalJoin creates a new MultiJoin node that joins the left and the right
2073 * node. The new node uses the local join rule to perform the join.
2074 */
2075 static MultiNode *
ApplyLocalJoin(MultiNode * leftNode,MultiNode * rightNode,List * partitionColumnList,JoinType joinType,List * applicableJoinClauses)2076 ApplyLocalJoin(MultiNode *leftNode, MultiNode *rightNode,
2077 List *partitionColumnList, JoinType joinType,
2078 List *applicableJoinClauses)
2079 {
2080 MultiJoin *joinNode = CitusMakeNode(MultiJoin);
2081 joinNode->joinRuleType = LOCAL_PARTITION_JOIN;
2082 joinNode->joinType = joinType;
2083 joinNode->joinClauseList = applicableJoinClauses;
2084
2085 SetLeftChild((MultiBinaryNode *) joinNode, leftNode);
2086 SetRightChild((MultiBinaryNode *) joinNode, rightNode);
2087
2088 return (MultiNode *) joinNode;
2089 }
2090
2091
2092 /*
2093 * ApplySingleRangePartitionJoin is a wrapper around ApplySinglePartitionJoin()
2094 * which sets the joinRuleType properly.
2095 */
2096 static MultiNode *
ApplySingleRangePartitionJoin(MultiNode * leftNode,MultiNode * rightNode,List * partitionColumnList,JoinType joinType,List * applicableJoinClauses)2097 ApplySingleRangePartitionJoin(MultiNode *leftNode, MultiNode *rightNode,
2098 List *partitionColumnList, JoinType joinType,
2099 List *applicableJoinClauses)
2100 {
2101 MultiJoin *joinNode =
2102 ApplySinglePartitionJoin(leftNode, rightNode, partitionColumnList, joinType,
2103 applicableJoinClauses);
2104
2105 joinNode->joinRuleType = SINGLE_RANGE_PARTITION_JOIN;
2106
2107 return (MultiNode *) joinNode;
2108 }
2109
2110
2111 /*
2112 * ApplySingleHashPartitionJoin is a wrapper around ApplySinglePartitionJoin()
2113 * which sets the joinRuleType properly.
2114 */
2115 static MultiNode *
ApplySingleHashPartitionJoin(MultiNode * leftNode,MultiNode * rightNode,List * partitionColumnList,JoinType joinType,List * applicableJoinClauses)2116 ApplySingleHashPartitionJoin(MultiNode *leftNode, MultiNode *rightNode,
2117 List *partitionColumnList, JoinType joinType,
2118 List *applicableJoinClauses)
2119 {
2120 MultiJoin *joinNode =
2121 ApplySinglePartitionJoin(leftNode, rightNode, partitionColumnList, joinType,
2122 applicableJoinClauses);
2123
2124 joinNode->joinRuleType = SINGLE_HASH_PARTITION_JOIN;
2125
2126 return (MultiNode *) joinNode;
2127 }
2128
2129
2130 /*
2131 * ApplySinglePartitionJoin creates a new MultiJoin node that joins the left and
2132 * right node. The function also adds a MultiPartition node on top of the node
2133 * (left or right) that is not partitioned on the join column.
2134 */
2135 static MultiJoin *
ApplySinglePartitionJoin(MultiNode * leftNode,MultiNode * rightNode,List * partitionColumnList,JoinType joinType,List * applicableJoinClauses)2136 ApplySinglePartitionJoin(MultiNode *leftNode, MultiNode *rightNode,
2137 List *partitionColumnList, JoinType joinType,
2138 List *applicableJoinClauses)
2139 {
2140 Var *partitionColumn = linitial(partitionColumnList);
2141 uint32 partitionTableId = partitionColumn->varno;
2142
2143 /* create all operator structures up front */
2144 MultiJoin *joinNode = CitusMakeNode(MultiJoin);
2145 MultiCollect *collectNode = CitusMakeNode(MultiCollect);
2146 MultiPartition *partitionNode = CitusMakeNode(MultiPartition);
2147
2148 /*
2149 * We first find the appropriate join clause. Then, we compare the partition
2150 * column against the join clause's columns. If one of the columns matches,
2151 * we introduce a (re-)partition operator for the other column.
2152 */
2153 OpExpr *joinClause = SinglePartitionJoinClause(partitionColumnList,
2154 applicableJoinClauses);
2155 Assert(joinClause != NULL);
2156
2157 /* both are verified in SinglePartitionJoinClause to not be NULL, assert is to guard */
2158 Var *leftColumn = LeftColumnOrNULL(joinClause);
2159 Var *rightColumn = RightColumnOrNULL(joinClause);
2160
2161 Assert(leftColumn != NULL);
2162 Assert(rightColumn != NULL);
2163
2164 if (equal(partitionColumn, leftColumn))
2165 {
2166 partitionNode->partitionColumn = rightColumn;
2167 partitionNode->splitPointTableId = partitionTableId;
2168 }
2169 else if (equal(partitionColumn, rightColumn))
2170 {
2171 partitionNode->partitionColumn = leftColumn;
2172 partitionNode->splitPointTableId = partitionTableId;
2173 }
2174
2175 /* determine the node the partition operator goes on top of */
2176 List *rightTableIdList = OutputTableIdList(rightNode);
2177 uint32 rightTableId = (uint32) linitial_int(rightTableIdList);
2178 Assert(list_length(rightTableIdList) == 1);
2179
2180 /*
2181 * If the right child node is partitioned on the partition key column, we
2182 * add the partition operator on the left child node; and vice versa. Then,
2183 * we add a collect operator on top of the partition operator, and always
2184 * make sure that we have at most one relation on the right-hand side.
2185 */
2186 if (partitionTableId == rightTableId)
2187 {
2188 SetChild((MultiUnaryNode *) partitionNode, leftNode);
2189 SetChild((MultiUnaryNode *) collectNode, (MultiNode *) partitionNode);
2190
2191 SetLeftChild((MultiBinaryNode *) joinNode, (MultiNode *) collectNode);
2192 SetRightChild((MultiBinaryNode *) joinNode, rightNode);
2193 }
2194 else
2195 {
2196 SetChild((MultiUnaryNode *) partitionNode, rightNode);
2197 SetChild((MultiUnaryNode *) collectNode, (MultiNode *) partitionNode);
2198
2199 SetLeftChild((MultiBinaryNode *) joinNode, leftNode);
2200 SetRightChild((MultiBinaryNode *) joinNode, (MultiNode *) collectNode);
2201 }
2202
2203 /* finally set join operator fields */
2204 joinNode->joinType = joinType;
2205 joinNode->joinClauseList = applicableJoinClauses;
2206
2207 return joinNode;
2208 }
2209
2210
2211 /*
2212 * ApplyDualPartitionJoin creates a new MultiJoin node that joins the left and
2213 * right node. The function also adds two MultiPartition operators on top of
2214 * both nodes to repartition these nodes' data on the join clause columns.
2215 */
2216 static MultiNode *
ApplyDualPartitionJoin(MultiNode * leftNode,MultiNode * rightNode,List * partitionColumnList,JoinType joinType,List * applicableJoinClauses)2217 ApplyDualPartitionJoin(MultiNode *leftNode, MultiNode *rightNode,
2218 List *partitionColumnList, JoinType joinType,
2219 List *applicableJoinClauses)
2220 {
2221 /* find the appropriate join clause */
2222 OpExpr *joinClause = DualPartitionJoinClause(applicableJoinClauses);
2223 Assert(joinClause != NULL);
2224
2225 /* both are verified in DualPartitionJoinClause to not be NULL, assert is to guard */
2226 Var *leftColumn = LeftColumnOrNULL(joinClause);
2227 Var *rightColumn = RightColumnOrNULL(joinClause);
2228 Assert(leftColumn != NULL);
2229 Assert(rightColumn != NULL);
2230
2231 List *rightTableIdList = OutputTableIdList(rightNode);
2232 uint32 rightTableId = (uint32) linitial_int(rightTableIdList);
2233 Assert(list_length(rightTableIdList) == 1);
2234
2235 MultiPartition *leftPartitionNode = CitusMakeNode(MultiPartition);
2236 MultiPartition *rightPartitionNode = CitusMakeNode(MultiPartition);
2237
2238 /* find the partition node each join clause column belongs to */
2239 if (leftColumn->varno == rightTableId)
2240 {
2241 leftPartitionNode->partitionColumn = rightColumn;
2242 rightPartitionNode->partitionColumn = leftColumn;
2243 }
2244 else
2245 {
2246 leftPartitionNode->partitionColumn = leftColumn;
2247 rightPartitionNode->partitionColumn = rightColumn;
2248 }
2249
2250 /* add partition operators on top of left and right nodes */
2251 SetChild((MultiUnaryNode *) leftPartitionNode, leftNode);
2252 SetChild((MultiUnaryNode *) rightPartitionNode, rightNode);
2253
2254 /* add collect operators on top of the two partition operators */
2255 MultiCollect *leftCollectNode = CitusMakeNode(MultiCollect);
2256 MultiCollect *rightCollectNode = CitusMakeNode(MultiCollect);
2257
2258 SetChild((MultiUnaryNode *) leftCollectNode, (MultiNode *) leftPartitionNode);
2259 SetChild((MultiUnaryNode *) rightCollectNode, (MultiNode *) rightPartitionNode);
2260
2261 /* add join operator on top of the two collect operators */
2262 MultiJoin *joinNode = CitusMakeNode(MultiJoin);
2263 joinNode->joinRuleType = DUAL_PARTITION_JOIN;
2264 joinNode->joinType = joinType;
2265 joinNode->joinClauseList = applicableJoinClauses;
2266
2267 SetLeftChild((MultiBinaryNode *) joinNode, (MultiNode *) leftCollectNode);
2268 SetRightChild((MultiBinaryNode *) joinNode, (MultiNode *) rightCollectNode);
2269
2270 return (MultiNode *) joinNode;
2271 }
2272
2273
2274 /* Creates a cartesian product node that joins the left and the right node. */
2275 static MultiNode *
ApplyCartesianProduct(MultiNode * leftNode,MultiNode * rightNode,List * partitionColumnList,JoinType joinType,List * applicableJoinClauses)2276 ApplyCartesianProduct(MultiNode *leftNode, MultiNode *rightNode,
2277 List *partitionColumnList, JoinType joinType,
2278 List *applicableJoinClauses)
2279 {
2280 MultiCartesianProduct *cartesianNode = CitusMakeNode(MultiCartesianProduct);
2281
2282 SetLeftChild((MultiBinaryNode *) cartesianNode, leftNode);
2283 SetRightChild((MultiBinaryNode *) cartesianNode, rightNode);
2284
2285 return (MultiNode *) cartesianNode;
2286 }
2287
2288
2289 /*
2290 * OperatorImplementsEquality returns true if the given opno represents an
2291 * equality operator. The function retrieves btree interpretation list for this
2292 * opno and check if BTEqualStrategyNumber strategy is present.
2293 */
2294 bool
OperatorImplementsEquality(Oid opno)2295 OperatorImplementsEquality(Oid opno)
2296 {
2297 bool equalityOperator = false;
2298 List *btreeIntepretationList = get_op_btree_interpretation(opno);
2299 ListCell *btreeInterpretationCell = NULL;
2300 foreach(btreeInterpretationCell, btreeIntepretationList)
2301 {
2302 OpBtreeInterpretation *btreeIntepretation = (OpBtreeInterpretation *)
2303 lfirst(btreeInterpretationCell);
2304 if (btreeIntepretation->strategy == BTEqualStrategyNumber)
2305 {
2306 equalityOperator = true;
2307 break;
2308 }
2309 }
2310
2311 return equalityOperator;
2312 }
2313