1 /*-------------------------------------------------------------------------
2 *
3 * multi_join_order.c
4 *
5 * Routines for constructing the join order list using a rule-based approach.
6 *
7 * Copyright (c) Citus Data, Inc.
8 *
9 * $Id$
10 *
11 *-------------------------------------------------------------------------
12 */
13
14 #include "postgres.h"
15
16 #include "distributed/pg_version_constants.h"
17
18 #include <limits.h>
19
20 #include "access/nbtree.h"
21 #include "access/heapam.h"
22 #include "access/htup_details.h"
23 #include "catalog/pg_am.h"
24 #include "distributed/listutils.h"
25 #include "distributed/metadata_cache.h"
26 #include "distributed/multi_join_order.h"
27 #include "distributed/multi_physical_planner.h"
28 #include "distributed/pg_dist_partition.h"
29 #include "distributed/worker_protocol.h"
30 #include "lib/stringinfo.h"
31 #include "optimizer/optimizer.h"
32 #include "utils/builtins.h"
33 #include "nodes/nodeFuncs.h"
34 #include "utils/builtins.h"
35 #include "utils/datum.h"
36 #include "utils/lsyscache.h"
37 #include "utils/rel.h"
38 #include "utils/syscache.h"
39
40
41 /* Config variables managed via guc.c */
42 bool LogMultiJoinOrder = false; /* print join order as a debugging aid */
43 bool EnableSingleHashRepartitioning = false;
44
45 /* Function pointer type definition for join rule evaluation functions */
46 typedef JoinOrderNode *(*RuleEvalFunction) (JoinOrderNode *currentJoinNode,
47 TableEntry *candidateTable,
48 List *applicableJoinClauses,
49 JoinType joinType);
50
51 static char *RuleNameArray[JOIN_RULE_LAST] = { 0 }; /* ordered join rule names */
52 static RuleEvalFunction RuleEvalFunctionArray[JOIN_RULE_LAST] = { 0 }; /* join rules */
53
54
55 /* Local functions forward declarations */
56 static bool JoinExprListWalker(Node *node, List **joinList);
57 static bool ExtractLeftMostRangeTableIndex(Node *node, int *rangeTableIndex);
58 static List * JoinOrderForTable(TableEntry *firstTable, List *tableEntryList,
59 List *joinClauseList);
60 static List * BestJoinOrder(List *candidateJoinOrders);
61 static List * FewestOfJoinRuleType(List *candidateJoinOrders, JoinRuleType ruleType);
62 static uint32 JoinRuleTypeCount(List *joinOrder, JoinRuleType ruleTypeToCount);
63 static List * LatestLargeDataTransfer(List *candidateJoinOrders);
64 static void PrintJoinOrderList(List *joinOrder);
65 static uint32 LargeDataTransferLocation(List *joinOrder);
66 static List * TableEntryListDifference(List *lhsTableList, List *rhsTableList);
67
68 /* Local functions forward declarations for join evaluations */
69 static JoinOrderNode * EvaluateJoinRules(List *joinedTableList,
70 JoinOrderNode *currentJoinNode,
71 TableEntry *candidateTable,
72 List *joinClauseList, JoinType joinType);
73 static List * RangeTableIdList(List *tableList);
74 static RuleEvalFunction JoinRuleEvalFunction(JoinRuleType ruleType);
75 static char * JoinRuleName(JoinRuleType ruleType);
76 static JoinOrderNode * ReferenceJoin(JoinOrderNode *joinNode, TableEntry *candidateTable,
77 List *applicableJoinClauses, JoinType joinType);
78 static JoinOrderNode * CartesianProductReferenceJoin(JoinOrderNode *joinNode,
79 TableEntry *candidateTable,
80 List *applicableJoinClauses,
81 JoinType joinType);
82 static JoinOrderNode * LocalJoin(JoinOrderNode *joinNode, TableEntry *candidateTable,
83 List *applicableJoinClauses, JoinType joinType);
84 static bool JoinOnColumns(List *currentPartitionColumnList, Var *candidatePartitionColumn,
85 List *joinClauseList);
86 static JoinOrderNode * SinglePartitionJoin(JoinOrderNode *joinNode,
87 TableEntry *candidateTable,
88 List *applicableJoinClauses,
89 JoinType joinType);
90 static JoinOrderNode * DualPartitionJoin(JoinOrderNode *joinNode,
91 TableEntry *candidateTable,
92 List *applicableJoinClauses,
93 JoinType joinType);
94 static JoinOrderNode * CartesianProduct(JoinOrderNode *joinNode,
95 TableEntry *candidateTable,
96 List *applicableJoinClauses,
97 JoinType joinType);
98 static JoinOrderNode * MakeJoinOrderNode(TableEntry *tableEntry,
99 JoinRuleType joinRuleType,
100 List *partitionColumnList, char partitionMethod,
101 TableEntry *anchorTable);
102
103
104 /*
105 * JoinExprList flattens the JoinExpr nodes in the FROM expression and translate implicit
106 * joins to inner joins. This function does not consider (right-)nested joins.
107 */
108 List *
JoinExprList(FromExpr * fromExpr)109 JoinExprList(FromExpr *fromExpr)
110 {
111 List *joinList = NIL;
112 List *fromList = fromExpr->fromlist;
113 ListCell *fromCell = NULL;
114
115 foreach(fromCell, fromList)
116 {
117 Node *nextNode = (Node *) lfirst(fromCell);
118
119 if (joinList != NIL)
120 {
121 /* multiple nodes in from clause, add an explicit join between them */
122 int nextRangeTableIndex = 0;
123
124 /* find the left most range table in this node */
125 ExtractLeftMostRangeTableIndex((Node *) fromExpr, &nextRangeTableIndex);
126
127 RangeTblRef *nextRangeTableRef = makeNode(RangeTblRef);
128 nextRangeTableRef->rtindex = nextRangeTableIndex;
129
130 /* join the previous node with nextRangeTableRef */
131 JoinExpr *newJoinExpr = makeNode(JoinExpr);
132 newJoinExpr->jointype = JOIN_INNER;
133 newJoinExpr->rarg = (Node *) nextRangeTableRef;
134 newJoinExpr->quals = NULL;
135
136 joinList = lappend(joinList, newJoinExpr);
137 }
138
139 JoinExprListWalker(nextNode, &joinList);
140 }
141
142 return joinList;
143 }
144
145
146 /*
147 * JoinExprListWalker the JoinExpr nodes in a join tree in the order in which joins are
148 * to be executed. If there are no joins then no elements are added to joinList.
149 */
150 static bool
JoinExprListWalker(Node * node,List ** joinList)151 JoinExprListWalker(Node *node, List **joinList)
152 {
153 bool walkerResult = false;
154
155 if (node == NULL)
156 {
157 return false;
158 }
159
160 if (IsA(node, JoinExpr))
161 {
162 JoinExpr *joinExpr = (JoinExpr *) node;
163
164 walkerResult = JoinExprListWalker(joinExpr->larg, joinList);
165
166 (*joinList) = lappend(*joinList, joinExpr);
167 }
168 else
169 {
170 walkerResult = expression_tree_walker(node, JoinExprListWalker,
171 joinList);
172 }
173
174 return walkerResult;
175 }
176
177
178 /*
179 * ExtractLeftMostRangeTableIndex extracts the range table index of the left-most
180 * leaf in a join tree.
181 */
182 static bool
ExtractLeftMostRangeTableIndex(Node * node,int * rangeTableIndex)183 ExtractLeftMostRangeTableIndex(Node *node, int *rangeTableIndex)
184 {
185 bool walkerResult = false;
186
187 Assert(node != NULL);
188
189 if (IsA(node, JoinExpr))
190 {
191 JoinExpr *joinExpr = (JoinExpr *) node;
192
193 walkerResult = ExtractLeftMostRangeTableIndex(joinExpr->larg, rangeTableIndex);
194 }
195 else if (IsA(node, RangeTblRef))
196 {
197 RangeTblRef *rangeTableRef = (RangeTblRef *) node;
198
199 *rangeTableIndex = rangeTableRef->rtindex;
200 walkerResult = true;
201 }
202 else
203 {
204 walkerResult = expression_tree_walker(node, ExtractLeftMostRangeTableIndex,
205 rangeTableIndex);
206 }
207
208 return walkerResult;
209 }
210
211
212 /*
213 * JoinOnColumns determines whether two columns are joined by a given join clause list.
214 */
215 static bool
JoinOnColumns(List * currentPartitionColumnList,Var * candidateColumn,List * joinClauseList)216 JoinOnColumns(List *currentPartitionColumnList, Var *candidateColumn,
217 List *joinClauseList)
218 {
219 if (candidateColumn == NULL || list_length(currentPartitionColumnList) == 0)
220 {
221 /*
222 * LocalJoin can only be happening if we have both a current column and a target
223 * column, otherwise we are not joining two local tables
224 */
225 return false;
226 }
227
228 Var *currentColumn = NULL;
229 foreach_ptr(currentColumn, currentPartitionColumnList)
230 {
231 Node *joinClause = NULL;
232 foreach_ptr(joinClause, joinClauseList)
233 {
234 if (!NodeIsEqualsOpExpr(joinClause))
235 {
236 continue;
237 }
238 OpExpr *joinClauseOpExpr = castNode(OpExpr, joinClause);
239 Var *leftColumn = LeftColumnOrNULL(joinClauseOpExpr);
240 Var *rightColumn = RightColumnOrNULL(joinClauseOpExpr);
241
242 /*
243 * Check if both join columns and both partition key columns match, since the
244 * current and candidate column's can't be NULL we know they won't match if either
245 * of the columns resolved to NULL above.
246 */
247 if (equal(leftColumn, currentColumn) &&
248 equal(rightColumn, candidateColumn))
249 {
250 return true;
251 }
252 if (equal(leftColumn, candidateColumn) &&
253 equal(rightColumn, currentColumn))
254 {
255 return true;
256 }
257 }
258 }
259
260 return false;
261 }
262
263
264 /*
265 * NodeIsEqualsOpExpr checks if the node is an OpExpr, where the operator
266 * matches OperatorImplementsEquality.
267 */
268 bool
NodeIsEqualsOpExpr(Node * node)269 NodeIsEqualsOpExpr(Node *node)
270 {
271 if (!IsA(node, OpExpr))
272 {
273 return false;
274 }
275 OpExpr *opExpr = castNode(OpExpr, node);
276 return OperatorImplementsEquality(opExpr->opno);
277 }
278
279
280 /*
281 * JoinOrderList calculates the best join order and join rules that apply given
282 * the list of tables and join clauses. First, the function generates a set of
283 * candidate join orders, each with a different table as its first table. Then,
284 * the function chooses among these candidates the join order that transfers the
285 * least amount of data across the network, and returns this join order.
286 */
287 List *
JoinOrderList(List * tableEntryList,List * joinClauseList)288 JoinOrderList(List *tableEntryList, List *joinClauseList)
289 {
290 List *candidateJoinOrderList = NIL;
291 ListCell *tableEntryCell = NULL;
292
293 foreach(tableEntryCell, tableEntryList)
294 {
295 TableEntry *startingTable = (TableEntry *) lfirst(tableEntryCell);
296
297 /* each candidate join order starts with a different table */
298 List *candidateJoinOrder = JoinOrderForTable(startingTable, tableEntryList,
299 joinClauseList);
300
301 if (candidateJoinOrder != NULL)
302 {
303 candidateJoinOrderList = lappend(candidateJoinOrderList, candidateJoinOrder);
304 }
305 }
306
307 if (list_length(candidateJoinOrderList) == 0)
308 {
309 /* there are no plans that we can create, time to error */
310 ereport(ERROR, (errcode(ERRCODE_FEATURE_NOT_SUPPORTED),
311 errmsg("complex joins are only supported when all distributed "
312 "tables are joined on their distribution columns with "
313 "equal operator")));
314 }
315
316 List *bestJoinOrder = BestJoinOrder(candidateJoinOrderList);
317
318 /* if logging is enabled, print join order */
319 if (LogMultiJoinOrder)
320 {
321 PrintJoinOrderList(bestJoinOrder);
322 }
323
324 return bestJoinOrder;
325 }
326
327
328 /*
329 * JoinOrderForTable creates a join order whose first element is the given first
330 * table. To determine each subsequent element in the join order, the function
331 * then chooses the table that has the lowest ranking join rule, and with which
332 * it can join the table to the previous table in the join order. The function
333 * repeats this until it determines all elements in the join order list, and
334 * returns this list.
335 */
336 static List *
JoinOrderForTable(TableEntry * firstTable,List * tableEntryList,List * joinClauseList)337 JoinOrderForTable(TableEntry *firstTable, List *tableEntryList, List *joinClauseList)
338 {
339 JoinRuleType firstJoinRule = JOIN_RULE_INVALID_FIRST;
340 int joinedTableCount = 1;
341 int totalTableCount = list_length(tableEntryList);
342
343 /* create join node for the first table */
344 Oid firstRelationId = firstTable->relationId;
345 uint32 firstTableId = firstTable->rangeTableId;
346 Var *firstPartitionColumn = PartitionColumn(firstRelationId, firstTableId);
347 char firstPartitionMethod = PartitionMethod(firstRelationId);
348
349 JoinOrderNode *firstJoinNode = MakeJoinOrderNode(firstTable, firstJoinRule,
350 list_make1(firstPartitionColumn),
351 firstPartitionMethod,
352 firstTable);
353
354 /* add first node to the join order */
355 List *joinOrderList = list_make1(firstJoinNode);
356 List *joinedTableList = list_make1(firstTable);
357 JoinOrderNode *currentJoinNode = firstJoinNode;
358
359 /* loop until we join all remaining tables */
360 while (joinedTableCount < totalTableCount)
361 {
362 ListCell *pendingTableCell = NULL;
363 JoinOrderNode *nextJoinNode = NULL;
364 JoinRuleType nextJoinRuleType = JOIN_RULE_LAST;
365
366 List *pendingTableList = TableEntryListDifference(tableEntryList,
367 joinedTableList);
368
369 /*
370 * Iterate over all pending tables, and find the next best table to
371 * join. The best table is the one whose join rule requires the least
372 * amount of data transfer.
373 */
374 foreach(pendingTableCell, pendingTableList)
375 {
376 TableEntry *pendingTable = (TableEntry *) lfirst(pendingTableCell);
377 JoinType joinType = JOIN_INNER;
378
379 /* evaluate all join rules for this pending table */
380 JoinOrderNode *pendingJoinNode = EvaluateJoinRules(joinedTableList,
381 currentJoinNode,
382 pendingTable,
383 joinClauseList, joinType);
384
385 if (pendingJoinNode == NULL)
386 {
387 /* no join order could be generated, we try our next pending table */
388 continue;
389 }
390
391 /* if this rule is better than previous ones, keep it */
392 JoinRuleType pendingJoinRuleType = pendingJoinNode->joinRuleType;
393 if (pendingJoinRuleType < nextJoinRuleType)
394 {
395 nextJoinNode = pendingJoinNode;
396 nextJoinRuleType = pendingJoinRuleType;
397 }
398 }
399
400 if (nextJoinNode == NULL)
401 {
402 /*
403 * There is no next join node found, this will repeat indefinitely hence we
404 * bail and let JoinOrderList try a new initial table
405 */
406 return NULL;
407 }
408
409 Assert(nextJoinNode != NULL);
410 TableEntry *nextJoinedTable = nextJoinNode->tableEntry;
411
412 /* add next node to the join order */
413 joinOrderList = lappend(joinOrderList, nextJoinNode);
414 joinedTableList = lappend(joinedTableList, nextJoinedTable);
415 currentJoinNode = nextJoinNode;
416
417 joinedTableCount++;
418 }
419
420 return joinOrderList;
421 }
422
423
424 /*
425 * BestJoinOrder takes in a list of candidate join orders, and determines the
426 * best join order among these candidates. The function uses two heuristics for
427 * this. First, the function chooses join orders that have the fewest number of
428 * join operators that cause large data transfers. Second, the function chooses
429 * join orders where large data transfers occur later in the execution.
430 */
431 static List *
BestJoinOrder(List * candidateJoinOrders)432 BestJoinOrder(List *candidateJoinOrders)
433 {
434 uint32 highestValidIndex = JOIN_RULE_LAST - 1;
435 uint32 candidateCount PG_USED_FOR_ASSERTS_ONLY = 0;
436
437 /*
438 * We start with the highest ranking rule type (cartesian product), and walk
439 * over these rules in reverse order. For each rule type, we then keep join
440 * orders that only contain the fewest number of join rules of that type.
441 *
442 * For example, the algorithm chooses join orders like the following:
443 * (a) The algorithm prefers join orders with 2 cartesian products (CP) to
444 * those that have 3 or more, if there isn't a join order with fewer CPs.
445 * (b) Assuming that all join orders have the same number of CPs, the
446 * algorithm prefers join orders with 2 dual partitions (DP) to those that
447 * have 3 or more, if there isn't a join order with fewer DPs; and so
448 * forth.
449 */
450 for (uint32 ruleTypeIndex = highestValidIndex; ruleTypeIndex > 0; ruleTypeIndex--)
451 {
452 JoinRuleType ruleType = (JoinRuleType) ruleTypeIndex;
453
454 candidateJoinOrders = FewestOfJoinRuleType(candidateJoinOrders, ruleType);
455 }
456
457 /*
458 * If there is a tie, we pick candidate join orders where large data
459 * transfers happen at later stages of query execution. This results in more
460 * data being filtered via joins, selections, and projections earlier on.
461 */
462 candidateJoinOrders = LatestLargeDataTransfer(candidateJoinOrders);
463
464 /* we should have at least one join order left after optimizations */
465 candidateCount = list_length(candidateJoinOrders);
466 Assert(candidateCount > 0);
467
468 /*
469 * If there still is a tie, we pick the join order whose relation appeared
470 * earliest in the query's range table entry list.
471 */
472 List *bestJoinOrder = (List *) linitial(candidateJoinOrders);
473
474 return bestJoinOrder;
475 }
476
477
478 /*
479 * FewestOfJoinRuleType finds join orders that have the fewest number of times
480 * the given join rule occurs in the candidate join orders, and filters all
481 * other join orders. For example, if four candidate join orders have a join
482 * rule appearing 3, 5, 3, and 6 times, only two join orders that have the join
483 * rule appearing 3 times will be returned.
484 */
485 static List *
FewestOfJoinRuleType(List * candidateJoinOrders,JoinRuleType ruleType)486 FewestOfJoinRuleType(List *candidateJoinOrders, JoinRuleType ruleType)
487 {
488 List *fewestJoinOrders = NULL;
489 uint32 fewestRuleCount = INT_MAX;
490 ListCell *joinOrderCell = NULL;
491
492 foreach(joinOrderCell, candidateJoinOrders)
493 {
494 List *joinOrder = (List *) lfirst(joinOrderCell);
495 uint32 ruleTypeCount = JoinRuleTypeCount(joinOrder, ruleType);
496
497 if (ruleTypeCount == fewestRuleCount)
498 {
499 fewestJoinOrders = lappend(fewestJoinOrders, joinOrder);
500 }
501 else if (ruleTypeCount < fewestRuleCount)
502 {
503 fewestJoinOrders = list_make1(joinOrder);
504 fewestRuleCount = ruleTypeCount;
505 }
506 }
507
508 return fewestJoinOrders;
509 }
510
511
512 /* Counts the number of times the given join rule occurs in the join order. */
513 static uint32
JoinRuleTypeCount(List * joinOrder,JoinRuleType ruleTypeToCount)514 JoinRuleTypeCount(List *joinOrder, JoinRuleType ruleTypeToCount)
515 {
516 uint32 ruleTypeCount = 0;
517 ListCell *joinOrderNodeCell = NULL;
518
519 foreach(joinOrderNodeCell, joinOrder)
520 {
521 JoinOrderNode *joinOrderNode = (JoinOrderNode *) lfirst(joinOrderNodeCell);
522
523 JoinRuleType ruleType = joinOrderNode->joinRuleType;
524 if (ruleType == ruleTypeToCount)
525 {
526 ruleTypeCount++;
527 }
528 }
529
530 return ruleTypeCount;
531 }
532
533
534 /*
535 * LatestLargeDataTransfer finds and returns join orders where a large data
536 * transfer join rule occurs as late as possible in the join order. Late large
537 * data transfers result in more data being filtered before data gets shuffled
538 * in the network.
539 */
540 static List *
LatestLargeDataTransfer(List * candidateJoinOrders)541 LatestLargeDataTransfer(List *candidateJoinOrders)
542 {
543 List *latestJoinOrders = NIL;
544 uint32 latestJoinLocation = 0;
545 ListCell *joinOrderCell = NULL;
546
547 foreach(joinOrderCell, candidateJoinOrders)
548 {
549 List *joinOrder = (List *) lfirst(joinOrderCell);
550 uint32 joinRuleLocation = LargeDataTransferLocation(joinOrder);
551
552 if (joinRuleLocation == latestJoinLocation)
553 {
554 latestJoinOrders = lappend(latestJoinOrders, joinOrder);
555 }
556 else if (joinRuleLocation > latestJoinLocation)
557 {
558 latestJoinOrders = list_make1(joinOrder);
559 latestJoinLocation = joinRuleLocation;
560 }
561 }
562
563 return latestJoinOrders;
564 }
565
566
567 /*
568 * LargeDataTransferLocation finds the first location of a large data transfer
569 * join rule, and returns that location. If the join order does not have any
570 * large data transfer rules, the function returns one location past the end of
571 * the join order list.
572 */
573 static uint32
LargeDataTransferLocation(List * joinOrder)574 LargeDataTransferLocation(List *joinOrder)
575 {
576 uint32 joinRuleLocation = 0;
577 ListCell *joinOrderNodeCell = NULL;
578
579 foreach(joinOrderNodeCell, joinOrder)
580 {
581 JoinOrderNode *joinOrderNode = (JoinOrderNode *) lfirst(joinOrderNodeCell);
582 JoinRuleType joinRuleType = joinOrderNode->joinRuleType;
583
584 /* we consider the following join rules to cause large data transfers */
585 if (joinRuleType == SINGLE_HASH_PARTITION_JOIN ||
586 joinRuleType == SINGLE_RANGE_PARTITION_JOIN ||
587 joinRuleType == DUAL_PARTITION_JOIN ||
588 joinRuleType == CARTESIAN_PRODUCT)
589 {
590 break;
591 }
592
593 joinRuleLocation++;
594 }
595
596 return joinRuleLocation;
597 }
598
599
600 /* Prints the join order list and join rules for debugging purposes. */
601 static void
PrintJoinOrderList(List * joinOrder)602 PrintJoinOrderList(List *joinOrder)
603 {
604 StringInfo printBuffer = makeStringInfo();
605 ListCell *joinOrderNodeCell = NULL;
606 bool firstJoinNode = true;
607
608 foreach(joinOrderNodeCell, joinOrder)
609 {
610 JoinOrderNode *joinOrderNode = (JoinOrderNode *) lfirst(joinOrderNodeCell);
611 Oid relationId = joinOrderNode->tableEntry->relationId;
612 char *relationName = get_rel_name(relationId);
613
614 if (firstJoinNode)
615 {
616 appendStringInfo(printBuffer, "[ \"%s\" ]", relationName);
617 firstJoinNode = false;
618 }
619 else
620 {
621 JoinRuleType ruleType = (JoinRuleType) joinOrderNode->joinRuleType;
622 char *ruleName = JoinRuleName(ruleType);
623
624 appendStringInfo(printBuffer, "[ %s ", ruleName);
625 appendStringInfo(printBuffer, "\"%s\" ]", relationName);
626 }
627 }
628
629 ereport(LOG, (errmsg("join order: %s",
630 ApplyLogRedaction(printBuffer->data))));
631 }
632
633
634 /*
635 * TableEntryListDifference returns a list containing table entries that are in
636 * the left-hand side table list, but not in the right-hand side table list.
637 */
638 static List *
TableEntryListDifference(List * lhsTableList,List * rhsTableList)639 TableEntryListDifference(List *lhsTableList, List *rhsTableList)
640 {
641 List *tableListDifference = NIL;
642 ListCell *lhsTableCell = NULL;
643
644 foreach(lhsTableCell, lhsTableList)
645 {
646 TableEntry *lhsTableEntry = (TableEntry *) lfirst(lhsTableCell);
647 ListCell *rhsTableCell = NULL;
648 bool lhsTableEntryExists = false;
649
650 foreach(rhsTableCell, rhsTableList)
651 {
652 TableEntry *rhsTableEntry = (TableEntry *) lfirst(rhsTableCell);
653
654 if ((lhsTableEntry->relationId == rhsTableEntry->relationId) &&
655 (lhsTableEntry->rangeTableId == rhsTableEntry->rangeTableId))
656 {
657 lhsTableEntryExists = true;
658 }
659 }
660
661 if (!lhsTableEntryExists)
662 {
663 tableListDifference = lappend(tableListDifference, lhsTableEntry);
664 }
665 }
666
667 return tableListDifference;
668 }
669
670
671 /*
672 * EvaluateJoinRules takes in a list of already joined tables and a candidate
673 * next table, evaluates different join rules between the two tables, and finds
674 * the best join rule that applies. The function returns the applicable join
675 * order node which includes the join rule and the partition information.
676 */
677 static JoinOrderNode *
EvaluateJoinRules(List * joinedTableList,JoinOrderNode * currentJoinNode,TableEntry * candidateTable,List * joinClauseList,JoinType joinType)678 EvaluateJoinRules(List *joinedTableList, JoinOrderNode *currentJoinNode,
679 TableEntry *candidateTable, List *joinClauseList,
680 JoinType joinType)
681 {
682 JoinOrderNode *nextJoinNode = NULL;
683 uint32 lowestValidIndex = JOIN_RULE_INVALID_FIRST + 1;
684 uint32 highestValidIndex = JOIN_RULE_LAST - 1;
685
686 /*
687 * We first find all applicable join clauses between already joined tables
688 * and the candidate table.
689 */
690 List *joinedTableIdList = RangeTableIdList(joinedTableList);
691 uint32 candidateTableId = candidateTable->rangeTableId;
692 List *applicableJoinClauses = ApplicableJoinClauses(joinedTableIdList,
693 candidateTableId,
694 joinClauseList);
695
696 /* we then evaluate all join rules in order */
697 for (uint32 ruleIndex = lowestValidIndex; ruleIndex <= highestValidIndex; ruleIndex++)
698 {
699 JoinRuleType ruleType = (JoinRuleType) ruleIndex;
700 RuleEvalFunction ruleEvalFunction = JoinRuleEvalFunction(ruleType);
701
702 nextJoinNode = (*ruleEvalFunction)(currentJoinNode,
703 candidateTable,
704 applicableJoinClauses,
705 joinType);
706
707 /* break after finding the first join rule that applies */
708 if (nextJoinNode != NULL)
709 {
710 break;
711 }
712 }
713
714 if (nextJoinNode == NULL)
715 {
716 return NULL;
717 }
718
719 Assert(nextJoinNode != NULL);
720 nextJoinNode->joinType = joinType;
721 nextJoinNode->joinClauseList = applicableJoinClauses;
722 return nextJoinNode;
723 }
724
725
726 /* Extracts range table identifiers from the given table list, and returns them. */
727 static List *
RangeTableIdList(List * tableList)728 RangeTableIdList(List *tableList)
729 {
730 List *rangeTableIdList = NIL;
731 ListCell *tableCell = NULL;
732
733 foreach(tableCell, tableList)
734 {
735 TableEntry *tableEntry = (TableEntry *) lfirst(tableCell);
736
737 uint32 rangeTableId = tableEntry->rangeTableId;
738 rangeTableIdList = lappend_int(rangeTableIdList, rangeTableId);
739 }
740
741 return rangeTableIdList;
742 }
743
744
745 /*
746 * JoinRuleEvalFunction returns a function pointer for the rule evaluation
747 * function; this rule evaluation function corresponds to the given rule type.
748 * The function also initializes the rule evaluation function array in a static
749 * code block, if the array has not been initialized.
750 */
751 static RuleEvalFunction
JoinRuleEvalFunction(JoinRuleType ruleType)752 JoinRuleEvalFunction(JoinRuleType ruleType)
753 {
754 static bool ruleEvalFunctionsInitialized = false;
755
756 if (!ruleEvalFunctionsInitialized)
757 {
758 RuleEvalFunctionArray[REFERENCE_JOIN] = &ReferenceJoin;
759 RuleEvalFunctionArray[LOCAL_PARTITION_JOIN] = &LocalJoin;
760 RuleEvalFunctionArray[SINGLE_RANGE_PARTITION_JOIN] = &SinglePartitionJoin;
761 RuleEvalFunctionArray[SINGLE_HASH_PARTITION_JOIN] = &SinglePartitionJoin;
762 RuleEvalFunctionArray[DUAL_PARTITION_JOIN] = &DualPartitionJoin;
763 RuleEvalFunctionArray[CARTESIAN_PRODUCT_REFERENCE_JOIN] =
764 &CartesianProductReferenceJoin;
765 RuleEvalFunctionArray[CARTESIAN_PRODUCT] = &CartesianProduct;
766
767 ruleEvalFunctionsInitialized = true;
768 }
769
770 RuleEvalFunction ruleEvalFunction = RuleEvalFunctionArray[ruleType];
771 Assert(ruleEvalFunction != NULL);
772
773 return ruleEvalFunction;
774 }
775
776
777 /* Returns a string name for the given join rule type. */
778 static char *
JoinRuleName(JoinRuleType ruleType)779 JoinRuleName(JoinRuleType ruleType)
780 {
781 static bool ruleNamesInitialized = false;
782
783 if (!ruleNamesInitialized)
784 {
785 /* use strdup() to be independent of memory contexts */
786 RuleNameArray[REFERENCE_JOIN] = strdup("reference join");
787 RuleNameArray[LOCAL_PARTITION_JOIN] = strdup("local partition join");
788 RuleNameArray[SINGLE_HASH_PARTITION_JOIN] =
789 strdup("single hash partition join");
790 RuleNameArray[SINGLE_RANGE_PARTITION_JOIN] =
791 strdup("single range partition join");
792 RuleNameArray[DUAL_PARTITION_JOIN] = strdup("dual partition join");
793 RuleNameArray[CARTESIAN_PRODUCT_REFERENCE_JOIN] = strdup(
794 "cartesian product reference join");
795 RuleNameArray[CARTESIAN_PRODUCT] = strdup("cartesian product");
796
797 ruleNamesInitialized = true;
798 }
799
800 char *ruleName = RuleNameArray[ruleType];
801 Assert(ruleName != NULL);
802
803 return ruleName;
804 }
805
806
807 /*
808 * ReferenceJoin evaluates if the candidate table is a reference table for inner,
809 * left and anti join. For right join, current join node must be represented by
810 * a reference table. For full join, both of them must be a reference table.
811 */
812 static JoinOrderNode *
ReferenceJoin(JoinOrderNode * currentJoinNode,TableEntry * candidateTable,List * applicableJoinClauses,JoinType joinType)813 ReferenceJoin(JoinOrderNode *currentJoinNode, TableEntry *candidateTable,
814 List *applicableJoinClauses, JoinType joinType)
815 {
816 int applicableJoinCount = list_length(applicableJoinClauses);
817 if (applicableJoinCount <= 0)
818 {
819 return NULL;
820 }
821
822 bool leftIsReferenceTable = IsCitusTableType(
823 currentJoinNode->tableEntry->relationId,
824 REFERENCE_TABLE);
825 bool rightIsReferenceTable = IsCitusTableType(candidateTable->relationId,
826 REFERENCE_TABLE);
827 if (!IsSupportedReferenceJoin(joinType, leftIsReferenceTable, rightIsReferenceTable))
828 {
829 return NULL;
830 }
831 return MakeJoinOrderNode(candidateTable, REFERENCE_JOIN,
832 currentJoinNode->partitionColumnList,
833 currentJoinNode->partitionMethod,
834 currentJoinNode->anchorTable);
835 }
836
837
838 /*
839 * IsSupportedReferenceJoin checks if with this join type we can safely do a simple join
840 * on the reference table on all the workers.
841 */
842 bool
IsSupportedReferenceJoin(JoinType joinType,bool leftIsReferenceTable,bool rightIsReferenceTable)843 IsSupportedReferenceJoin(JoinType joinType, bool leftIsReferenceTable,
844 bool rightIsReferenceTable)
845 {
846 if ((joinType == JOIN_INNER || joinType == JOIN_LEFT || joinType == JOIN_ANTI) &&
847 rightIsReferenceTable)
848 {
849 return true;
850 }
851 else if ((joinType == JOIN_RIGHT) &&
852 leftIsReferenceTable)
853 {
854 return true;
855 }
856 else if (joinType == JOIN_FULL && leftIsReferenceTable && rightIsReferenceTable)
857 {
858 return true;
859 }
860 return false;
861 }
862
863
864 /*
865 * ReferenceJoin evaluates if the candidate table is a reference table for inner,
866 * left and anti join. For right join, current join node must be represented by
867 * a reference table. For full join, both of them must be a reference table.
868 */
869 static JoinOrderNode *
CartesianProductReferenceJoin(JoinOrderNode * currentJoinNode,TableEntry * candidateTable,List * applicableJoinClauses,JoinType joinType)870 CartesianProductReferenceJoin(JoinOrderNode *currentJoinNode, TableEntry *candidateTable,
871 List *applicableJoinClauses, JoinType joinType)
872 {
873 bool leftIsReferenceTable = IsCitusTableType(
874 currentJoinNode->tableEntry->relationId,
875 REFERENCE_TABLE);
876 bool rightIsReferenceTable = IsCitusTableType(candidateTable->relationId,
877 REFERENCE_TABLE);
878
879 if (!IsSupportedReferenceJoin(joinType, leftIsReferenceTable, rightIsReferenceTable))
880 {
881 return NULL;
882 }
883 return MakeJoinOrderNode(candidateTable, CARTESIAN_PRODUCT_REFERENCE_JOIN,
884 currentJoinNode->partitionColumnList,
885 currentJoinNode->partitionMethod,
886 currentJoinNode->anchorTable);
887 }
888
889
890 /*
891 * LocalJoin takes the current partition key column and the candidate table's
892 * partition key column and the partition method for each table. The function
893 * then evaluates if tables in the join order and the candidate table can be
894 * joined locally, without any data transfers. If they can, the function returns
895 * a join order node for a local join. Otherwise, the function returns null.
896 *
897 * Anchor table is used to decide whether the JoinOrderNode can be joined
898 * locally with the candidate table. That table is updated by each join type
899 * applied over JoinOrderNode. Note that, we lost the anchor table after
900 * dual partitioning and cartesian product.
901 */
902 static JoinOrderNode *
LocalJoin(JoinOrderNode * currentJoinNode,TableEntry * candidateTable,List * applicableJoinClauses,JoinType joinType)903 LocalJoin(JoinOrderNode *currentJoinNode, TableEntry *candidateTable,
904 List *applicableJoinClauses, JoinType joinType)
905 {
906 Oid relationId = candidateTable->relationId;
907 uint32 tableId = candidateTable->rangeTableId;
908 Var *candidatePartitionColumn = PartitionColumn(relationId, tableId);
909 List *currentPartitionColumnList = currentJoinNode->partitionColumnList;
910 char candidatePartitionMethod = PartitionMethod(relationId);
911 char currentPartitionMethod = currentJoinNode->partitionMethod;
912 TableEntry *currentAnchorTable = currentJoinNode->anchorTable;
913
914 /*
915 * If we previously dual-hash re-partitioned the tables for a join or made cartesian
916 * product, there is no anchor table anymore. In that case we don't allow local join.
917 */
918 if (currentAnchorTable == NULL)
919 {
920 return NULL;
921 }
922
923 /* the partition method should be the same for a local join */
924 if (currentPartitionMethod != candidatePartitionMethod)
925 {
926 return NULL;
927 }
928
929 bool joinOnPartitionColumns = JoinOnColumns(currentPartitionColumnList,
930 candidatePartitionColumn,
931 applicableJoinClauses);
932 if (!joinOnPartitionColumns)
933 {
934 return NULL;
935 }
936
937 /* shard interval lists must have 1-1 matching for local joins */
938 bool coPartitionedTables = CoPartitionedTables(currentAnchorTable->relationId,
939 relationId);
940
941 if (!coPartitionedTables)
942 {
943 return NULL;
944 }
945
946 /*
947 * Since we are applying a local join to the candidate table we need to keep track of
948 * the partition column of the candidate table on the MultiJoinNode. This will allow
949 * subsequent joins colocated with this candidate table to correctly be recognized as
950 * a local join as well.
951 */
952 currentPartitionColumnList = list_append_unique(currentPartitionColumnList,
953 candidatePartitionColumn);
954
955 JoinOrderNode *nextJoinNode = MakeJoinOrderNode(candidateTable, LOCAL_PARTITION_JOIN,
956 currentPartitionColumnList,
957 currentPartitionMethod,
958 currentAnchorTable);
959
960
961 return nextJoinNode;
962 }
963
964
965 /*
966 * SinglePartitionJoin takes the current and the candidate table's partition keys
967 * and methods. The function then evaluates if either "tables in the join order"
968 * or the candidate table is already partitioned on a join column. If they are,
969 * the function returns a join order node with the already partitioned column as
970 * the next partition key. Otherwise, the function returns null.
971 */
972 static JoinOrderNode *
SinglePartitionJoin(JoinOrderNode * currentJoinNode,TableEntry * candidateTable,List * applicableJoinClauses,JoinType joinType)973 SinglePartitionJoin(JoinOrderNode *currentJoinNode, TableEntry *candidateTable,
974 List *applicableJoinClauses, JoinType joinType)
975 {
976 List *currentPartitionColumnList = currentJoinNode->partitionColumnList;
977 char currentPartitionMethod = currentJoinNode->partitionMethod;
978 TableEntry *currentAnchorTable = currentJoinNode->anchorTable;
979 JoinRuleType currentJoinRuleType = currentJoinNode->joinRuleType;
980
981
982 Oid relationId = candidateTable->relationId;
983 uint32 tableId = candidateTable->rangeTableId;
984 Var *candidatePartitionColumn = PartitionColumn(relationId, tableId);
985 char candidatePartitionMethod = PartitionMethod(relationId);
986
987 /* outer joins are not supported yet */
988 if (IS_OUTER_JOIN(joinType))
989 {
990 return NULL;
991 }
992
993 /*
994 * If we previously dual-hash re-partitioned the tables for a join or made
995 * cartesian product, we currently don't allow a single-repartition join.
996 */
997 if (currentJoinRuleType == DUAL_PARTITION_JOIN ||
998 currentJoinRuleType == CARTESIAN_PRODUCT)
999 {
1000 return NULL;
1001 }
1002
1003 OpExpr *joinClause =
1004 SinglePartitionJoinClause(currentPartitionColumnList, applicableJoinClauses);
1005 if (joinClause != NULL)
1006 {
1007 if (currentPartitionMethod == DISTRIBUTE_BY_HASH)
1008 {
1009 /*
1010 * Single hash repartitioning may perform worse than dual hash
1011 * repartitioning. Thus, we control it via a guc.
1012 */
1013 if (!EnableSingleHashRepartitioning)
1014 {
1015 return NULL;
1016 }
1017
1018 return MakeJoinOrderNode(candidateTable, SINGLE_HASH_PARTITION_JOIN,
1019 currentPartitionColumnList,
1020 currentPartitionMethod,
1021 currentAnchorTable);
1022 }
1023 else
1024 {
1025 return MakeJoinOrderNode(candidateTable, SINGLE_RANGE_PARTITION_JOIN,
1026 currentPartitionColumnList,
1027 currentPartitionMethod,
1028 currentAnchorTable);
1029 }
1030 }
1031
1032 /* evaluate re-partitioning the current table only if the rule didn't apply above */
1033 if (candidatePartitionMethod != DISTRIBUTE_BY_NONE)
1034 {
1035 /*
1036 * Create a new unique list (set) with the partition column of the candidate table
1037 * to check if a single repartition join will work for this table. When it works
1038 * the set is retained on the MultiJoinNode for later local join verification.
1039 */
1040 List *candidatePartitionColumnList = list_make1(candidatePartitionColumn);
1041 joinClause = SinglePartitionJoinClause(candidatePartitionColumnList,
1042 applicableJoinClauses);
1043 if (joinClause != NULL)
1044 {
1045 if (candidatePartitionMethod == DISTRIBUTE_BY_HASH)
1046 {
1047 /*
1048 * Single hash repartitioning may perform worse than dual hash
1049 * repartitioning. Thus, we control it via a guc.
1050 */
1051 if (!EnableSingleHashRepartitioning)
1052 {
1053 return NULL;
1054 }
1055
1056 return MakeJoinOrderNode(candidateTable,
1057 SINGLE_HASH_PARTITION_JOIN,
1058 candidatePartitionColumnList,
1059 candidatePartitionMethod,
1060 candidateTable);
1061 }
1062 else
1063 {
1064 return MakeJoinOrderNode(candidateTable,
1065 SINGLE_RANGE_PARTITION_JOIN,
1066 candidatePartitionColumnList,
1067 candidatePartitionMethod,
1068 candidateTable);
1069 }
1070 }
1071 }
1072
1073 return NULL;
1074 }
1075
1076
1077 /*
1078 * SinglePartitionJoinClause walks over the applicable join clause list, and
1079 * finds an applicable join clause for the given partition column. If no such
1080 * clause exists, the function returns NULL.
1081 */
1082 OpExpr *
SinglePartitionJoinClause(List * partitionColumnList,List * applicableJoinClauses)1083 SinglePartitionJoinClause(List *partitionColumnList, List *applicableJoinClauses)
1084 {
1085 if (list_length(partitionColumnList) == 0)
1086 {
1087 return NULL;
1088 }
1089
1090 Var *partitionColumn = NULL;
1091 foreach_ptr(partitionColumn, partitionColumnList)
1092 {
1093 Node *applicableJoinClause = NULL;
1094 foreach_ptr(applicableJoinClause, applicableJoinClauses)
1095 {
1096 if (!NodeIsEqualsOpExpr(applicableJoinClause))
1097 {
1098 continue;
1099 }
1100 OpExpr *applicableJoinOpExpr = castNode(OpExpr, applicableJoinClause);
1101 Var *leftColumn = LeftColumnOrNULL(applicableJoinOpExpr);
1102 Var *rightColumn = RightColumnOrNULL(applicableJoinOpExpr);
1103 if (leftColumn == NULL || rightColumn == NULL)
1104 {
1105 /* not a simple partition column join */
1106 continue;
1107 }
1108
1109
1110 /*
1111 * We first check if partition column matches either of the join columns
1112 * and if it does, we then check if the join column types match. If the
1113 * types are different, we will use different hash functions for the two
1114 * column types, and will incorrectly repartition the data.
1115 */
1116 if (equal(leftColumn, partitionColumn) || equal(rightColumn, partitionColumn))
1117 {
1118 if (leftColumn->vartype == rightColumn->vartype)
1119 {
1120 return applicableJoinOpExpr;
1121 }
1122 else
1123 {
1124 ereport(DEBUG1, (errmsg("single partition column types do not "
1125 "match")));
1126 }
1127 }
1128 }
1129 }
1130
1131 return NULL;
1132 }
1133
1134
1135 /*
1136 * DualPartitionJoin evaluates if a join clause exists between "tables in the
1137 * join order" and the candidate table. If such a clause exists, both tables can
1138 * be repartitioned on the join column; and the function returns a join order
1139 * node with the join column as the next partition key. Otherwise, the function
1140 * returns null.
1141 */
1142 static JoinOrderNode *
DualPartitionJoin(JoinOrderNode * currentJoinNode,TableEntry * candidateTable,List * applicableJoinClauses,JoinType joinType)1143 DualPartitionJoin(JoinOrderNode *currentJoinNode, TableEntry *candidateTable,
1144 List *applicableJoinClauses, JoinType joinType)
1145 {
1146 OpExpr *joinClause = DualPartitionJoinClause(applicableJoinClauses);
1147 if (joinClause)
1148 {
1149 /* because of the dual partition, anchor table and partition column get lost */
1150 return MakeJoinOrderNode(candidateTable,
1151 DUAL_PARTITION_JOIN,
1152 NIL,
1153 REDISTRIBUTE_BY_HASH,
1154 NULL);
1155 }
1156
1157 return NULL;
1158 }
1159
1160
1161 /*
1162 * DualPartitionJoinClause walks over the applicable join clause list, and finds
1163 * an applicable join clause for dual re-partitioning. If no such clause exists,
1164 * the function returns NULL.
1165 */
1166 OpExpr *
DualPartitionJoinClause(List * applicableJoinClauses)1167 DualPartitionJoinClause(List *applicableJoinClauses)
1168 {
1169 Node *applicableJoinClause = NULL;
1170 foreach_ptr(applicableJoinClause, applicableJoinClauses)
1171 {
1172 if (!NodeIsEqualsOpExpr(applicableJoinClause))
1173 {
1174 continue;
1175 }
1176 OpExpr *applicableJoinOpExpr = castNode(OpExpr, applicableJoinClause);
1177 Var *leftColumn = LeftColumnOrNULL(applicableJoinOpExpr);
1178 Var *rightColumn = RightColumnOrNULL(applicableJoinOpExpr);
1179
1180 if (leftColumn == NULL || rightColumn == NULL)
1181 {
1182 continue;
1183 }
1184
1185 /* we only need to check that the join column types match */
1186 if (leftColumn->vartype == rightColumn->vartype)
1187 {
1188 return applicableJoinOpExpr;
1189 }
1190 else
1191 {
1192 ereport(DEBUG1, (errmsg("dual partition column types do not match")));
1193 }
1194 }
1195
1196 return NULL;
1197 }
1198
1199
1200 /*
1201 * CartesianProduct always evaluates to true since all tables can be combined
1202 * using a cartesian product operator. This function acts as a catch-all rule,
1203 * in case none of the join rules apply.
1204 */
1205 static JoinOrderNode *
CartesianProduct(JoinOrderNode * currentJoinNode,TableEntry * candidateTable,List * applicableJoinClauses,JoinType joinType)1206 CartesianProduct(JoinOrderNode *currentJoinNode, TableEntry *candidateTable,
1207 List *applicableJoinClauses, JoinType joinType)
1208 {
1209 if (list_length(applicableJoinClauses) == 0)
1210 {
1211 /* Because of the cartesian product, anchor table information got lost */
1212 return MakeJoinOrderNode(candidateTable, CARTESIAN_PRODUCT,
1213 currentJoinNode->partitionColumnList,
1214 currentJoinNode->partitionMethod,
1215 NULL);
1216 }
1217
1218 return NULL;
1219 }
1220
1221
1222 /* Constructs and returns a join-order node with the given arguments */
1223 JoinOrderNode *
MakeJoinOrderNode(TableEntry * tableEntry,JoinRuleType joinRuleType,List * partitionColumnList,char partitionMethod,TableEntry * anchorTable)1224 MakeJoinOrderNode(TableEntry *tableEntry, JoinRuleType joinRuleType,
1225 List *partitionColumnList, char partitionMethod,
1226 TableEntry *anchorTable)
1227 {
1228 JoinOrderNode *joinOrderNode = palloc0(sizeof(JoinOrderNode));
1229 joinOrderNode->tableEntry = tableEntry;
1230 joinOrderNode->joinRuleType = joinRuleType;
1231 joinOrderNode->joinType = JOIN_INNER;
1232 joinOrderNode->partitionColumnList = partitionColumnList;
1233 joinOrderNode->partitionMethod = partitionMethod;
1234 joinOrderNode->joinClauseList = NIL;
1235 joinOrderNode->anchorTable = anchorTable;
1236
1237 return joinOrderNode;
1238 }
1239
1240
1241 /*
1242 * IsApplicableJoinClause tests if the current joinClause is applicable to the join at
1243 * hand.
1244 *
1245 * Given a list of left hand tables and a candidate right hand table the join clause is
1246 * valid if atleast 1 column is from the right hand table AND all columns can be found
1247 * in either the list of tables on the left *or* in the right hand table.
1248 */
1249 bool
IsApplicableJoinClause(List * leftTableIdList,uint32 rightTableId,Node * joinClause)1250 IsApplicableJoinClause(List *leftTableIdList, uint32 rightTableId, Node *joinClause)
1251 {
1252 List *varList = pull_var_clause_default(joinClause);
1253 Var *var = NULL;
1254 bool joinContainsRightTable = false;
1255 foreach_ptr(var, varList)
1256 {
1257 uint32 columnTableId = var->varno;
1258 if (rightTableId == columnTableId)
1259 {
1260 joinContainsRightTable = true;
1261 }
1262 else if (!list_member_int(leftTableIdList, columnTableId))
1263 {
1264 /*
1265 * We couldn't find this column either on the right hand side (first if
1266 * statement), nor in the list on the left. This join clause involves a table
1267 * not yet available during the candidate join.
1268 */
1269 return false;
1270 }
1271 }
1272
1273 /*
1274 * All columns referenced in this clause are available during this join, now the join
1275 * is applicable if we found our candidate table as well
1276 */
1277 return joinContainsRightTable;
1278 }
1279
1280
1281 /*
1282 * ApplicableJoinClauses finds all join clauses that apply between the given
1283 * left table list and the right table, and returns these found join clauses.
1284 */
1285 List *
ApplicableJoinClauses(List * leftTableIdList,uint32 rightTableId,List * joinClauseList)1286 ApplicableJoinClauses(List *leftTableIdList, uint32 rightTableId, List *joinClauseList)
1287 {
1288 List *applicableJoinClauses = NIL;
1289
1290 /* make sure joinClauseList contains only join clauses */
1291 joinClauseList = JoinClauseList(joinClauseList);
1292
1293 Node *joinClause = NULL;
1294 foreach_ptr(joinClause, joinClauseList)
1295 {
1296 if (IsApplicableJoinClause(leftTableIdList, rightTableId, joinClause))
1297 {
1298 applicableJoinClauses = lappend(applicableJoinClauses, joinClause);
1299 }
1300 }
1301
1302 return applicableJoinClauses;
1303 }
1304
1305
1306 /*
1307 * Returns the left column only when directly referenced in the given join clause,
1308 * otherwise NULL is returned.
1309 */
1310 Var *
LeftColumnOrNULL(OpExpr * joinClause)1311 LeftColumnOrNULL(OpExpr *joinClause)
1312 {
1313 List *argumentList = joinClause->args;
1314 Node *leftArgument = (Node *) linitial(argumentList);
1315
1316 leftArgument = strip_implicit_coercions(leftArgument);
1317 if (!IsA(leftArgument, Var))
1318 {
1319 return NULL;
1320 }
1321 return castNode(Var, leftArgument);
1322 }
1323
1324
1325 /*
1326 * Returns the right column only when directly referenced in the given join clause,
1327 * otherwise NULL is returned.
1328 * */
1329 Var *
RightColumnOrNULL(OpExpr * joinClause)1330 RightColumnOrNULL(OpExpr *joinClause)
1331 {
1332 List *argumentList = joinClause->args;
1333 Node *rightArgument = (Node *) lsecond(argumentList);
1334
1335 rightArgument = strip_implicit_coercions(rightArgument);
1336 if (!IsA(rightArgument, Var))
1337 {
1338 return NULL;
1339 }
1340 return castNode(Var, rightArgument);
1341 }
1342
1343
1344 /*
1345 * PartitionColumn builds the partition column for the given relation, and sets
1346 * the partition column's range table references to the given table identifier.
1347 *
1348 * Note that reference tables do not have partition column. Thus, this function
1349 * returns NULL when called for reference tables.
1350 */
1351 Var *
PartitionColumn(Oid relationId,uint32 rangeTableId)1352 PartitionColumn(Oid relationId, uint32 rangeTableId)
1353 {
1354 Var *partitionKey = DistPartitionKey(relationId);
1355 Var *partitionColumn = NULL;
1356
1357 /* short circuit for reference tables */
1358 if (partitionKey == NULL)
1359 {
1360 return partitionColumn;
1361 }
1362
1363 partitionColumn = partitionKey;
1364 partitionColumn->varno = rangeTableId;
1365 partitionColumn->varnosyn = rangeTableId;
1366
1367 return partitionColumn;
1368 }
1369
1370
1371 /*
1372 * DistPartitionKey returns the partition key column for the given relation. Note
1373 * that in the context of distributed join and query planning, the callers of
1374 * this function *must* set the partition key column's range table reference
1375 * (varno) to match the table's location in the query range table list.
1376 *
1377 * Note that reference tables do not have partition column. Thus, this function
1378 * returns NULL when called for reference tables.
1379 */
1380 Var *
DistPartitionKey(Oid relationId)1381 DistPartitionKey(Oid relationId)
1382 {
1383 CitusTableCacheEntry *partitionEntry = GetCitusTableCacheEntry(relationId);
1384
1385 /* non-distributed tables do not have partition column */
1386 if (IsCitusTableTypeCacheEntry(partitionEntry, CITUS_TABLE_WITH_NO_DIST_KEY))
1387 {
1388 return NULL;
1389 }
1390
1391 return copyObject(partitionEntry->partitionColumn);
1392 }
1393
1394
1395 /*
1396 * DistPartitionKeyOrError is the same as DistPartitionKey but errors out instead
1397 * of returning NULL if this is called with a relationId of a reference table.
1398 */
1399 Var *
DistPartitionKeyOrError(Oid relationId)1400 DistPartitionKeyOrError(Oid relationId)
1401 {
1402 Var *partitionKey = DistPartitionKey(relationId);
1403
1404 if (partitionKey == NULL)
1405 {
1406 ereport(ERROR, (errmsg(
1407 "no distribution column found for relation %d, because it is a reference table",
1408 relationId)));
1409 }
1410
1411 return partitionKey;
1412 }
1413
1414
1415 /* Returns the partition method for the given relation. */
1416 char
PartitionMethod(Oid relationId)1417 PartitionMethod(Oid relationId)
1418 {
1419 /* errors out if not a distributed table */
1420 CitusTableCacheEntry *partitionEntry = GetCitusTableCacheEntry(relationId);
1421
1422 char partitionMethod = partitionEntry->partitionMethod;
1423
1424 return partitionMethod;
1425 }
1426
1427
1428 /* Returns the replication model for the given relation. */
1429 char
TableReplicationModel(Oid relationId)1430 TableReplicationModel(Oid relationId)
1431 {
1432 /* errors out if not a distributed table */
1433 CitusTableCacheEntry *partitionEntry = GetCitusTableCacheEntry(relationId);
1434
1435 char replicationModel = partitionEntry->replicationModel;
1436
1437 return replicationModel;
1438 }
1439