1 /*-------------------------------------------------------------------------
2  *
3  * multi_join_order.c
4  *
5  * Routines for constructing the join order list using a rule-based approach.
6  *
7  * Copyright (c) Citus Data, Inc.
8  *
9  * $Id$
10  *
11  *-------------------------------------------------------------------------
12  */
13 
14 #include "postgres.h"
15 
16 #include "distributed/pg_version_constants.h"
17 
18 #include <limits.h>
19 
20 #include "access/nbtree.h"
21 #include "access/heapam.h"
22 #include "access/htup_details.h"
23 #include "catalog/pg_am.h"
24 #include "distributed/listutils.h"
25 #include "distributed/metadata_cache.h"
26 #include "distributed/multi_join_order.h"
27 #include "distributed/multi_physical_planner.h"
28 #include "distributed/pg_dist_partition.h"
29 #include "distributed/worker_protocol.h"
30 #include "lib/stringinfo.h"
31 #include "optimizer/optimizer.h"
32 #include "utils/builtins.h"
33 #include "nodes/nodeFuncs.h"
34 #include "utils/builtins.h"
35 #include "utils/datum.h"
36 #include "utils/lsyscache.h"
37 #include "utils/rel.h"
38 #include "utils/syscache.h"
39 
40 
41 /* Config variables managed via guc.c */
42 bool LogMultiJoinOrder = false; /* print join order as a debugging aid */
43 bool EnableSingleHashRepartitioning = false;
44 
45 /* Function pointer type definition for join rule evaluation functions */
46 typedef JoinOrderNode *(*RuleEvalFunction) (JoinOrderNode *currentJoinNode,
47 											TableEntry *candidateTable,
48 											List *applicableJoinClauses,
49 											JoinType joinType);
50 
51 static char *RuleNameArray[JOIN_RULE_LAST] = { 0 }; /* ordered join rule names */
52 static RuleEvalFunction RuleEvalFunctionArray[JOIN_RULE_LAST] = { 0 }; /* join rules */
53 
54 
55 /* Local functions forward declarations */
56 static bool JoinExprListWalker(Node *node, List **joinList);
57 static bool ExtractLeftMostRangeTableIndex(Node *node, int *rangeTableIndex);
58 static List * JoinOrderForTable(TableEntry *firstTable, List *tableEntryList,
59 								List *joinClauseList);
60 static List * BestJoinOrder(List *candidateJoinOrders);
61 static List * FewestOfJoinRuleType(List *candidateJoinOrders, JoinRuleType ruleType);
62 static uint32 JoinRuleTypeCount(List *joinOrder, JoinRuleType ruleTypeToCount);
63 static List * LatestLargeDataTransfer(List *candidateJoinOrders);
64 static void PrintJoinOrderList(List *joinOrder);
65 static uint32 LargeDataTransferLocation(List *joinOrder);
66 static List * TableEntryListDifference(List *lhsTableList, List *rhsTableList);
67 
68 /* Local functions forward declarations for join evaluations */
69 static JoinOrderNode * EvaluateJoinRules(List *joinedTableList,
70 										 JoinOrderNode *currentJoinNode,
71 										 TableEntry *candidateTable,
72 										 List *joinClauseList, JoinType joinType);
73 static List * RangeTableIdList(List *tableList);
74 static RuleEvalFunction JoinRuleEvalFunction(JoinRuleType ruleType);
75 static char * JoinRuleName(JoinRuleType ruleType);
76 static JoinOrderNode * ReferenceJoin(JoinOrderNode *joinNode, TableEntry *candidateTable,
77 									 List *applicableJoinClauses, JoinType joinType);
78 static JoinOrderNode * CartesianProductReferenceJoin(JoinOrderNode *joinNode,
79 													 TableEntry *candidateTable,
80 													 List *applicableJoinClauses,
81 													 JoinType joinType);
82 static JoinOrderNode * LocalJoin(JoinOrderNode *joinNode, TableEntry *candidateTable,
83 								 List *applicableJoinClauses, JoinType joinType);
84 static bool JoinOnColumns(List *currentPartitionColumnList, Var *candidatePartitionColumn,
85 						  List *joinClauseList);
86 static JoinOrderNode * SinglePartitionJoin(JoinOrderNode *joinNode,
87 										   TableEntry *candidateTable,
88 										   List *applicableJoinClauses,
89 										   JoinType joinType);
90 static JoinOrderNode * DualPartitionJoin(JoinOrderNode *joinNode,
91 										 TableEntry *candidateTable,
92 										 List *applicableJoinClauses,
93 										 JoinType joinType);
94 static JoinOrderNode * CartesianProduct(JoinOrderNode *joinNode,
95 										TableEntry *candidateTable,
96 										List *applicableJoinClauses,
97 										JoinType joinType);
98 static JoinOrderNode * MakeJoinOrderNode(TableEntry *tableEntry,
99 										 JoinRuleType joinRuleType,
100 										 List *partitionColumnList, char partitionMethod,
101 										 TableEntry *anchorTable);
102 
103 
104 /*
105  * JoinExprList flattens the JoinExpr nodes in the FROM expression and translate implicit
106  * joins to inner joins. This function does not consider (right-)nested joins.
107  */
108 List *
JoinExprList(FromExpr * fromExpr)109 JoinExprList(FromExpr *fromExpr)
110 {
111 	List *joinList = NIL;
112 	List *fromList = fromExpr->fromlist;
113 	ListCell *fromCell = NULL;
114 
115 	foreach(fromCell, fromList)
116 	{
117 		Node *nextNode = (Node *) lfirst(fromCell);
118 
119 		if (joinList != NIL)
120 		{
121 			/* multiple nodes in from clause, add an explicit join between them */
122 			int nextRangeTableIndex = 0;
123 
124 			/* find the left most range table in this node */
125 			ExtractLeftMostRangeTableIndex((Node *) fromExpr, &nextRangeTableIndex);
126 
127 			RangeTblRef *nextRangeTableRef = makeNode(RangeTblRef);
128 			nextRangeTableRef->rtindex = nextRangeTableIndex;
129 
130 			/* join the previous node with nextRangeTableRef */
131 			JoinExpr *newJoinExpr = makeNode(JoinExpr);
132 			newJoinExpr->jointype = JOIN_INNER;
133 			newJoinExpr->rarg = (Node *) nextRangeTableRef;
134 			newJoinExpr->quals = NULL;
135 
136 			joinList = lappend(joinList, newJoinExpr);
137 		}
138 
139 		JoinExprListWalker(nextNode, &joinList);
140 	}
141 
142 	return joinList;
143 }
144 
145 
146 /*
147  * JoinExprListWalker the JoinExpr nodes in a join tree in the order in which joins are
148  * to be executed. If there are no joins then no elements are added to joinList.
149  */
150 static bool
JoinExprListWalker(Node * node,List ** joinList)151 JoinExprListWalker(Node *node, List **joinList)
152 {
153 	bool walkerResult = false;
154 
155 	if (node == NULL)
156 	{
157 		return false;
158 	}
159 
160 	if (IsA(node, JoinExpr))
161 	{
162 		JoinExpr *joinExpr = (JoinExpr *) node;
163 
164 		walkerResult = JoinExprListWalker(joinExpr->larg, joinList);
165 
166 		(*joinList) = lappend(*joinList, joinExpr);
167 	}
168 	else
169 	{
170 		walkerResult = expression_tree_walker(node, JoinExprListWalker,
171 											  joinList);
172 	}
173 
174 	return walkerResult;
175 }
176 
177 
178 /*
179  * ExtractLeftMostRangeTableIndex extracts the range table index of the left-most
180  * leaf in a join tree.
181  */
182 static bool
ExtractLeftMostRangeTableIndex(Node * node,int * rangeTableIndex)183 ExtractLeftMostRangeTableIndex(Node *node, int *rangeTableIndex)
184 {
185 	bool walkerResult = false;
186 
187 	Assert(node != NULL);
188 
189 	if (IsA(node, JoinExpr))
190 	{
191 		JoinExpr *joinExpr = (JoinExpr *) node;
192 
193 		walkerResult = ExtractLeftMostRangeTableIndex(joinExpr->larg, rangeTableIndex);
194 	}
195 	else if (IsA(node, RangeTblRef))
196 	{
197 		RangeTblRef *rangeTableRef = (RangeTblRef *) node;
198 
199 		*rangeTableIndex = rangeTableRef->rtindex;
200 		walkerResult = true;
201 	}
202 	else
203 	{
204 		walkerResult = expression_tree_walker(node, ExtractLeftMostRangeTableIndex,
205 											  rangeTableIndex);
206 	}
207 
208 	return walkerResult;
209 }
210 
211 
212 /*
213  * JoinOnColumns determines whether two columns are joined by a given join clause list.
214  */
215 static bool
JoinOnColumns(List * currentPartitionColumnList,Var * candidateColumn,List * joinClauseList)216 JoinOnColumns(List *currentPartitionColumnList, Var *candidateColumn,
217 			  List *joinClauseList)
218 {
219 	if (candidateColumn == NULL || list_length(currentPartitionColumnList) == 0)
220 	{
221 		/*
222 		 * LocalJoin can only be happening if we have both a current column and a target
223 		 * column, otherwise we are not joining two local tables
224 		 */
225 		return false;
226 	}
227 
228 	Var *currentColumn = NULL;
229 	foreach_ptr(currentColumn, currentPartitionColumnList)
230 	{
231 		Node *joinClause = NULL;
232 		foreach_ptr(joinClause, joinClauseList)
233 		{
234 			if (!NodeIsEqualsOpExpr(joinClause))
235 			{
236 				continue;
237 			}
238 			OpExpr *joinClauseOpExpr = castNode(OpExpr, joinClause);
239 			Var *leftColumn = LeftColumnOrNULL(joinClauseOpExpr);
240 			Var *rightColumn = RightColumnOrNULL(joinClauseOpExpr);
241 
242 			/*
243 			 * Check if both join columns and both partition key columns match, since the
244 			 * current and candidate column's can't be NULL we know they won't match if either
245 			 * of the columns resolved to NULL above.
246 			 */
247 			if (equal(leftColumn, currentColumn) &&
248 				equal(rightColumn, candidateColumn))
249 			{
250 				return true;
251 			}
252 			if (equal(leftColumn, candidateColumn) &&
253 				equal(rightColumn, currentColumn))
254 			{
255 				return true;
256 			}
257 		}
258 	}
259 
260 	return false;
261 }
262 
263 
264 /*
265  * NodeIsEqualsOpExpr checks if the node is an OpExpr, where the operator
266  * matches OperatorImplementsEquality.
267  */
268 bool
NodeIsEqualsOpExpr(Node * node)269 NodeIsEqualsOpExpr(Node *node)
270 {
271 	if (!IsA(node, OpExpr))
272 	{
273 		return false;
274 	}
275 	OpExpr *opExpr = castNode(OpExpr, node);
276 	return OperatorImplementsEquality(opExpr->opno);
277 }
278 
279 
280 /*
281  * JoinOrderList calculates the best join order and join rules that apply given
282  * the list of tables and join clauses. First, the function generates a set of
283  * candidate join orders, each with a different table as its first table. Then,
284  * the function chooses among these candidates the join order that transfers the
285  * least amount of data across the network, and returns this join order.
286  */
287 List *
JoinOrderList(List * tableEntryList,List * joinClauseList)288 JoinOrderList(List *tableEntryList, List *joinClauseList)
289 {
290 	List *candidateJoinOrderList = NIL;
291 	ListCell *tableEntryCell = NULL;
292 
293 	foreach(tableEntryCell, tableEntryList)
294 	{
295 		TableEntry *startingTable = (TableEntry *) lfirst(tableEntryCell);
296 
297 		/* each candidate join order starts with a different table */
298 		List *candidateJoinOrder = JoinOrderForTable(startingTable, tableEntryList,
299 													 joinClauseList);
300 
301 		if (candidateJoinOrder != NULL)
302 		{
303 			candidateJoinOrderList = lappend(candidateJoinOrderList, candidateJoinOrder);
304 		}
305 	}
306 
307 	if (list_length(candidateJoinOrderList) == 0)
308 	{
309 		/* there are no plans that we can create, time to error */
310 		ereport(ERROR, (errcode(ERRCODE_FEATURE_NOT_SUPPORTED),
311 						errmsg("complex joins are only supported when all distributed "
312 							   "tables are joined on their distribution columns with "
313 							   "equal operator")));
314 	}
315 
316 	List *bestJoinOrder = BestJoinOrder(candidateJoinOrderList);
317 
318 	/* if logging is enabled, print join order */
319 	if (LogMultiJoinOrder)
320 	{
321 		PrintJoinOrderList(bestJoinOrder);
322 	}
323 
324 	return bestJoinOrder;
325 }
326 
327 
328 /*
329  * JoinOrderForTable creates a join order whose first element is the given first
330  * table. To determine each subsequent element in the join order, the function
331  * then chooses the table that has the lowest ranking join rule, and with which
332  * it can join the table to the previous table in the join order. The function
333  * repeats this until it determines all elements in the join order list, and
334  * returns this list.
335  */
336 static List *
JoinOrderForTable(TableEntry * firstTable,List * tableEntryList,List * joinClauseList)337 JoinOrderForTable(TableEntry *firstTable, List *tableEntryList, List *joinClauseList)
338 {
339 	JoinRuleType firstJoinRule = JOIN_RULE_INVALID_FIRST;
340 	int joinedTableCount = 1;
341 	int totalTableCount = list_length(tableEntryList);
342 
343 	/* create join node for the first table */
344 	Oid firstRelationId = firstTable->relationId;
345 	uint32 firstTableId = firstTable->rangeTableId;
346 	Var *firstPartitionColumn = PartitionColumn(firstRelationId, firstTableId);
347 	char firstPartitionMethod = PartitionMethod(firstRelationId);
348 
349 	JoinOrderNode *firstJoinNode = MakeJoinOrderNode(firstTable, firstJoinRule,
350 													 list_make1(firstPartitionColumn),
351 													 firstPartitionMethod,
352 													 firstTable);
353 
354 	/* add first node to the join order */
355 	List *joinOrderList = list_make1(firstJoinNode);
356 	List *joinedTableList = list_make1(firstTable);
357 	JoinOrderNode *currentJoinNode = firstJoinNode;
358 
359 	/* loop until we join all remaining tables */
360 	while (joinedTableCount < totalTableCount)
361 	{
362 		ListCell *pendingTableCell = NULL;
363 		JoinOrderNode *nextJoinNode = NULL;
364 		JoinRuleType nextJoinRuleType = JOIN_RULE_LAST;
365 
366 		List *pendingTableList = TableEntryListDifference(tableEntryList,
367 														  joinedTableList);
368 
369 		/*
370 		 * Iterate over all pending tables, and find the next best table to
371 		 * join. The best table is the one whose join rule requires the least
372 		 * amount of data transfer.
373 		 */
374 		foreach(pendingTableCell, pendingTableList)
375 		{
376 			TableEntry *pendingTable = (TableEntry *) lfirst(pendingTableCell);
377 			JoinType joinType = JOIN_INNER;
378 
379 			/* evaluate all join rules for this pending table */
380 			JoinOrderNode *pendingJoinNode = EvaluateJoinRules(joinedTableList,
381 															   currentJoinNode,
382 															   pendingTable,
383 															   joinClauseList, joinType);
384 
385 			if (pendingJoinNode == NULL)
386 			{
387 				/* no join order could be generated, we try our next pending table */
388 				continue;
389 			}
390 
391 			/* if this rule is better than previous ones, keep it */
392 			JoinRuleType pendingJoinRuleType = pendingJoinNode->joinRuleType;
393 			if (pendingJoinRuleType < nextJoinRuleType)
394 			{
395 				nextJoinNode = pendingJoinNode;
396 				nextJoinRuleType = pendingJoinRuleType;
397 			}
398 		}
399 
400 		if (nextJoinNode == NULL)
401 		{
402 			/*
403 			 * There is no next join node found, this will repeat indefinitely hence we
404 			 * bail and let JoinOrderList try a new initial table
405 			 */
406 			return NULL;
407 		}
408 
409 		Assert(nextJoinNode != NULL);
410 		TableEntry *nextJoinedTable = nextJoinNode->tableEntry;
411 
412 		/* add next node to the join order */
413 		joinOrderList = lappend(joinOrderList, nextJoinNode);
414 		joinedTableList = lappend(joinedTableList, nextJoinedTable);
415 		currentJoinNode = nextJoinNode;
416 
417 		joinedTableCount++;
418 	}
419 
420 	return joinOrderList;
421 }
422 
423 
424 /*
425  * BestJoinOrder takes in a list of candidate join orders, and determines the
426  * best join order among these candidates. The function uses two heuristics for
427  * this. First, the function chooses join orders that have the fewest number of
428  * join operators that cause large data transfers. Second, the function chooses
429  * join orders where large data transfers occur later in the execution.
430  */
431 static List *
BestJoinOrder(List * candidateJoinOrders)432 BestJoinOrder(List *candidateJoinOrders)
433 {
434 	uint32 highestValidIndex = JOIN_RULE_LAST - 1;
435 	uint32 candidateCount PG_USED_FOR_ASSERTS_ONLY = 0;
436 
437 	/*
438 	 * We start with the highest ranking rule type (cartesian product), and walk
439 	 * over these rules in reverse order. For each rule type, we then keep join
440 	 * orders that only contain the fewest number of join rules of that type.
441 	 *
442 	 * For example, the algorithm chooses join orders like the following:
443 	 * (a) The algorithm prefers join orders with 2 cartesian products (CP) to
444 	 * those that have 3 or more, if there isn't a join order with fewer CPs.
445 	 * (b) Assuming that all join orders have the same number of CPs, the
446 	 * algorithm prefers join orders with 2 dual partitions (DP) to those that
447 	 * have 3 or more, if there isn't a join order with fewer DPs; and so
448 	 * forth.
449 	 */
450 	for (uint32 ruleTypeIndex = highestValidIndex; ruleTypeIndex > 0; ruleTypeIndex--)
451 	{
452 		JoinRuleType ruleType = (JoinRuleType) ruleTypeIndex;
453 
454 		candidateJoinOrders = FewestOfJoinRuleType(candidateJoinOrders, ruleType);
455 	}
456 
457 	/*
458 	 * If there is a tie, we pick candidate join orders where large data
459 	 * transfers happen at later stages of query execution. This results in more
460 	 * data being filtered via joins, selections, and projections earlier on.
461 	 */
462 	candidateJoinOrders = LatestLargeDataTransfer(candidateJoinOrders);
463 
464 	/* we should have at least one join order left after optimizations */
465 	candidateCount = list_length(candidateJoinOrders);
466 	Assert(candidateCount > 0);
467 
468 	/*
469 	 * If there still is a tie, we pick the join order whose relation appeared
470 	 * earliest in the query's range table entry list.
471 	 */
472 	List *bestJoinOrder = (List *) linitial(candidateJoinOrders);
473 
474 	return bestJoinOrder;
475 }
476 
477 
478 /*
479  * FewestOfJoinRuleType finds join orders that have the fewest number of times
480  * the given join rule occurs in the candidate join orders, and filters all
481  * other join orders. For example, if four candidate join orders have a join
482  * rule appearing 3, 5, 3, and 6 times, only two join orders that have the join
483  * rule appearing 3 times will be returned.
484  */
485 static List *
FewestOfJoinRuleType(List * candidateJoinOrders,JoinRuleType ruleType)486 FewestOfJoinRuleType(List *candidateJoinOrders, JoinRuleType ruleType)
487 {
488 	List *fewestJoinOrders = NULL;
489 	uint32 fewestRuleCount = INT_MAX;
490 	ListCell *joinOrderCell = NULL;
491 
492 	foreach(joinOrderCell, candidateJoinOrders)
493 	{
494 		List *joinOrder = (List *) lfirst(joinOrderCell);
495 		uint32 ruleTypeCount = JoinRuleTypeCount(joinOrder, ruleType);
496 
497 		if (ruleTypeCount == fewestRuleCount)
498 		{
499 			fewestJoinOrders = lappend(fewestJoinOrders, joinOrder);
500 		}
501 		else if (ruleTypeCount < fewestRuleCount)
502 		{
503 			fewestJoinOrders = list_make1(joinOrder);
504 			fewestRuleCount = ruleTypeCount;
505 		}
506 	}
507 
508 	return fewestJoinOrders;
509 }
510 
511 
512 /* Counts the number of times the given join rule occurs in the join order. */
513 static uint32
JoinRuleTypeCount(List * joinOrder,JoinRuleType ruleTypeToCount)514 JoinRuleTypeCount(List *joinOrder, JoinRuleType ruleTypeToCount)
515 {
516 	uint32 ruleTypeCount = 0;
517 	ListCell *joinOrderNodeCell = NULL;
518 
519 	foreach(joinOrderNodeCell, joinOrder)
520 	{
521 		JoinOrderNode *joinOrderNode = (JoinOrderNode *) lfirst(joinOrderNodeCell);
522 
523 		JoinRuleType ruleType = joinOrderNode->joinRuleType;
524 		if (ruleType == ruleTypeToCount)
525 		{
526 			ruleTypeCount++;
527 		}
528 	}
529 
530 	return ruleTypeCount;
531 }
532 
533 
534 /*
535  * LatestLargeDataTransfer finds and returns join orders where a large data
536  * transfer join rule occurs as late as possible in the join order. Late large
537  * data transfers result in more data being filtered before data gets shuffled
538  * in the network.
539  */
540 static List *
LatestLargeDataTransfer(List * candidateJoinOrders)541 LatestLargeDataTransfer(List *candidateJoinOrders)
542 {
543 	List *latestJoinOrders = NIL;
544 	uint32 latestJoinLocation = 0;
545 	ListCell *joinOrderCell = NULL;
546 
547 	foreach(joinOrderCell, candidateJoinOrders)
548 	{
549 		List *joinOrder = (List *) lfirst(joinOrderCell);
550 		uint32 joinRuleLocation = LargeDataTransferLocation(joinOrder);
551 
552 		if (joinRuleLocation == latestJoinLocation)
553 		{
554 			latestJoinOrders = lappend(latestJoinOrders, joinOrder);
555 		}
556 		else if (joinRuleLocation > latestJoinLocation)
557 		{
558 			latestJoinOrders = list_make1(joinOrder);
559 			latestJoinLocation = joinRuleLocation;
560 		}
561 	}
562 
563 	return latestJoinOrders;
564 }
565 
566 
567 /*
568  * LargeDataTransferLocation finds the first location of a large data transfer
569  * join rule, and returns that location. If the join order does not have any
570  * large data transfer rules, the function returns one location past the end of
571  * the join order list.
572  */
573 static uint32
LargeDataTransferLocation(List * joinOrder)574 LargeDataTransferLocation(List *joinOrder)
575 {
576 	uint32 joinRuleLocation = 0;
577 	ListCell *joinOrderNodeCell = NULL;
578 
579 	foreach(joinOrderNodeCell, joinOrder)
580 	{
581 		JoinOrderNode *joinOrderNode = (JoinOrderNode *) lfirst(joinOrderNodeCell);
582 		JoinRuleType joinRuleType = joinOrderNode->joinRuleType;
583 
584 		/* we consider the following join rules to cause large data transfers */
585 		if (joinRuleType == SINGLE_HASH_PARTITION_JOIN ||
586 			joinRuleType == SINGLE_RANGE_PARTITION_JOIN ||
587 			joinRuleType == DUAL_PARTITION_JOIN ||
588 			joinRuleType == CARTESIAN_PRODUCT)
589 		{
590 			break;
591 		}
592 
593 		joinRuleLocation++;
594 	}
595 
596 	return joinRuleLocation;
597 }
598 
599 
600 /* Prints the join order list and join rules for debugging purposes. */
601 static void
PrintJoinOrderList(List * joinOrder)602 PrintJoinOrderList(List *joinOrder)
603 {
604 	StringInfo printBuffer = makeStringInfo();
605 	ListCell *joinOrderNodeCell = NULL;
606 	bool firstJoinNode = true;
607 
608 	foreach(joinOrderNodeCell, joinOrder)
609 	{
610 		JoinOrderNode *joinOrderNode = (JoinOrderNode *) lfirst(joinOrderNodeCell);
611 		Oid relationId = joinOrderNode->tableEntry->relationId;
612 		char *relationName = get_rel_name(relationId);
613 
614 		if (firstJoinNode)
615 		{
616 			appendStringInfo(printBuffer, "[ \"%s\" ]", relationName);
617 			firstJoinNode = false;
618 		}
619 		else
620 		{
621 			JoinRuleType ruleType = (JoinRuleType) joinOrderNode->joinRuleType;
622 			char *ruleName = JoinRuleName(ruleType);
623 
624 			appendStringInfo(printBuffer, "[ %s ", ruleName);
625 			appendStringInfo(printBuffer, "\"%s\" ]", relationName);
626 		}
627 	}
628 
629 	ereport(LOG, (errmsg("join order: %s",
630 						 ApplyLogRedaction(printBuffer->data))));
631 }
632 
633 
634 /*
635  * TableEntryListDifference returns a list containing table entries that are in
636  * the left-hand side table list, but not in the right-hand side table list.
637  */
638 static List *
TableEntryListDifference(List * lhsTableList,List * rhsTableList)639 TableEntryListDifference(List *lhsTableList, List *rhsTableList)
640 {
641 	List *tableListDifference = NIL;
642 	ListCell *lhsTableCell = NULL;
643 
644 	foreach(lhsTableCell, lhsTableList)
645 	{
646 		TableEntry *lhsTableEntry = (TableEntry *) lfirst(lhsTableCell);
647 		ListCell *rhsTableCell = NULL;
648 		bool lhsTableEntryExists = false;
649 
650 		foreach(rhsTableCell, rhsTableList)
651 		{
652 			TableEntry *rhsTableEntry = (TableEntry *) lfirst(rhsTableCell);
653 
654 			if ((lhsTableEntry->relationId == rhsTableEntry->relationId) &&
655 				(lhsTableEntry->rangeTableId == rhsTableEntry->rangeTableId))
656 			{
657 				lhsTableEntryExists = true;
658 			}
659 		}
660 
661 		if (!lhsTableEntryExists)
662 		{
663 			tableListDifference = lappend(tableListDifference, lhsTableEntry);
664 		}
665 	}
666 
667 	return tableListDifference;
668 }
669 
670 
671 /*
672  * EvaluateJoinRules takes in a list of already joined tables and a candidate
673  * next table, evaluates different join rules between the two tables, and finds
674  * the best join rule that applies. The function returns the applicable join
675  * order node which includes the join rule and the partition information.
676  */
677 static JoinOrderNode *
EvaluateJoinRules(List * joinedTableList,JoinOrderNode * currentJoinNode,TableEntry * candidateTable,List * joinClauseList,JoinType joinType)678 EvaluateJoinRules(List *joinedTableList, JoinOrderNode *currentJoinNode,
679 				  TableEntry *candidateTable, List *joinClauseList,
680 				  JoinType joinType)
681 {
682 	JoinOrderNode *nextJoinNode = NULL;
683 	uint32 lowestValidIndex = JOIN_RULE_INVALID_FIRST + 1;
684 	uint32 highestValidIndex = JOIN_RULE_LAST - 1;
685 
686 	/*
687 	 * We first find all applicable join clauses between already joined tables
688 	 * and the candidate table.
689 	 */
690 	List *joinedTableIdList = RangeTableIdList(joinedTableList);
691 	uint32 candidateTableId = candidateTable->rangeTableId;
692 	List *applicableJoinClauses = ApplicableJoinClauses(joinedTableIdList,
693 														candidateTableId,
694 														joinClauseList);
695 
696 	/* we then evaluate all join rules in order */
697 	for (uint32 ruleIndex = lowestValidIndex; ruleIndex <= highestValidIndex; ruleIndex++)
698 	{
699 		JoinRuleType ruleType = (JoinRuleType) ruleIndex;
700 		RuleEvalFunction ruleEvalFunction = JoinRuleEvalFunction(ruleType);
701 
702 		nextJoinNode = (*ruleEvalFunction)(currentJoinNode,
703 										   candidateTable,
704 										   applicableJoinClauses,
705 										   joinType);
706 
707 		/* break after finding the first join rule that applies */
708 		if (nextJoinNode != NULL)
709 		{
710 			break;
711 		}
712 	}
713 
714 	if (nextJoinNode == NULL)
715 	{
716 		return NULL;
717 	}
718 
719 	Assert(nextJoinNode != NULL);
720 	nextJoinNode->joinType = joinType;
721 	nextJoinNode->joinClauseList = applicableJoinClauses;
722 	return nextJoinNode;
723 }
724 
725 
726 /* Extracts range table identifiers from the given table list, and returns them. */
727 static List *
RangeTableIdList(List * tableList)728 RangeTableIdList(List *tableList)
729 {
730 	List *rangeTableIdList = NIL;
731 	ListCell *tableCell = NULL;
732 
733 	foreach(tableCell, tableList)
734 	{
735 		TableEntry *tableEntry = (TableEntry *) lfirst(tableCell);
736 
737 		uint32 rangeTableId = tableEntry->rangeTableId;
738 		rangeTableIdList = lappend_int(rangeTableIdList, rangeTableId);
739 	}
740 
741 	return rangeTableIdList;
742 }
743 
744 
745 /*
746  * JoinRuleEvalFunction returns a function pointer for the rule evaluation
747  * function; this rule evaluation function corresponds to the given rule type.
748  * The function also initializes the rule evaluation function array in a static
749  * code block, if the array has not been initialized.
750  */
751 static RuleEvalFunction
JoinRuleEvalFunction(JoinRuleType ruleType)752 JoinRuleEvalFunction(JoinRuleType ruleType)
753 {
754 	static bool ruleEvalFunctionsInitialized = false;
755 
756 	if (!ruleEvalFunctionsInitialized)
757 	{
758 		RuleEvalFunctionArray[REFERENCE_JOIN] = &ReferenceJoin;
759 		RuleEvalFunctionArray[LOCAL_PARTITION_JOIN] = &LocalJoin;
760 		RuleEvalFunctionArray[SINGLE_RANGE_PARTITION_JOIN] = &SinglePartitionJoin;
761 		RuleEvalFunctionArray[SINGLE_HASH_PARTITION_JOIN] = &SinglePartitionJoin;
762 		RuleEvalFunctionArray[DUAL_PARTITION_JOIN] = &DualPartitionJoin;
763 		RuleEvalFunctionArray[CARTESIAN_PRODUCT_REFERENCE_JOIN] =
764 			&CartesianProductReferenceJoin;
765 		RuleEvalFunctionArray[CARTESIAN_PRODUCT] = &CartesianProduct;
766 
767 		ruleEvalFunctionsInitialized = true;
768 	}
769 
770 	RuleEvalFunction ruleEvalFunction = RuleEvalFunctionArray[ruleType];
771 	Assert(ruleEvalFunction != NULL);
772 
773 	return ruleEvalFunction;
774 }
775 
776 
777 /* Returns a string name for the given join rule type. */
778 static char *
JoinRuleName(JoinRuleType ruleType)779 JoinRuleName(JoinRuleType ruleType)
780 {
781 	static bool ruleNamesInitialized = false;
782 
783 	if (!ruleNamesInitialized)
784 	{
785 		/* use strdup() to be independent of memory contexts */
786 		RuleNameArray[REFERENCE_JOIN] = strdup("reference join");
787 		RuleNameArray[LOCAL_PARTITION_JOIN] = strdup("local partition join");
788 		RuleNameArray[SINGLE_HASH_PARTITION_JOIN] =
789 			strdup("single hash partition join");
790 		RuleNameArray[SINGLE_RANGE_PARTITION_JOIN] =
791 			strdup("single range partition join");
792 		RuleNameArray[DUAL_PARTITION_JOIN] = strdup("dual partition join");
793 		RuleNameArray[CARTESIAN_PRODUCT_REFERENCE_JOIN] = strdup(
794 			"cartesian product reference join");
795 		RuleNameArray[CARTESIAN_PRODUCT] = strdup("cartesian product");
796 
797 		ruleNamesInitialized = true;
798 	}
799 
800 	char *ruleName = RuleNameArray[ruleType];
801 	Assert(ruleName != NULL);
802 
803 	return ruleName;
804 }
805 
806 
807 /*
808  * ReferenceJoin evaluates if the candidate table is a reference table for inner,
809  * left and anti join. For right join, current join node must be represented by
810  * a reference table. For full join, both of them must be a reference table.
811  */
812 static JoinOrderNode *
ReferenceJoin(JoinOrderNode * currentJoinNode,TableEntry * candidateTable,List * applicableJoinClauses,JoinType joinType)813 ReferenceJoin(JoinOrderNode *currentJoinNode, TableEntry *candidateTable,
814 			  List *applicableJoinClauses, JoinType joinType)
815 {
816 	int applicableJoinCount = list_length(applicableJoinClauses);
817 	if (applicableJoinCount <= 0)
818 	{
819 		return NULL;
820 	}
821 
822 	bool leftIsReferenceTable = IsCitusTableType(
823 		currentJoinNode->tableEntry->relationId,
824 		REFERENCE_TABLE);
825 	bool rightIsReferenceTable = IsCitusTableType(candidateTable->relationId,
826 												  REFERENCE_TABLE);
827 	if (!IsSupportedReferenceJoin(joinType, leftIsReferenceTable, rightIsReferenceTable))
828 	{
829 		return NULL;
830 	}
831 	return MakeJoinOrderNode(candidateTable, REFERENCE_JOIN,
832 							 currentJoinNode->partitionColumnList,
833 							 currentJoinNode->partitionMethod,
834 							 currentJoinNode->anchorTable);
835 }
836 
837 
838 /*
839  * IsSupportedReferenceJoin checks if with this join type we can safely do a simple join
840  * on the reference table on all the workers.
841  */
842 bool
IsSupportedReferenceJoin(JoinType joinType,bool leftIsReferenceTable,bool rightIsReferenceTable)843 IsSupportedReferenceJoin(JoinType joinType, bool leftIsReferenceTable,
844 						 bool rightIsReferenceTable)
845 {
846 	if ((joinType == JOIN_INNER || joinType == JOIN_LEFT || joinType == JOIN_ANTI) &&
847 		rightIsReferenceTable)
848 	{
849 		return true;
850 	}
851 	else if ((joinType == JOIN_RIGHT) &&
852 			 leftIsReferenceTable)
853 	{
854 		return true;
855 	}
856 	else if (joinType == JOIN_FULL && leftIsReferenceTable && rightIsReferenceTable)
857 	{
858 		return true;
859 	}
860 	return false;
861 }
862 
863 
864 /*
865  * ReferenceJoin evaluates if the candidate table is a reference table for inner,
866  * left and anti join. For right join, current join node must be represented by
867  * a reference table. For full join, both of them must be a reference table.
868  */
869 static JoinOrderNode *
CartesianProductReferenceJoin(JoinOrderNode * currentJoinNode,TableEntry * candidateTable,List * applicableJoinClauses,JoinType joinType)870 CartesianProductReferenceJoin(JoinOrderNode *currentJoinNode, TableEntry *candidateTable,
871 							  List *applicableJoinClauses, JoinType joinType)
872 {
873 	bool leftIsReferenceTable = IsCitusTableType(
874 		currentJoinNode->tableEntry->relationId,
875 		REFERENCE_TABLE);
876 	bool rightIsReferenceTable = IsCitusTableType(candidateTable->relationId,
877 												  REFERENCE_TABLE);
878 
879 	if (!IsSupportedReferenceJoin(joinType, leftIsReferenceTable, rightIsReferenceTable))
880 	{
881 		return NULL;
882 	}
883 	return MakeJoinOrderNode(candidateTable, CARTESIAN_PRODUCT_REFERENCE_JOIN,
884 							 currentJoinNode->partitionColumnList,
885 							 currentJoinNode->partitionMethod,
886 							 currentJoinNode->anchorTable);
887 }
888 
889 
890 /*
891  * LocalJoin takes the current partition key column and the candidate table's
892  * partition key column and the partition method for each table. The function
893  * then evaluates if tables in the join order and the candidate table can be
894  * joined locally, without any data transfers. If they can, the function returns
895  * a join order node for a local join. Otherwise, the function returns null.
896  *
897  * Anchor table is used to decide whether the JoinOrderNode can be joined
898  * locally with the candidate table. That table is updated by each join type
899  * applied over JoinOrderNode. Note that, we lost the anchor table after
900  * dual partitioning and cartesian product.
901  */
902 static JoinOrderNode *
LocalJoin(JoinOrderNode * currentJoinNode,TableEntry * candidateTable,List * applicableJoinClauses,JoinType joinType)903 LocalJoin(JoinOrderNode *currentJoinNode, TableEntry *candidateTable,
904 		  List *applicableJoinClauses, JoinType joinType)
905 {
906 	Oid relationId = candidateTable->relationId;
907 	uint32 tableId = candidateTable->rangeTableId;
908 	Var *candidatePartitionColumn = PartitionColumn(relationId, tableId);
909 	List *currentPartitionColumnList = currentJoinNode->partitionColumnList;
910 	char candidatePartitionMethod = PartitionMethod(relationId);
911 	char currentPartitionMethod = currentJoinNode->partitionMethod;
912 	TableEntry *currentAnchorTable = currentJoinNode->anchorTable;
913 
914 	/*
915 	 * If we previously dual-hash re-partitioned the tables for a join or made cartesian
916 	 * product, there is no anchor table anymore. In that case we don't allow local join.
917 	 */
918 	if (currentAnchorTable == NULL)
919 	{
920 		return NULL;
921 	}
922 
923 	/* the partition method should be the same for a local join */
924 	if (currentPartitionMethod != candidatePartitionMethod)
925 	{
926 		return NULL;
927 	}
928 
929 	bool joinOnPartitionColumns = JoinOnColumns(currentPartitionColumnList,
930 												candidatePartitionColumn,
931 												applicableJoinClauses);
932 	if (!joinOnPartitionColumns)
933 	{
934 		return NULL;
935 	}
936 
937 	/* shard interval lists must have 1-1 matching for local joins */
938 	bool coPartitionedTables = CoPartitionedTables(currentAnchorTable->relationId,
939 												   relationId);
940 
941 	if (!coPartitionedTables)
942 	{
943 		return NULL;
944 	}
945 
946 	/*
947 	 * Since we are applying a local join to the candidate table we need to keep track of
948 	 * the partition column of the candidate table on the MultiJoinNode. This will allow
949 	 * subsequent joins colocated with this candidate table to correctly be recognized as
950 	 * a local join as well.
951 	 */
952 	currentPartitionColumnList = list_append_unique(currentPartitionColumnList,
953 													candidatePartitionColumn);
954 
955 	JoinOrderNode *nextJoinNode = MakeJoinOrderNode(candidateTable, LOCAL_PARTITION_JOIN,
956 													currentPartitionColumnList,
957 													currentPartitionMethod,
958 													currentAnchorTable);
959 
960 
961 	return nextJoinNode;
962 }
963 
964 
965 /*
966  * SinglePartitionJoin takes the current and the candidate table's partition keys
967  * and methods. The function then evaluates if either "tables in the join order"
968  * or the candidate table is already partitioned on a join column. If they are,
969  * the function returns a join order node with the already partitioned column as
970  * the next partition key. Otherwise, the function returns null.
971  */
972 static JoinOrderNode *
SinglePartitionJoin(JoinOrderNode * currentJoinNode,TableEntry * candidateTable,List * applicableJoinClauses,JoinType joinType)973 SinglePartitionJoin(JoinOrderNode *currentJoinNode, TableEntry *candidateTable,
974 					List *applicableJoinClauses, JoinType joinType)
975 {
976 	List *currentPartitionColumnList = currentJoinNode->partitionColumnList;
977 	char currentPartitionMethod = currentJoinNode->partitionMethod;
978 	TableEntry *currentAnchorTable = currentJoinNode->anchorTable;
979 	JoinRuleType currentJoinRuleType = currentJoinNode->joinRuleType;
980 
981 
982 	Oid relationId = candidateTable->relationId;
983 	uint32 tableId = candidateTable->rangeTableId;
984 	Var *candidatePartitionColumn = PartitionColumn(relationId, tableId);
985 	char candidatePartitionMethod = PartitionMethod(relationId);
986 
987 	/* outer joins are not supported yet */
988 	if (IS_OUTER_JOIN(joinType))
989 	{
990 		return NULL;
991 	}
992 
993 	/*
994 	 * If we previously dual-hash re-partitioned the tables for a join or made
995 	 * cartesian product, we currently don't allow a single-repartition join.
996 	 */
997 	if (currentJoinRuleType == DUAL_PARTITION_JOIN ||
998 		currentJoinRuleType == CARTESIAN_PRODUCT)
999 	{
1000 		return NULL;
1001 	}
1002 
1003 	OpExpr *joinClause =
1004 		SinglePartitionJoinClause(currentPartitionColumnList, applicableJoinClauses);
1005 	if (joinClause != NULL)
1006 	{
1007 		if (currentPartitionMethod == DISTRIBUTE_BY_HASH)
1008 		{
1009 			/*
1010 			 * Single hash repartitioning may perform worse than dual hash
1011 			 * repartitioning. Thus, we control it via a guc.
1012 			 */
1013 			if (!EnableSingleHashRepartitioning)
1014 			{
1015 				return NULL;
1016 			}
1017 
1018 			return MakeJoinOrderNode(candidateTable, SINGLE_HASH_PARTITION_JOIN,
1019 									 currentPartitionColumnList,
1020 									 currentPartitionMethod,
1021 									 currentAnchorTable);
1022 		}
1023 		else
1024 		{
1025 			return MakeJoinOrderNode(candidateTable, SINGLE_RANGE_PARTITION_JOIN,
1026 									 currentPartitionColumnList,
1027 									 currentPartitionMethod,
1028 									 currentAnchorTable);
1029 		}
1030 	}
1031 
1032 	/* evaluate re-partitioning the current table only if the rule didn't apply above */
1033 	if (candidatePartitionMethod != DISTRIBUTE_BY_NONE)
1034 	{
1035 		/*
1036 		 * Create a new unique list (set) with the partition column of the candidate table
1037 		 * to check if a single repartition join will work for this table. When it works
1038 		 * the set is retained on the MultiJoinNode for later local join verification.
1039 		 */
1040 		List *candidatePartitionColumnList = list_make1(candidatePartitionColumn);
1041 		joinClause = SinglePartitionJoinClause(candidatePartitionColumnList,
1042 											   applicableJoinClauses);
1043 		if (joinClause != NULL)
1044 		{
1045 			if (candidatePartitionMethod == DISTRIBUTE_BY_HASH)
1046 			{
1047 				/*
1048 				 * Single hash repartitioning may perform worse than dual hash
1049 				 * repartitioning. Thus, we control it via a guc.
1050 				 */
1051 				if (!EnableSingleHashRepartitioning)
1052 				{
1053 					return NULL;
1054 				}
1055 
1056 				return MakeJoinOrderNode(candidateTable,
1057 										 SINGLE_HASH_PARTITION_JOIN,
1058 										 candidatePartitionColumnList,
1059 										 candidatePartitionMethod,
1060 										 candidateTable);
1061 			}
1062 			else
1063 			{
1064 				return MakeJoinOrderNode(candidateTable,
1065 										 SINGLE_RANGE_PARTITION_JOIN,
1066 										 candidatePartitionColumnList,
1067 										 candidatePartitionMethod,
1068 										 candidateTable);
1069 			}
1070 		}
1071 	}
1072 
1073 	return NULL;
1074 }
1075 
1076 
1077 /*
1078  * SinglePartitionJoinClause walks over the applicable join clause list, and
1079  * finds an applicable join clause for the given partition column. If no such
1080  * clause exists, the function returns NULL.
1081  */
1082 OpExpr *
SinglePartitionJoinClause(List * partitionColumnList,List * applicableJoinClauses)1083 SinglePartitionJoinClause(List *partitionColumnList, List *applicableJoinClauses)
1084 {
1085 	if (list_length(partitionColumnList) == 0)
1086 	{
1087 		return NULL;
1088 	}
1089 
1090 	Var *partitionColumn = NULL;
1091 	foreach_ptr(partitionColumn, partitionColumnList)
1092 	{
1093 		Node *applicableJoinClause = NULL;
1094 		foreach_ptr(applicableJoinClause, applicableJoinClauses)
1095 		{
1096 			if (!NodeIsEqualsOpExpr(applicableJoinClause))
1097 			{
1098 				continue;
1099 			}
1100 			OpExpr *applicableJoinOpExpr = castNode(OpExpr, applicableJoinClause);
1101 			Var *leftColumn = LeftColumnOrNULL(applicableJoinOpExpr);
1102 			Var *rightColumn = RightColumnOrNULL(applicableJoinOpExpr);
1103 			if (leftColumn == NULL || rightColumn == NULL)
1104 			{
1105 				/* not a simple partition column join */
1106 				continue;
1107 			}
1108 
1109 
1110 			/*
1111 			 * We first check if partition column matches either of the join columns
1112 			 * and if it does, we then check if the join column types match. If the
1113 			 * types are different, we will use different hash functions for the two
1114 			 * column types, and will incorrectly repartition the data.
1115 			 */
1116 			if (equal(leftColumn, partitionColumn) || equal(rightColumn, partitionColumn))
1117 			{
1118 				if (leftColumn->vartype == rightColumn->vartype)
1119 				{
1120 					return applicableJoinOpExpr;
1121 				}
1122 				else
1123 				{
1124 					ereport(DEBUG1, (errmsg("single partition column types do not "
1125 											"match")));
1126 				}
1127 			}
1128 		}
1129 	}
1130 
1131 	return NULL;
1132 }
1133 
1134 
1135 /*
1136  * DualPartitionJoin evaluates if a join clause exists between "tables in the
1137  * join order" and the candidate table. If such a clause exists, both tables can
1138  * be repartitioned on the join column; and the function returns a join order
1139  * node with the join column as the next partition key. Otherwise, the function
1140  * returns null.
1141  */
1142 static JoinOrderNode *
DualPartitionJoin(JoinOrderNode * currentJoinNode,TableEntry * candidateTable,List * applicableJoinClauses,JoinType joinType)1143 DualPartitionJoin(JoinOrderNode *currentJoinNode, TableEntry *candidateTable,
1144 				  List *applicableJoinClauses, JoinType joinType)
1145 {
1146 	OpExpr *joinClause = DualPartitionJoinClause(applicableJoinClauses);
1147 	if (joinClause)
1148 	{
1149 		/* because of the dual partition, anchor table and partition column get lost */
1150 		return MakeJoinOrderNode(candidateTable,
1151 								 DUAL_PARTITION_JOIN,
1152 								 NIL,
1153 								 REDISTRIBUTE_BY_HASH,
1154 								 NULL);
1155 	}
1156 
1157 	return NULL;
1158 }
1159 
1160 
1161 /*
1162  * DualPartitionJoinClause walks over the applicable join clause list, and finds
1163  * an applicable join clause for dual re-partitioning. If no such clause exists,
1164  * the function returns NULL.
1165  */
1166 OpExpr *
DualPartitionJoinClause(List * applicableJoinClauses)1167 DualPartitionJoinClause(List *applicableJoinClauses)
1168 {
1169 	Node *applicableJoinClause = NULL;
1170 	foreach_ptr(applicableJoinClause, applicableJoinClauses)
1171 	{
1172 		if (!NodeIsEqualsOpExpr(applicableJoinClause))
1173 		{
1174 			continue;
1175 		}
1176 		OpExpr *applicableJoinOpExpr = castNode(OpExpr, applicableJoinClause);
1177 		Var *leftColumn = LeftColumnOrNULL(applicableJoinOpExpr);
1178 		Var *rightColumn = RightColumnOrNULL(applicableJoinOpExpr);
1179 
1180 		if (leftColumn == NULL || rightColumn == NULL)
1181 		{
1182 			continue;
1183 		}
1184 
1185 		/* we only need to check that the join column types match */
1186 		if (leftColumn->vartype == rightColumn->vartype)
1187 		{
1188 			return applicableJoinOpExpr;
1189 		}
1190 		else
1191 		{
1192 			ereport(DEBUG1, (errmsg("dual partition column types do not match")));
1193 		}
1194 	}
1195 
1196 	return NULL;
1197 }
1198 
1199 
1200 /*
1201  * CartesianProduct always evaluates to true since all tables can be combined
1202  * using a cartesian product operator. This function acts as a catch-all rule,
1203  * in case none of the join rules apply.
1204  */
1205 static JoinOrderNode *
CartesianProduct(JoinOrderNode * currentJoinNode,TableEntry * candidateTable,List * applicableJoinClauses,JoinType joinType)1206 CartesianProduct(JoinOrderNode *currentJoinNode, TableEntry *candidateTable,
1207 				 List *applicableJoinClauses, JoinType joinType)
1208 {
1209 	if (list_length(applicableJoinClauses) == 0)
1210 	{
1211 		/* Because of the cartesian product, anchor table information got lost */
1212 		return MakeJoinOrderNode(candidateTable, CARTESIAN_PRODUCT,
1213 								 currentJoinNode->partitionColumnList,
1214 								 currentJoinNode->partitionMethod,
1215 								 NULL);
1216 	}
1217 
1218 	return NULL;
1219 }
1220 
1221 
1222 /* Constructs and returns a join-order node with the given arguments */
1223 JoinOrderNode *
MakeJoinOrderNode(TableEntry * tableEntry,JoinRuleType joinRuleType,List * partitionColumnList,char partitionMethod,TableEntry * anchorTable)1224 MakeJoinOrderNode(TableEntry *tableEntry, JoinRuleType joinRuleType,
1225 				  List *partitionColumnList, char partitionMethod,
1226 				  TableEntry *anchorTable)
1227 {
1228 	JoinOrderNode *joinOrderNode = palloc0(sizeof(JoinOrderNode));
1229 	joinOrderNode->tableEntry = tableEntry;
1230 	joinOrderNode->joinRuleType = joinRuleType;
1231 	joinOrderNode->joinType = JOIN_INNER;
1232 	joinOrderNode->partitionColumnList = partitionColumnList;
1233 	joinOrderNode->partitionMethod = partitionMethod;
1234 	joinOrderNode->joinClauseList = NIL;
1235 	joinOrderNode->anchorTable = anchorTable;
1236 
1237 	return joinOrderNode;
1238 }
1239 
1240 
1241 /*
1242  * IsApplicableJoinClause tests if the current joinClause is applicable to the join at
1243  * hand.
1244  *
1245  * Given a list of left hand tables and a candidate right hand table the join clause is
1246  * valid if atleast 1 column is from the right hand table AND all columns can be found
1247  * in either the list of tables on the left *or* in the right hand table.
1248  */
1249 bool
IsApplicableJoinClause(List * leftTableIdList,uint32 rightTableId,Node * joinClause)1250 IsApplicableJoinClause(List *leftTableIdList, uint32 rightTableId, Node *joinClause)
1251 {
1252 	List *varList = pull_var_clause_default(joinClause);
1253 	Var *var = NULL;
1254 	bool joinContainsRightTable = false;
1255 	foreach_ptr(var, varList)
1256 	{
1257 		uint32 columnTableId = var->varno;
1258 		if (rightTableId == columnTableId)
1259 		{
1260 			joinContainsRightTable = true;
1261 		}
1262 		else if (!list_member_int(leftTableIdList, columnTableId))
1263 		{
1264 			/*
1265 			 * We couldn't find this column either on the right hand side (first if
1266 			 * statement), nor in the list on the left. This join clause involves a table
1267 			 * not yet available during the candidate join.
1268 			 */
1269 			return false;
1270 		}
1271 	}
1272 
1273 	/*
1274 	 * All columns referenced in this clause are available during this join, now the join
1275 	 * is applicable if we found our candidate table as well
1276 	 */
1277 	return joinContainsRightTable;
1278 }
1279 
1280 
1281 /*
1282  * ApplicableJoinClauses finds all join clauses that apply between the given
1283  * left table list and the right table, and returns these found join clauses.
1284  */
1285 List *
ApplicableJoinClauses(List * leftTableIdList,uint32 rightTableId,List * joinClauseList)1286 ApplicableJoinClauses(List *leftTableIdList, uint32 rightTableId, List *joinClauseList)
1287 {
1288 	List *applicableJoinClauses = NIL;
1289 
1290 	/* make sure joinClauseList contains only join clauses */
1291 	joinClauseList = JoinClauseList(joinClauseList);
1292 
1293 	Node *joinClause = NULL;
1294 	foreach_ptr(joinClause, joinClauseList)
1295 	{
1296 		if (IsApplicableJoinClause(leftTableIdList, rightTableId, joinClause))
1297 		{
1298 			applicableJoinClauses = lappend(applicableJoinClauses, joinClause);
1299 		}
1300 	}
1301 
1302 	return applicableJoinClauses;
1303 }
1304 
1305 
1306 /*
1307  * Returns the left column only when directly referenced in the given join clause,
1308  * otherwise NULL is returned.
1309  */
1310 Var *
LeftColumnOrNULL(OpExpr * joinClause)1311 LeftColumnOrNULL(OpExpr *joinClause)
1312 {
1313 	List *argumentList = joinClause->args;
1314 	Node *leftArgument = (Node *) linitial(argumentList);
1315 
1316 	leftArgument = strip_implicit_coercions(leftArgument);
1317 	if (!IsA(leftArgument, Var))
1318 	{
1319 		return NULL;
1320 	}
1321 	return castNode(Var, leftArgument);
1322 }
1323 
1324 
1325 /*
1326  * Returns the right column only when directly referenced in the given join clause,
1327  * otherwise NULL is returned.
1328  * */
1329 Var *
RightColumnOrNULL(OpExpr * joinClause)1330 RightColumnOrNULL(OpExpr *joinClause)
1331 {
1332 	List *argumentList = joinClause->args;
1333 	Node *rightArgument = (Node *) lsecond(argumentList);
1334 
1335 	rightArgument = strip_implicit_coercions(rightArgument);
1336 	if (!IsA(rightArgument, Var))
1337 	{
1338 		return NULL;
1339 	}
1340 	return castNode(Var, rightArgument);
1341 }
1342 
1343 
1344 /*
1345  * PartitionColumn builds the partition column for the given relation, and sets
1346  * the partition column's range table references to the given table identifier.
1347  *
1348  * Note that reference tables do not have partition column. Thus, this function
1349  * returns NULL when called for reference tables.
1350  */
1351 Var *
PartitionColumn(Oid relationId,uint32 rangeTableId)1352 PartitionColumn(Oid relationId, uint32 rangeTableId)
1353 {
1354 	Var *partitionKey = DistPartitionKey(relationId);
1355 	Var *partitionColumn = NULL;
1356 
1357 	/* short circuit for reference tables */
1358 	if (partitionKey == NULL)
1359 	{
1360 		return partitionColumn;
1361 	}
1362 
1363 	partitionColumn = partitionKey;
1364 	partitionColumn->varno = rangeTableId;
1365 	partitionColumn->varnosyn = rangeTableId;
1366 
1367 	return partitionColumn;
1368 }
1369 
1370 
1371 /*
1372  * DistPartitionKey returns the partition key column for the given relation. Note
1373  * that in the context of distributed join and query planning, the callers of
1374  * this function *must* set the partition key column's range table reference
1375  * (varno) to match the table's location in the query range table list.
1376  *
1377  * Note that reference tables do not have partition column. Thus, this function
1378  * returns NULL when called for reference tables.
1379  */
1380 Var *
DistPartitionKey(Oid relationId)1381 DistPartitionKey(Oid relationId)
1382 {
1383 	CitusTableCacheEntry *partitionEntry = GetCitusTableCacheEntry(relationId);
1384 
1385 	/* non-distributed tables do not have partition column */
1386 	if (IsCitusTableTypeCacheEntry(partitionEntry, CITUS_TABLE_WITH_NO_DIST_KEY))
1387 	{
1388 		return NULL;
1389 	}
1390 
1391 	return copyObject(partitionEntry->partitionColumn);
1392 }
1393 
1394 
1395 /*
1396  * DistPartitionKeyOrError is the same as DistPartitionKey but errors out instead
1397  * of returning NULL if this is called with a relationId of a reference table.
1398  */
1399 Var *
DistPartitionKeyOrError(Oid relationId)1400 DistPartitionKeyOrError(Oid relationId)
1401 {
1402 	Var *partitionKey = DistPartitionKey(relationId);
1403 
1404 	if (partitionKey == NULL)
1405 	{
1406 		ereport(ERROR, (errmsg(
1407 							"no distribution column found for relation %d, because it is a reference table",
1408 							relationId)));
1409 	}
1410 
1411 	return partitionKey;
1412 }
1413 
1414 
1415 /* Returns the partition method for the given relation. */
1416 char
PartitionMethod(Oid relationId)1417 PartitionMethod(Oid relationId)
1418 {
1419 	/* errors out if not a distributed table */
1420 	CitusTableCacheEntry *partitionEntry = GetCitusTableCacheEntry(relationId);
1421 
1422 	char partitionMethod = partitionEntry->partitionMethod;
1423 
1424 	return partitionMethod;
1425 }
1426 
1427 
1428 /* Returns the replication model for the given relation. */
1429 char
TableReplicationModel(Oid relationId)1430 TableReplicationModel(Oid relationId)
1431 {
1432 	/* errors out if not a distributed table */
1433 	CitusTableCacheEntry *partitionEntry = GetCitusTableCacheEntry(relationId);
1434 
1435 	char replicationModel = partitionEntry->replicationModel;
1436 
1437 	return replicationModel;
1438 }
1439