1 /*-------------------------------------------------------------------------
2  *
3  * multi_logical_planner.c
4  *
5  * Routines for constructing a logical plan tree from the given Query tree
6  * structure. This new logical plan is based on multi-relational algebra rules.
7  *
8  * Copyright (c) Citus Data, Inc.
9  *
10  * $Id$
11  *
12  *-------------------------------------------------------------------------
13  */
14 
15 #include "postgres.h"
16 
17 #include "distributed/pg_version_constants.h"
18 
19 #include "access/heapam.h"
20 #include "access/nbtree.h"
21 #include "catalog/pg_am.h"
22 #include "catalog/pg_class.h"
23 #include "commands/defrem.h"
24 #include "distributed/citus_clauses.h"
25 #include "distributed/colocation_utils.h"
26 #include "distributed/metadata_cache.h"
27 #include "distributed/insert_select_planner.h"
28 #include "distributed/listutils.h"
29 #include "distributed/multi_logical_optimizer.h"
30 #include "distributed/multi_logical_planner.h"
31 #include "distributed/multi_physical_planner.h"
32 #include "distributed/reference_table_utils.h"
33 #include "distributed/relation_restriction_equivalence.h"
34 #include "distributed/query_pushdown_planning.h"
35 #include "distributed/query_utils.h"
36 #include "distributed/multi_router_planner.h"
37 #include "distributed/worker_protocol.h"
38 #include "distributed/version_compat.h"
39 #include "nodes/makefuncs.h"
40 #include "nodes/nodeFuncs.h"
41 #include "nodes/pathnodes.h"
42 #include "optimizer/optimizer.h"
43 #include "optimizer/clauses.h"
44 #include "optimizer/prep.h"
45 #include "optimizer/tlist.h"
46 #include "parser/parsetree.h"
47 #include "utils/builtins.h"
48 #include "utils/datum.h"
49 #include "utils/lsyscache.h"
50 #include "utils/syscache.h"
51 #include "utils/rel.h"
52 #include "utils/relcache.h"
53 
54 
55 /* Struct to differentiate different qualifier types in an expression tree walker */
56 typedef struct QualifierWalkerContext
57 {
58 	List *baseQualifierList;
59 	List *outerJoinQualifierList;
60 } QualifierWalkerContext;
61 
62 
63 /* Function pointer type definition for apply join rule functions */
64 typedef MultiNode *(*RuleApplyFunction) (MultiNode *leftNode, MultiNode *rightNode,
65 										 List *partitionColumnList, JoinType joinType,
66 										 List *joinClauses);
67 
68 typedef bool (*CheckNodeFunc)(Node *);
69 
70 static RuleApplyFunction RuleApplyFunctionArray[JOIN_RULE_LAST] = { 0 }; /* join rules */
71 
72 /* Local functions forward declarations */
73 static FieldSelect * CompositeFieldRecursive(Expr *expression, Query *query);
74 static Oid NodeTryGetRteRelid(Node *node);
75 static bool FullCompositeFieldList(List *compositeFieldList);
76 static bool HasUnsupportedJoinWalker(Node *node, void *context);
77 static bool ErrorHintRequired(const char *errorHint, Query *queryTree);
78 static bool HasTablesample(Query *queryTree);
79 static bool HasComplexRangeTableType(Query *queryTree);
80 static bool IsReadIntermediateResultFunction(Node *node);
81 static bool IsReadIntermediateResultArrayFunction(Node *node);
82 static bool IsCitusExtraDataContainerFunc(Node *node);
83 static bool IsFunctionWithOid(Node *node, Oid funcOid);
84 static bool IsGroupingFunc(Node *node);
85 static bool ExtractFromExpressionWalker(Node *node,
86 										QualifierWalkerContext *walkerContext);
87 static List * MultiTableNodeList(List *tableEntryList, List *rangeTableList);
88 static List * AddMultiCollectNodes(List *tableNodeList);
89 static MultiNode * MultiJoinTree(List *joinOrderList, List *collectTableList,
90 								 List *joinClauseList);
91 static MultiCollect * CollectNodeForTable(List *collectTableList, uint32 rangeTableId);
92 static MultiSelect * MultiSelectNode(List *whereClauseList);
93 static bool IsSelectClause(Node *clause);
94 
95 /* Local functions forward declarations for applying joins */
96 static MultiNode * ApplyJoinRule(MultiNode *leftNode, MultiNode *rightNode,
97 								 JoinRuleType ruleType, List *partitionColumnList,
98 								 JoinType joinType, List *joinClauseList);
99 static RuleApplyFunction JoinRuleApplyFunction(JoinRuleType ruleType);
100 static MultiNode * ApplyReferenceJoin(MultiNode *leftNode, MultiNode *rightNode,
101 									  List *partitionColumnList, JoinType joinType,
102 									  List *joinClauses);
103 static MultiNode * ApplyLocalJoin(MultiNode *leftNode, MultiNode *rightNode,
104 								  List *partitionColumnList, JoinType joinType,
105 								  List *joinClauses);
106 static MultiNode * ApplySingleRangePartitionJoin(MultiNode *leftNode,
107 												 MultiNode *rightNode,
108 												 List *partitionColumnList,
109 												 JoinType joinType,
110 												 List *applicableJoinClauses);
111 static MultiNode * ApplySingleHashPartitionJoin(MultiNode *leftNode,
112 												MultiNode *rightNode,
113 												List *partitionColumnList,
114 												JoinType joinType,
115 												List *applicableJoinClauses);
116 static MultiJoin * ApplySinglePartitionJoin(MultiNode *leftNode, MultiNode *rightNode,
117 											List *partitionColumnList, JoinType joinType,
118 											List *joinClauses);
119 static MultiNode * ApplyDualPartitionJoin(MultiNode *leftNode, MultiNode *rightNode,
120 										  List *partitionColumnList, JoinType joinType,
121 										  List *joinClauses);
122 static MultiNode * ApplyCartesianProductReferenceJoin(MultiNode *leftNode,
123 													  MultiNode *rightNode,
124 													  List *partitionColumnList,
125 													  JoinType joinType,
126 													  List *joinClauses);
127 static MultiNode * ApplyCartesianProduct(MultiNode *leftNode, MultiNode *rightNode,
128 										 List *partitionColumnList, JoinType joinType,
129 										 List *joinClauses);
130 
131 
132 /*
133  * MultiLogicalPlanCreate takes in both the original query and its corresponding modified
134  * query tree yield by the standard planner. It uses helper functions to create logical
135  * plan and adds a root node to top of it. The original query is only used for subquery
136  * pushdown planning.
137  *
138  * We also pass queryTree and plannerRestrictionContext to the planner. They
139  * are primarily used to decide whether the subquery is safe to pushdown.
140  * If not, it helps to produce meaningful error messages for subquery
141  * pushdown planning.
142  */
143 MultiTreeRoot *
MultiLogicalPlanCreate(Query * originalQuery,Query * queryTree,PlannerRestrictionContext * plannerRestrictionContext)144 MultiLogicalPlanCreate(Query *originalQuery, Query *queryTree,
145 					   PlannerRestrictionContext *plannerRestrictionContext)
146 {
147 	MultiNode *multiQueryNode = NULL;
148 
149 
150 	if (ShouldUseSubqueryPushDown(originalQuery, queryTree, plannerRestrictionContext))
151 	{
152 		multiQueryNode = SubqueryMultiNodeTree(originalQuery, queryTree,
153 											   plannerRestrictionContext);
154 	}
155 	else
156 	{
157 		multiQueryNode = MultiNodeTree(queryTree);
158 	}
159 
160 	/* add a root node to serve as the permanent handle to the tree */
161 	MultiTreeRoot *rootNode = CitusMakeNode(MultiTreeRoot);
162 	SetChild((MultiUnaryNode *) rootNode, multiQueryNode);
163 
164 	return rootNode;
165 }
166 
167 
168 /*
169  * FindNodeMatchingCheckFunction finds a node for which the checker function returns true.
170  *
171  * To call this function directly with an RTE, use:
172  * range_table_walker(rte, FindNodeMatchingCheckFunction, checker, QTW_EXAMINE_RTES_BEFORE)
173  */
174 bool
FindNodeMatchingCheckFunction(Node * node,CheckNodeFunc checker)175 FindNodeMatchingCheckFunction(Node *node, CheckNodeFunc checker)
176 {
177 	if (node == NULL)
178 	{
179 		return false;
180 	}
181 
182 	if (checker(node))
183 	{
184 		return true;
185 	}
186 
187 	if (IsA(node, RangeTblEntry))
188 	{
189 		/* query_tree_walker descends into RTEs */
190 		return false;
191 	}
192 	else if (IsA(node, Query))
193 	{
194 		return query_tree_walker((Query *) node, FindNodeMatchingCheckFunction, checker,
195 								 QTW_EXAMINE_RTES_BEFORE);
196 	}
197 
198 	return expression_tree_walker(node, FindNodeMatchingCheckFunction, checker);
199 }
200 
201 
202 /*
203  * TargetListOnPartitionColumn checks if at least one target list entry is on
204  * partition column.
205  */
206 bool
TargetListOnPartitionColumn(Query * query,List * targetEntryList)207 TargetListOnPartitionColumn(Query *query, List *targetEntryList)
208 {
209 	bool targetListOnPartitionColumn = false;
210 	List *compositeFieldList = NIL;
211 
212 	ListCell *targetEntryCell = NULL;
213 	foreach(targetEntryCell, targetEntryList)
214 	{
215 		TargetEntry *targetEntry = (TargetEntry *) lfirst(targetEntryCell);
216 		Expr *targetExpression = targetEntry->expr;
217 
218 		bool skipOuterVars = true;
219 		bool isPartitionColumn = IsPartitionColumn(targetExpression, query,
220 												   skipOuterVars);
221 		Var *column = NULL;
222 		RangeTblEntry *rte = NULL;
223 
224 		FindReferencedTableColumn(targetExpression, NIL, query, &column, &rte,
225 								  skipOuterVars);
226 		Oid relationId = rte ? rte->relid : InvalidOid;
227 
228 		/*
229 		 * If the expression belongs to a non-distributed table continue searching for
230 		 * other partition keys.
231 		 */
232 		if (IsCitusTableType(relationId, CITUS_TABLE_WITH_NO_DIST_KEY))
233 		{
234 			continue;
235 		}
236 
237 		if (isPartitionColumn)
238 		{
239 			FieldSelect *compositeField = CompositeFieldRecursive(targetExpression,
240 																  query);
241 			if (compositeField)
242 			{
243 				compositeFieldList = lappend(compositeFieldList, compositeField);
244 			}
245 			else
246 			{
247 				targetListOnPartitionColumn = true;
248 				break;
249 			}
250 		}
251 	}
252 
253 	/* check composite fields */
254 	if (!targetListOnPartitionColumn)
255 	{
256 		bool fullCompositeFieldList = FullCompositeFieldList(compositeFieldList);
257 		if (fullCompositeFieldList)
258 		{
259 			targetListOnPartitionColumn = true;
260 		}
261 	}
262 
263 	/*
264 	 * We could still behave as if the target list is on partition column if
265 	 * range table entries don't contain a distributed table.
266 	 */
267 	if (!targetListOnPartitionColumn)
268 	{
269 		if (!FindNodeMatchingCheckFunctionInRangeTableList(query->rtable,
270 														   IsDistributedTableRTE))
271 		{
272 			targetListOnPartitionColumn = true;
273 		}
274 	}
275 
276 	return targetListOnPartitionColumn;
277 }
278 
279 
280 /*
281  * FindNodeMatchingCheckFunctionInRangeTableList finds a node for which the checker
282  * function returns true.
283  *
284  * FindNodeMatchingCheckFunctionInRangeTableList relies on
285  * FindNodeMatchingCheckFunction() but only considers the range table entries.
286  */
287 bool
FindNodeMatchingCheckFunctionInRangeTableList(List * rtable,CheckNodeFunc checker)288 FindNodeMatchingCheckFunctionInRangeTableList(List *rtable, CheckNodeFunc checker)
289 {
290 	return range_table_walker(rtable, FindNodeMatchingCheckFunction, checker,
291 							  QTW_EXAMINE_RTES_BEFORE);
292 }
293 
294 
295 /*
296  * NodeTryGetRteRelid returns the relid of the given RTE_RELATION RangeTableEntry.
297  * Returns InvalidOid if any of these assumptions fail for given node.
298  */
299 static Oid
NodeTryGetRteRelid(Node * node)300 NodeTryGetRteRelid(Node *node)
301 {
302 	if (node == NULL)
303 	{
304 		return InvalidOid;
305 	}
306 
307 	if (!IsA(node, RangeTblEntry))
308 	{
309 		return InvalidOid;
310 	}
311 
312 	RangeTblEntry *rangeTableEntry = (RangeTblEntry *) node;
313 
314 	if (rangeTableEntry->rtekind != RTE_RELATION)
315 	{
316 		return InvalidOid;
317 	}
318 
319 	return rangeTableEntry->relid;
320 }
321 
322 
323 /*
324  * IsCitusTableRTE gets a node and returns true if the node is a
325  * range table relation entry that points to a distributed relation.
326  */
327 bool
IsCitusTableRTE(Node * node)328 IsCitusTableRTE(Node *node)
329 {
330 	Oid relationId = NodeTryGetRteRelid(node);
331 	return relationId != InvalidOid && IsCitusTable(relationId);
332 }
333 
334 
335 /*
336  * IsDistributedOrReferenceTableRTE returns true if the given node
337  * is eeither a distributed(hash/range/append) or reference table.
338  */
339 bool
IsDistributedOrReferenceTableRTE(Node * node)340 IsDistributedOrReferenceTableRTE(Node *node)
341 {
342 	Oid relationId = NodeTryGetRteRelid(node);
343 	if (!OidIsValid(relationId))
344 	{
345 		return false;
346 	}
347 	return IsCitusTableType(relationId, DISTRIBUTED_TABLE) ||
348 		   IsCitusTableType(relationId, REFERENCE_TABLE);
349 }
350 
351 
352 /*
353  * IsDistributedTableRTE gets a node and returns true if the node
354  * is a range table relation entry that points to a distributed relation,
355  * returning false still if the relation is a reference table.
356  */
357 bool
IsDistributedTableRTE(Node * node)358 IsDistributedTableRTE(Node *node)
359 {
360 	Oid relationId = NodeTryGetRteRelid(node);
361 	return relationId != InvalidOid && IsCitusTableType(relationId, DISTRIBUTED_TABLE);
362 }
363 
364 
365 /*
366  * IsReferenceTableRTE gets a node and returns true if the node
367  * is a range table relation entry that points to a reference table.
368  */
369 bool
IsReferenceTableRTE(Node * node)370 IsReferenceTableRTE(Node *node)
371 {
372 	Oid relationId = NodeTryGetRteRelid(node);
373 	return relationId != InvalidOid && IsCitusTableType(relationId, REFERENCE_TABLE);
374 }
375 
376 
377 /*
378  * FullCompositeFieldList gets a composite field list, and checks if all fields
379  * of composite type are used in the list.
380  */
381 static bool
FullCompositeFieldList(List * compositeFieldList)382 FullCompositeFieldList(List *compositeFieldList)
383 {
384 	bool fullCompositeFieldList = true;
385 	bool *compositeFieldArray = NULL;
386 	uint32 compositeFieldCount = 0;
387 
388 	ListCell *fieldSelectCell = NULL;
389 	foreach(fieldSelectCell, compositeFieldList)
390 	{
391 		FieldSelect *fieldSelect = (FieldSelect *) lfirst(fieldSelectCell);
392 
393 		Expr *fieldExpression = fieldSelect->arg;
394 		if (!IsA(fieldExpression, Var))
395 		{
396 			continue;
397 		}
398 
399 		if (compositeFieldArray == NULL)
400 		{
401 			Var *compositeColumn = (Var *) fieldExpression;
402 			Oid compositeTypeId = compositeColumn->vartype;
403 			Oid compositeRelationId = get_typ_typrelid(compositeTypeId);
404 
405 			/* get composite type attribute count */
406 			Relation relation = relation_open(compositeRelationId, AccessShareLock);
407 			compositeFieldCount = relation->rd_att->natts;
408 			compositeFieldArray = palloc0(compositeFieldCount * sizeof(bool));
409 			relation_close(relation, AccessShareLock);
410 
411 			for (uint32 compositeFieldIndex = 0;
412 				 compositeFieldIndex < compositeFieldCount;
413 				 compositeFieldIndex++)
414 			{
415 				compositeFieldArray[compositeFieldIndex] = false;
416 			}
417 		}
418 
419 		uint32 compositeFieldIndex = fieldSelect->fieldnum - 1;
420 		compositeFieldArray[compositeFieldIndex] = true;
421 	}
422 
423 	for (uint32 fieldIndex = 0; fieldIndex < compositeFieldCount; fieldIndex++)
424 	{
425 		if (!compositeFieldArray[fieldIndex])
426 		{
427 			fullCompositeFieldList = false;
428 		}
429 	}
430 
431 	if (compositeFieldCount == 0)
432 	{
433 		fullCompositeFieldList = false;
434 	}
435 
436 	return fullCompositeFieldList;
437 }
438 
439 
440 /*
441  * CompositeFieldRecursive recursively finds composite field in the query tree
442  * referred by given expression. If expression does not refer to a composite
443  * field, then it returns NULL.
444  *
445  * If expression is a field select we directly return composite field. If it is
446  * a column is referenced from a subquery, then we recursively check that subquery
447  * until we reach the source of that column, and find composite field. If this
448  * column is referenced from join range table entry, then we resolve which join
449  * column it refers and recursively use this column with the same query.
450  */
451 static FieldSelect *
CompositeFieldRecursive(Expr * expression,Query * query)452 CompositeFieldRecursive(Expr *expression, Query *query)
453 {
454 	FieldSelect *compositeField = NULL;
455 	List *rangetableList = query->rtable;
456 	Var *candidateColumn = NULL;
457 
458 	if (IsA(expression, FieldSelect))
459 	{
460 		compositeField = (FieldSelect *) expression;
461 		return compositeField;
462 	}
463 
464 	if (IsA(expression, Var))
465 	{
466 		candidateColumn = (Var *) expression;
467 	}
468 	else
469 	{
470 		return NULL;
471 	}
472 
473 	Index rangeTableEntryIndex = candidateColumn->varno - 1;
474 	RangeTblEntry *rangeTableEntry = list_nth(rangetableList, rangeTableEntryIndex);
475 
476 	if (rangeTableEntry->rtekind == RTE_SUBQUERY)
477 	{
478 		Query *subquery = rangeTableEntry->subquery;
479 		List *targetEntryList = subquery->targetList;
480 		AttrNumber targetEntryIndex = candidateColumn->varattno - 1;
481 		TargetEntry *subqueryTargetEntry = list_nth(targetEntryList, targetEntryIndex);
482 
483 		Expr *subqueryExpression = subqueryTargetEntry->expr;
484 		compositeField = CompositeFieldRecursive(subqueryExpression, subquery);
485 	}
486 	else if (rangeTableEntry->rtekind == RTE_JOIN)
487 	{
488 		List *joinColumnList = rangeTableEntry->joinaliasvars;
489 		AttrNumber joinColumnIndex = candidateColumn->varattno - 1;
490 		Expr *joinColumn = list_nth(joinColumnList, joinColumnIndex);
491 
492 		compositeField = CompositeFieldRecursive(joinColumn, query);
493 	}
494 
495 	return compositeField;
496 }
497 
498 
499 /*
500  * SubqueryEntryList finds the subquery nodes in the range table entry list, and
501  * builds a list of subquery range table entries from these subquery nodes. Range
502  * table entry list also includes subqueries which are pulled up. We don't want
503  * to add pulled up subqueries to list, so we walk over join tree indexes and
504  * check range table entries referenced in the join tree.
505  */
506 List *
SubqueryEntryList(Query * queryTree)507 SubqueryEntryList(Query *queryTree)
508 {
509 	List *rangeTableList = queryTree->rtable;
510 	List *subqueryEntryList = NIL;
511 	List *joinTreeTableIndexList = NIL;
512 	ListCell *joinTreeTableIndexCell = NULL;
513 
514 	/*
515 	 * Extract all range table indexes from the join tree. Note that here we
516 	 * only walk over range table entries at this level and do not recurse into
517 	 * subqueries.
518 	 */
519 	ExtractRangeTableIndexWalker((Node *) queryTree->jointree, &joinTreeTableIndexList);
520 	foreach(joinTreeTableIndexCell, joinTreeTableIndexList)
521 	{
522 		/*
523 		 * Join tree's range table index starts from 1 in the query tree. But,
524 		 * list indexes start from 0.
525 		 */
526 		int joinTreeTableIndex = lfirst_int(joinTreeTableIndexCell);
527 		int rangeTableListIndex = joinTreeTableIndex - 1;
528 		RangeTblEntry *rangeTableEntry =
529 			(RangeTblEntry *) list_nth(rangeTableList, rangeTableListIndex);
530 
531 		if (rangeTableEntry->rtekind == RTE_SUBQUERY)
532 		{
533 			subqueryEntryList = lappend(subqueryEntryList, rangeTableEntry);
534 		}
535 	}
536 
537 	return subqueryEntryList;
538 }
539 
540 
541 /*
542  * MultiNodeTree takes in a parsed query tree and uses that tree to construct a
543  * logical plan. This plan is based on multi-relational algebra. This function
544  * creates the logical plan in several steps.
545  *
546  * First, the function checks if there is a subquery. If there is a subquery
547  * it recursively creates nested multi trees. If this query has a subquery, the
548  * function does not create any join trees and jumps to last step.
549  *
550  * If there is no subquery, the function calculates the join order using tables
551  * in the query and join clauses between the tables. Second, the function
552  * starts building the logical plan from the bottom-up, and begins with the table
553  * and collect nodes. Third, the function builds the join tree using the join
554  * order information and table nodes.
555  *
556  * In the last step, the function adds the select, project, aggregate, sort,
557  * group, and limit nodes if they appear in the original query tree.
558  */
559 MultiNode *
MultiNodeTree(Query * queryTree)560 MultiNodeTree(Query *queryTree)
561 {
562 	List *rangeTableList = queryTree->rtable;
563 	List *targetEntryList = queryTree->targetList;
564 	List *joinClauseList = NIL;
565 	List *joinOrderList = NIL;
566 	List *tableEntryList = NIL;
567 	List *tableNodeList = NIL;
568 	List *collectTableList = NIL;
569 	MultiNode *joinTreeNode = NULL;
570 	MultiNode *currentTopNode = NULL;
571 
572 	/* verify we can perform distributed planning on this query */
573 	DeferredErrorMessage *unsupportedQueryError = DeferErrorIfQueryNotSupported(
574 		queryTree);
575 	if (unsupportedQueryError != NULL)
576 	{
577 		RaiseDeferredError(unsupportedQueryError, ERROR);
578 	}
579 
580 	/* extract where clause qualifiers and verify we can plan for them */
581 	List *whereClauseList = WhereClauseList(queryTree->jointree);
582 	unsupportedQueryError = DeferErrorIfUnsupportedClause(whereClauseList);
583 	if (unsupportedQueryError)
584 	{
585 		RaiseDeferredErrorInternal(unsupportedQueryError, ERROR);
586 	}
587 
588 	/*
589 	 * If we have a subquery, build a multi table node for the subquery and
590 	 * add a collect node on top of the multi table node.
591 	 */
592 	List *subqueryEntryList = SubqueryEntryList(queryTree);
593 	if (subqueryEntryList != NIL)
594 	{
595 		MultiCollect *subqueryCollectNode = CitusMakeNode(MultiCollect);
596 		ListCell *columnCell = NULL;
597 
598 		/* we only support single subquery in the entry list */
599 		Assert(list_length(subqueryEntryList) == 1);
600 
601 		RangeTblEntry *subqueryRangeTableEntry = (RangeTblEntry *) linitial(
602 			subqueryEntryList);
603 		Query *subqueryTree = subqueryRangeTableEntry->subquery;
604 
605 		/* ensure if subquery satisfies preconditions */
606 		Assert(DeferErrorIfUnsupportedSubqueryRepartition(subqueryTree) == NULL);
607 
608 		MultiTable *subqueryNode = CitusMakeNode(MultiTable);
609 		subqueryNode->relationId = SUBQUERY_RELATION_ID;
610 		subqueryNode->rangeTableId = SUBQUERY_RANGE_TABLE_ID;
611 		subqueryNode->partitionColumn = NULL;
612 		subqueryNode->alias = NULL;
613 		subqueryNode->referenceNames = NULL;
614 
615 		/*
616 		 * We disregard pulled subqueries. This changes order of range table list.
617 		 * We do not allow subquery joins, so we will have only one range table
618 		 * entry in range table list after dropping pulled subquery. For this
619 		 * reason, here we are updating columns in the most outer query for where
620 		 * clause list and target list accordingly.
621 		 */
622 		Assert(list_length(subqueryEntryList) == 1);
623 
624 		List *whereClauseColumnList = pull_var_clause_default((Node *) whereClauseList);
625 		List *targetListColumnList = pull_var_clause_default((Node *) targetEntryList);
626 
627 		List *columnList = list_concat(whereClauseColumnList, targetListColumnList);
628 		foreach(columnCell, columnList)
629 		{
630 			Var *column = (Var *) lfirst(columnCell);
631 			column->varno = 1;
632 		}
633 
634 		/* recursively create child nested multitree */
635 		MultiNode *subqueryExtendedNode = MultiNodeTree(subqueryTree);
636 
637 		SetChild((MultiUnaryNode *) subqueryCollectNode, (MultiNode *) subqueryNode);
638 		SetChild((MultiUnaryNode *) subqueryNode, subqueryExtendedNode);
639 
640 		currentTopNode = (MultiNode *) subqueryCollectNode;
641 	}
642 	else
643 	{
644 		/*
645 		 * We calculate the join order using the list of tables in the query and
646 		 * the join clauses between them. Note that this function owns the table
647 		 * entry list's memory, and JoinOrderList() shallow copies the list's
648 		 * elements.
649 		 */
650 		joinClauseList = JoinClauseList(whereClauseList);
651 		tableEntryList = UsedTableEntryList(queryTree);
652 
653 		/* build the list of multi table nodes */
654 		tableNodeList = MultiTableNodeList(tableEntryList, rangeTableList);
655 
656 		/* add collect nodes on top of the multi table nodes */
657 		collectTableList = AddMultiCollectNodes(tableNodeList);
658 
659 		/* find best join order for commutative inner joins */
660 		joinOrderList = JoinOrderList(tableEntryList, joinClauseList);
661 
662 		/* build join tree using the join order and collected tables */
663 		joinTreeNode = MultiJoinTree(joinOrderList, collectTableList, joinClauseList);
664 
665 		currentTopNode = joinTreeNode;
666 	}
667 
668 	Assert(currentTopNode != NULL);
669 
670 	/* build select node if the query has selection criteria */
671 	MultiSelect *selectNode = MultiSelectNode(whereClauseList);
672 	if (selectNode != NULL)
673 	{
674 		SetChild((MultiUnaryNode *) selectNode, currentTopNode);
675 		currentTopNode = (MultiNode *) selectNode;
676 	}
677 
678 	/* build project node for the columns to project */
679 	MultiProject *projectNode = MultiProjectNode(targetEntryList);
680 	SetChild((MultiUnaryNode *) projectNode, currentTopNode);
681 	currentTopNode = (MultiNode *) projectNode;
682 
683 	/*
684 	 * We build the extended operator node to capture aggregate functions, group
685 	 * clauses, sort clauses, limit/offset clauses, and expressions. We need to
686 	 * distinguish between aggregates and expressions; and we address this later
687 	 * in the logical optimizer.
688 	 */
689 	MultiExtendedOp *extendedOpNode = MultiExtendedOpNode(queryTree, queryTree);
690 	SetChild((MultiUnaryNode *) extendedOpNode, currentTopNode);
691 	currentTopNode = (MultiNode *) extendedOpNode;
692 
693 	return currentTopNode;
694 }
695 
696 
697 /*
698  * ContainsReadIntermediateResultFunction determines whether an expresion tree contains
699  * a call to the read_intermediate_result function.
700  */
701 bool
ContainsReadIntermediateResultFunction(Node * node)702 ContainsReadIntermediateResultFunction(Node *node)
703 {
704 	return FindNodeMatchingCheckFunction(node, IsReadIntermediateResultFunction);
705 }
706 
707 
708 /*
709  * ContainsReadIntermediateResultArrayFunction determines whether an expresion
710  * tree contains a call to the read_intermediate_results(result_ids, format)
711  * function.
712  */
713 bool
ContainsReadIntermediateResultArrayFunction(Node * node)714 ContainsReadIntermediateResultArrayFunction(Node *node)
715 {
716 	return FindNodeMatchingCheckFunction(node, IsReadIntermediateResultArrayFunction);
717 }
718 
719 
720 /*
721  * IsReadIntermediateResultFunction determines whether a given node is a function call
722  * to the read_intermediate_result function.
723  */
724 static bool
IsReadIntermediateResultFunction(Node * node)725 IsReadIntermediateResultFunction(Node *node)
726 {
727 	return IsFunctionWithOid(node, CitusReadIntermediateResultFuncId());
728 }
729 
730 
731 /*
732  * IsReadIntermediateResultArrayFunction determines whether a given node is a
733  * function call to the read_intermediate_results(result_ids, format) function.
734  */
735 static bool
IsReadIntermediateResultArrayFunction(Node * node)736 IsReadIntermediateResultArrayFunction(Node *node)
737 {
738 	return IsFunctionWithOid(node, CitusReadIntermediateResultArrayFuncId());
739 }
740 
741 
742 /*
743  * IsCitusExtraDataContainerRelation determines whether a range table entry contains a
744  * call to the citus_extradata_container function.
745  */
746 bool
IsCitusExtraDataContainerRelation(RangeTblEntry * rte)747 IsCitusExtraDataContainerRelation(RangeTblEntry *rte)
748 {
749 	if (rte->rtekind != RTE_FUNCTION || list_length(rte->functions) != 1)
750 	{
751 		/* avoid more expensive checks below for non-functions */
752 		return false;
753 	}
754 
755 	if (!CitusHasBeenLoaded() || !CheckCitusVersion(DEBUG5))
756 	{
757 		return false;
758 	}
759 
760 	return FindNodeMatchingCheckFunction((Node *) rte->functions,
761 										 IsCitusExtraDataContainerFunc);
762 }
763 
764 
765 /*
766  * IsCitusExtraDataContainerFunc determines whether a given node is a function call
767  * to the citus_extradata_container function.
768  */
769 static bool
IsCitusExtraDataContainerFunc(Node * node)770 IsCitusExtraDataContainerFunc(Node *node)
771 {
772 	return IsFunctionWithOid(node, CitusExtraDataContainerFuncId());
773 }
774 
775 
776 /*
777  * IsFunctionWithOid determines whether a given node is a function call
778  * to the read_intermediate_result function.
779  */
780 static bool
IsFunctionWithOid(Node * node,Oid funcOid)781 IsFunctionWithOid(Node *node, Oid funcOid)
782 {
783 	if (IsA(node, FuncExpr))
784 	{
785 		FuncExpr *funcExpr = (FuncExpr *) node;
786 
787 		if (funcExpr->funcid == funcOid)
788 		{
789 			return true;
790 		}
791 	}
792 
793 	return false;
794 }
795 
796 
797 /*
798  * IsGroupingFunc returns whether node is a GroupingFunc.
799  */
800 static bool
IsGroupingFunc(Node * node)801 IsGroupingFunc(Node *node)
802 {
803 	return IsA(node, GroupingFunc);
804 }
805 
806 
807 /*
808  * FindIntermediateResultIdIfExists extracts the id of the intermediate result
809  * if the given RTE contains a read_intermediate_results function, NULL otherwise
810  */
811 char *
FindIntermediateResultIdIfExists(RangeTblEntry * rte)812 FindIntermediateResultIdIfExists(RangeTblEntry *rte)
813 {
814 	char *resultId = NULL;
815 
816 	Assert(rte->rtekind == RTE_FUNCTION);
817 
818 	List *functionList = rte->functions;
819 	RangeTblFunction *rangeTblfunction = (RangeTblFunction *) linitial(functionList);
820 	FuncExpr *funcExpr = (FuncExpr *) rangeTblfunction->funcexpr;
821 
822 	if (IsReadIntermediateResultFunction((Node *) funcExpr))
823 	{
824 		Const *resultIdConst = linitial(funcExpr->args);
825 
826 		if (!resultIdConst->constisnull)
827 		{
828 			resultId = TextDatumGetCString(resultIdConst->constvalue);
829 		}
830 	}
831 
832 	return resultId;
833 }
834 
835 
836 /*
837  * ErrorIfQueryNotSupported checks that we can perform distributed planning for
838  * the given query. The checks in this function will be removed as we support
839  * more functionality in our distributed planning.
840  */
841 DeferredErrorMessage *
DeferErrorIfQueryNotSupported(Query * queryTree)842 DeferErrorIfQueryNotSupported(Query *queryTree)
843 {
844 	char *errorMessage = NULL;
845 	bool preconditionsSatisfied = true;
846 	const char *errorHint = NULL;
847 	const char *joinHint = "Consider joining tables on partition column and have "
848 						   "equal filter on joining columns.";
849 	const char *filterHint = "Consider using an equality filter on the distributed "
850 							 "table's partition column.";
851 
852 	if (queryTree->setOperations)
853 	{
854 		preconditionsSatisfied = false;
855 		errorMessage = "could not run distributed query with UNION, INTERSECT, or "
856 					   "EXCEPT";
857 		errorHint = filterHint;
858 	}
859 
860 	if (queryTree->hasRecursive)
861 	{
862 		preconditionsSatisfied = false;
863 		errorMessage = "could not run distributed query with RECURSIVE";
864 		errorHint = filterHint;
865 	}
866 
867 	if (queryTree->cteList)
868 	{
869 		preconditionsSatisfied = false;
870 		errorMessage = "could not run distributed query with common table expressions";
871 		errorHint = filterHint;
872 	}
873 
874 	if (queryTree->hasForUpdate)
875 	{
876 		preconditionsSatisfied = false;
877 		errorMessage = "could not run distributed query with FOR UPDATE/SHARE commands";
878 		errorHint = filterHint;
879 	}
880 
881 	if (queryTree->groupingSets)
882 	{
883 		preconditionsSatisfied = false;
884 		errorMessage = "could not run distributed query with GROUPING SETS, CUBE, "
885 					   "or ROLLUP";
886 		errorHint = filterHint;
887 	}
888 
889 	if (FindNodeMatchingCheckFunction((Node *) queryTree, IsGroupingFunc))
890 	{
891 		preconditionsSatisfied = false;
892 		errorMessage = "could not run distributed query with GROUPING";
893 		errorHint = filterHint;
894 	}
895 
896 	bool hasTablesample = HasTablesample(queryTree);
897 	if (hasTablesample)
898 	{
899 		preconditionsSatisfied = false;
900 		errorMessage = "could not run distributed query which use TABLESAMPLE";
901 		errorHint = filterHint;
902 	}
903 
904 	bool hasUnsupportedJoin = HasUnsupportedJoinWalker((Node *) queryTree->jointree,
905 													   NULL);
906 	if (hasUnsupportedJoin)
907 	{
908 		preconditionsSatisfied = false;
909 		errorMessage = "could not run distributed query with join types other than "
910 					   "INNER or OUTER JOINS";
911 		errorHint = joinHint;
912 	}
913 
914 	bool hasComplexRangeTableType = HasComplexRangeTableType(queryTree);
915 	if (hasComplexRangeTableType)
916 	{
917 		preconditionsSatisfied = false;
918 		errorMessage = "could not run distributed query with complex table expressions";
919 		errorHint = filterHint;
920 	}
921 
922 	if (FindNodeMatchingCheckFunction((Node *) queryTree->limitCount, IsNodeSubquery))
923 	{
924 		preconditionsSatisfied = false;
925 		errorMessage = "subquery in LIMIT is not supported in multi-shard queries";
926 	}
927 
928 	if (FindNodeMatchingCheckFunction((Node *) queryTree->limitOffset, IsNodeSubquery))
929 	{
930 		preconditionsSatisfied = false;
931 		errorMessage = "subquery in OFFSET is not supported in multi-shard queries";
932 	}
933 
934 	RTEListProperties *queryRteListProperties = GetRTEListPropertiesForQuery(queryTree);
935 	if (queryRteListProperties->hasCitusLocalTable ||
936 		queryRteListProperties->hasPostgresLocalTable)
937 	{
938 		preconditionsSatisfied = false;
939 		errorMessage = "direct joins between distributed and local tables are "
940 					   "not supported";
941 		errorHint = LOCAL_TABLE_SUBQUERY_CTE_HINT;
942 	}
943 
944 	/* finally check and error out if not satisfied */
945 	if (!preconditionsSatisfied)
946 	{
947 		bool showHint = ErrorHintRequired(errorHint, queryTree);
948 		return DeferredError(ERRCODE_FEATURE_NOT_SUPPORTED,
949 							 errorMessage, NULL,
950 							 showHint ? errorHint : NULL);
951 	}
952 
953 	return NULL;
954 }
955 
956 
957 /* HasTablesample returns tree if the query contains tablesample */
958 static bool
HasTablesample(Query * queryTree)959 HasTablesample(Query *queryTree)
960 {
961 	List *rangeTableList = queryTree->rtable;
962 	ListCell *rangeTableEntryCell = NULL;
963 	bool hasTablesample = false;
964 
965 	foreach(rangeTableEntryCell, rangeTableList)
966 	{
967 		RangeTblEntry *rangeTableEntry = lfirst(rangeTableEntryCell);
968 		if (rangeTableEntry->tablesample)
969 		{
970 			hasTablesample = true;
971 			break;
972 		}
973 	}
974 
975 	return hasTablesample;
976 }
977 
978 
979 /*
980  * HasUnsupportedJoinWalker returns tree if the query contains an unsupported
981  * join type. We currently support inner, left, right, full and anti joins.
982  * Semi joins are not supported. A full description of these join types is
983  * included in nodes/nodes.h.
984  */
985 static bool
HasUnsupportedJoinWalker(Node * node,void * context)986 HasUnsupportedJoinWalker(Node *node, void *context)
987 {
988 	bool hasUnsupportedJoin = false;
989 
990 	if (node == NULL)
991 	{
992 		return false;
993 	}
994 
995 	if (IsA(node, JoinExpr))
996 	{
997 		JoinExpr *joinExpr = (JoinExpr *) node;
998 		JoinType joinType = joinExpr->jointype;
999 		bool outerJoin = IS_OUTER_JOIN(joinType);
1000 		if (!outerJoin && joinType != JOIN_INNER && joinType != JOIN_SEMI)
1001 		{
1002 			hasUnsupportedJoin = true;
1003 		}
1004 	}
1005 
1006 	if (!hasUnsupportedJoin)
1007 	{
1008 		hasUnsupportedJoin = expression_tree_walker(node, HasUnsupportedJoinWalker,
1009 													NULL);
1010 	}
1011 
1012 	return hasUnsupportedJoin;
1013 }
1014 
1015 
1016 /*
1017  * ErrorHintRequired returns true if error hint shold be displayed with the
1018  * query error message. Error hint is valid only for queries involving reference
1019  * and hash partitioned tables. If more than one hash distributed table is
1020  * present we display the hint only if the tables are colocated. If the query
1021  * only has reference table(s), then it is handled by router planner.
1022  */
1023 static bool
ErrorHintRequired(const char * errorHint,Query * queryTree)1024 ErrorHintRequired(const char *errorHint, Query *queryTree)
1025 {
1026 	List *distributedRelationIdList = DistributedRelationIdList(queryTree);
1027 	ListCell *relationIdCell = NULL;
1028 	List *colocationIdList = NIL;
1029 
1030 	if (errorHint == NULL)
1031 	{
1032 		return false;
1033 	}
1034 
1035 	foreach(relationIdCell, distributedRelationIdList)
1036 	{
1037 		Oid relationId = lfirst_oid(relationIdCell);
1038 		if (IsCitusTableType(relationId, REFERENCE_TABLE))
1039 		{
1040 			continue;
1041 		}
1042 		else if (IsCitusTableType(relationId, HASH_DISTRIBUTED))
1043 		{
1044 			int colocationId = TableColocationId(relationId);
1045 			colocationIdList = list_append_unique_int(colocationIdList, colocationId);
1046 		}
1047 		else
1048 		{
1049 			return false;
1050 		}
1051 	}
1052 
1053 	/* do not display the hint if there are more than one colocation group */
1054 	if (list_length(colocationIdList) > 1)
1055 	{
1056 		return false;
1057 	}
1058 
1059 	return true;
1060 }
1061 
1062 
1063 /*
1064  * DeferErrorIfUnsupportedSubqueryRepartition checks that we can perform distributed planning for
1065  * the given subquery. If not, a deferred error is returned. The function recursively
1066  * does this check to all lower levels of the subquery.
1067  */
1068 DeferredErrorMessage *
DeferErrorIfUnsupportedSubqueryRepartition(Query * subqueryTree)1069 DeferErrorIfUnsupportedSubqueryRepartition(Query *subqueryTree)
1070 {
1071 	char *errorDetail = NULL;
1072 	bool preconditionsSatisfied = true;
1073 	List *joinTreeTableIndexList = NIL;
1074 
1075 	if (!subqueryTree->hasAggs)
1076 	{
1077 		preconditionsSatisfied = false;
1078 		errorDetail = "Subqueries without aggregates are not supported yet";
1079 	}
1080 
1081 	if (subqueryTree->groupClause == NIL)
1082 	{
1083 		preconditionsSatisfied = false;
1084 		errorDetail = "Subqueries without group by clause are not supported yet";
1085 	}
1086 
1087 	if (subqueryTree->sortClause != NULL)
1088 	{
1089 		preconditionsSatisfied = false;
1090 		errorDetail = "Subqueries with order by clause are not supported yet";
1091 	}
1092 
1093 	if (subqueryTree->limitCount != NULL)
1094 	{
1095 		preconditionsSatisfied = false;
1096 		errorDetail = "Subqueries with limit are not supported yet";
1097 	}
1098 
1099 	if (subqueryTree->limitOffset != NULL)
1100 	{
1101 		preconditionsSatisfied = false;
1102 		errorDetail = "Subqueries with offset are not supported yet";
1103 	}
1104 
1105 	if (subqueryTree->hasSubLinks)
1106 	{
1107 		preconditionsSatisfied = false;
1108 		errorDetail = "Subqueries other than from-clause subqueries are unsupported";
1109 	}
1110 
1111 	/* finally check and return error if conditions are not satisfied */
1112 	if (!preconditionsSatisfied)
1113 	{
1114 		return DeferredError(ERRCODE_FEATURE_NOT_SUPPORTED,
1115 							 "cannot perform distributed planning on this query",
1116 							 errorDetail, NULL);
1117 	}
1118 
1119 	/*
1120 	 * Extract all range table indexes from the join tree. Note that sub-queries
1121 	 * that get pulled up by PostgreSQL don't appear in this join tree.
1122 	 */
1123 	ExtractRangeTableIndexWalker((Node *) subqueryTree->jointree,
1124 								 &joinTreeTableIndexList);
1125 	Assert(list_length(joinTreeTableIndexList) == 1);
1126 
1127 	/* continue with the inner subquery */
1128 	int rangeTableIndex = linitial_int(joinTreeTableIndexList);
1129 	RangeTblEntry *rangeTableEntry = rt_fetch(rangeTableIndex, subqueryTree->rtable);
1130 	if (rangeTableEntry->rtekind == RTE_RELATION)
1131 	{
1132 		return NULL;
1133 	}
1134 
1135 	Assert(rangeTableEntry->rtekind == RTE_SUBQUERY);
1136 	Query *innerSubquery = rangeTableEntry->subquery;
1137 
1138 	/* recursively continue to the inner subqueries */
1139 	return DeferErrorIfUnsupportedSubqueryRepartition(innerSubquery);
1140 }
1141 
1142 
1143 /*
1144  * HasComplexRangeTableType checks if the given query tree contains any complex
1145  * range table types. For this, the function walks over all range tables in the
1146  * join tree, and checks if they correspond to simple relations or subqueries.
1147  * If they don't, the function assumes the query has complex range tables.
1148  */
1149 static bool
HasComplexRangeTableType(Query * queryTree)1150 HasComplexRangeTableType(Query *queryTree)
1151 {
1152 	List *rangeTableList = queryTree->rtable;
1153 	List *joinTreeTableIndexList = NIL;
1154 	ListCell *joinTreeTableIndexCell = NULL;
1155 	bool hasComplexRangeTableType = false;
1156 
1157 	/*
1158 	 * Extract all range table indexes from the join tree. Note that sub-queries
1159 	 * that get pulled up by PostgreSQL don't appear in this join tree.
1160 	 */
1161 	ExtractRangeTableIndexWalker((Node *) queryTree->jointree, &joinTreeTableIndexList);
1162 	foreach(joinTreeTableIndexCell, joinTreeTableIndexList)
1163 	{
1164 		/*
1165 		 * Join tree's range table index starts from 1 in the query tree. But,
1166 		 * list indexes start from 0.
1167 		 */
1168 		int joinTreeTableIndex = lfirst_int(joinTreeTableIndexCell);
1169 		int rangeTableListIndex = joinTreeTableIndex - 1;
1170 
1171 		RangeTblEntry *rangeTableEntry =
1172 			(RangeTblEntry *) list_nth(rangeTableList, rangeTableListIndex);
1173 
1174 		/*
1175 		 * Check if the range table in the join tree is a simple relation or a
1176 		 * subquery or a function. Note that RTE_FUNCTIONs are handled via (sub)query
1177 		 * pushdown.
1178 		 */
1179 		if (rangeTableEntry->rtekind != RTE_RELATION &&
1180 			rangeTableEntry->rtekind != RTE_SUBQUERY &&
1181 			rangeTableEntry->rtekind != RTE_FUNCTION &&
1182 			rangeTableEntry->rtekind != RTE_VALUES)
1183 		{
1184 			hasComplexRangeTableType = true;
1185 		}
1186 
1187 		/*
1188 		 * Check if the subquery range table entry includes children inheritance.
1189 		 *
1190 		 * Note that PostgreSQL flattens out simple union all queries into an
1191 		 * append relation, sets "inh" field of RangeTblEntry to true and deletes
1192 		 * set operations. Here we check this for subqueries.
1193 		 */
1194 		if (rangeTableEntry->rtekind == RTE_SUBQUERY && rangeTableEntry->inh)
1195 		{
1196 			hasComplexRangeTableType = true;
1197 		}
1198 	}
1199 
1200 	return hasComplexRangeTableType;
1201 }
1202 
1203 
1204 /*
1205  * WhereClauseList walks over the FROM expression in the query tree, and builds
1206  * a list of all clauses from the expression tree. The function checks for both
1207  * implicitly and explicitly defined clauses, but only selects INNER join
1208  * explicit clauses, and skips any outer-join clauses. Explicit clauses are
1209  * expressed as "SELECT ... FROM R1 INNER JOIN R2 ON R1.A = R2.A". Implicit
1210  * joins differ in that they live in the WHERE clause, and are expressed as
1211  * "SELECT ... FROM ... WHERE R1.a = R2.a".
1212  */
1213 List *
WhereClauseList(FromExpr * fromExpr)1214 WhereClauseList(FromExpr *fromExpr)
1215 {
1216 	FromExpr *fromExprCopy = copyObject(fromExpr);
1217 	QualifierWalkerContext *walkerContext = palloc0(sizeof(QualifierWalkerContext));
1218 
1219 	ExtractFromExpressionWalker((Node *) fromExprCopy, walkerContext);
1220 	List *whereClauseList = walkerContext->baseQualifierList;
1221 
1222 	return whereClauseList;
1223 }
1224 
1225 
1226 /*
1227  * QualifierList walks over the FROM expression in the query tree, and builds
1228  * a list of all qualifiers from the expression tree. The function checks for
1229  * both implicitly and explicitly defined qualifiers. Note that this function
1230  * is very similar to WhereClauseList(), but QualifierList() also includes
1231  * outer-join clauses.
1232  */
1233 List *
QualifierList(FromExpr * fromExpr)1234 QualifierList(FromExpr *fromExpr)
1235 {
1236 	FromExpr *fromExprCopy = copyObject(fromExpr);
1237 	QualifierWalkerContext *walkerContext = palloc0(sizeof(QualifierWalkerContext));
1238 	List *qualifierList = NIL;
1239 
1240 	ExtractFromExpressionWalker((Node *) fromExprCopy, walkerContext);
1241 	qualifierList = list_concat(qualifierList, walkerContext->baseQualifierList);
1242 	qualifierList = list_concat(qualifierList, walkerContext->outerJoinQualifierList);
1243 
1244 	return qualifierList;
1245 }
1246 
1247 
1248 /*
1249  * DeferErrorIfUnsupportedClause walks over the given list of clauses, and
1250  * checks that we can recognize all the clauses. This function ensures that
1251  * we do not drop an unsupported clause type on the floor, and thus prevents
1252  * erroneous results.
1253  *
1254  * Returns a deferred error, caller is responsible for raising the error.
1255  */
1256 DeferredErrorMessage *
DeferErrorIfUnsupportedClause(List * clauseList)1257 DeferErrorIfUnsupportedClause(List *clauseList)
1258 {
1259 	ListCell *clauseCell = NULL;
1260 	foreach(clauseCell, clauseList)
1261 	{
1262 		Node *clause = (Node *) lfirst(clauseCell);
1263 
1264 		if (!(IsSelectClause(clause) || IsJoinClause(clause) || or_clause(clause)))
1265 		{
1266 			return DeferredError(ERRCODE_FEATURE_NOT_SUPPORTED,
1267 								 "unsupported clause type", NULL, NULL);
1268 		}
1269 	}
1270 	return NULL;
1271 }
1272 
1273 
1274 /*
1275  * JoinClauseList finds the join clauses from the given where clause expression
1276  * list, and returns them. The function does not iterate into nested OR clauses
1277  * and relies on find_duplicate_ors() in the optimizer to pull up factorizable
1278  * OR clauses.
1279  */
1280 List *
JoinClauseList(List * whereClauseList)1281 JoinClauseList(List *whereClauseList)
1282 {
1283 	List *joinClauseList = NIL;
1284 	ListCell *whereClauseCell = NULL;
1285 
1286 	foreach(whereClauseCell, whereClauseList)
1287 	{
1288 		Node *whereClause = (Node *) lfirst(whereClauseCell);
1289 		if (IsJoinClause(whereClause))
1290 		{
1291 			joinClauseList = lappend(joinClauseList, whereClause);
1292 		}
1293 	}
1294 
1295 	return joinClauseList;
1296 }
1297 
1298 
1299 /*
1300  * ExtractFromExpressionWalker walks over a FROM expression, and finds all
1301  * implicit and explicit qualifiers in the expression. The function looks at
1302  * join and from expression nodes to find qualifiers, and returns these
1303  * qualifiers.
1304  *
1305  * Note that we don't want outer join clauses in regular outer join planning,
1306  * but we need outer join clauses in subquery pushdown prerequisite checks.
1307  * Therefore, outer join qualifiers are returned in a different list than other
1308  * qualifiers inside the given walker context. For this reason, we return two
1309  * qualifier lists.
1310  *
1311  * Note that we check if the qualifier node in join and from expression nodes
1312  * is a list node. If it is not a list node which is the case for subqueries,
1313  * then we run eval_const_expressions(), canonicalize_qual() and make_ands_implicit()
1314  * on the qualifier node and get a list of flattened implicitly AND'ed qualifier
1315  * list. Actually in the planer phase of PostgreSQL these functions also run on
1316  * subqueries but differently from the outermost query, they are run on a copy
1317  * of parse tree and changes do not get persisted as modifications to the original
1318  * query tree.
1319  *
1320  * Also this function adds SubLinks to the baseQualifierList when they appear on
1321  * the query's WHERE clause. The callers of the function should consider processing
1322  * Sublinks as well.
1323  */
1324 static bool
ExtractFromExpressionWalker(Node * node,QualifierWalkerContext * walkerContext)1325 ExtractFromExpressionWalker(Node *node, QualifierWalkerContext *walkerContext)
1326 {
1327 	if (node == NULL)
1328 	{
1329 		return false;
1330 	}
1331 
1332 	/*
1333 	 * Get qualifier lists of join and from expression nodes. Note that in the
1334 	 * case of subqueries, PostgreSQL can skip simplifying, flattening and
1335 	 * making ANDs implicit. If qualifiers node is not a list, then we run these
1336 	 * preprocess routines on qualifiers node.
1337 	 */
1338 	if (IsA(node, JoinExpr))
1339 	{
1340 		List *joinQualifierList = NIL;
1341 		JoinExpr *joinExpression = (JoinExpr *) node;
1342 		Node *joinQualifiersNode = joinExpression->quals;
1343 		JoinType joinType = joinExpression->jointype;
1344 
1345 		if (joinQualifiersNode != NULL)
1346 		{
1347 			if (IsA(joinQualifiersNode, List))
1348 			{
1349 				joinQualifierList = (List *) joinQualifiersNode;
1350 			}
1351 			else
1352 			{
1353 				/* this part of code only run for subqueries */
1354 				Node *joinClause = eval_const_expressions(NULL, joinQualifiersNode);
1355 				joinClause = (Node *) canonicalize_qual((Expr *) joinClause, false);
1356 				joinQualifierList = make_ands_implicit((Expr *) joinClause);
1357 			}
1358 		}
1359 
1360 		/* return outer join clauses in a separate list */
1361 		if (joinType == JOIN_INNER || joinType == JOIN_SEMI)
1362 		{
1363 			walkerContext->baseQualifierList =
1364 				list_concat(walkerContext->baseQualifierList, joinQualifierList);
1365 		}
1366 		else if (IS_OUTER_JOIN(joinType))
1367 		{
1368 			walkerContext->outerJoinQualifierList =
1369 				list_concat(walkerContext->outerJoinQualifierList, joinQualifierList);
1370 		}
1371 	}
1372 	else if (IsA(node, FromExpr))
1373 	{
1374 		List *fromQualifierList = NIL;
1375 		FromExpr *fromExpression = (FromExpr *) node;
1376 		Node *fromQualifiersNode = fromExpression->quals;
1377 
1378 		if (fromQualifiersNode != NULL)
1379 		{
1380 			if (IsA(fromQualifiersNode, List))
1381 			{
1382 				fromQualifierList = (List *) fromQualifiersNode;
1383 			}
1384 			else
1385 			{
1386 				/* this part of code only run for subqueries */
1387 				Node *fromClause = eval_const_expressions(NULL, fromQualifiersNode);
1388 				fromClause = (Node *) canonicalize_qual((Expr *) fromClause, false);
1389 				fromQualifierList = make_ands_implicit((Expr *) fromClause);
1390 			}
1391 
1392 			walkerContext->baseQualifierList =
1393 				list_concat(walkerContext->baseQualifierList, fromQualifierList);
1394 		}
1395 	}
1396 
1397 	bool walkerResult = expression_tree_walker(node, ExtractFromExpressionWalker,
1398 											   (void *) walkerContext);
1399 
1400 	return walkerResult;
1401 }
1402 
1403 
1404 /*
1405  * IsJoinClause determines if the given node is a join clause according to our
1406  * criteria. Our criteria defines a join clause as an equi join operator between
1407  * two columns that belong to two different tables.
1408  */
1409 bool
IsJoinClause(Node * clause)1410 IsJoinClause(Node *clause)
1411 {
1412 	Var *var = NULL;
1413 
1414 	/*
1415 	 * take all column references from the clause, if we find 2 column references from a
1416 	 * different relation we assume this is a join clause
1417 	 */
1418 	List *varList = pull_var_clause_default(clause);
1419 	if (list_length(varList) <= 0)
1420 	{
1421 		/* no column references in query, not describing a join */
1422 		return false;
1423 	}
1424 	Var *initialVar = castNode(Var, linitial(varList));
1425 
1426 	foreach_ptr(var, varList)
1427 	{
1428 		if (var->varno != initialVar->varno)
1429 		{
1430 			/*
1431 			 * this column reference comes from a different relation, hence describing a
1432 			 * join
1433 			 */
1434 			return true;
1435 		}
1436 	}
1437 
1438 	/* all column references were to the same relation, no join */
1439 	return false;
1440 }
1441 
1442 
1443 /*
1444  * TableEntryList finds the regular relation nodes in the range table entry
1445  * list, and builds a list of table entries from these regular relation nodes.
1446  */
1447 List *
TableEntryList(List * rangeTableList)1448 TableEntryList(List *rangeTableList)
1449 {
1450 	List *tableEntryList = NIL;
1451 	ListCell *rangeTableCell = NULL;
1452 	uint32 tableId = 1; /* range table indices start at 1 */
1453 
1454 	foreach(rangeTableCell, rangeTableList)
1455 	{
1456 		RangeTblEntry *rangeTableEntry = (RangeTblEntry *) lfirst(rangeTableCell);
1457 
1458 		if (rangeTableEntry->rtekind == RTE_RELATION)
1459 		{
1460 			TableEntry *tableEntry = (TableEntry *) palloc0(sizeof(TableEntry));
1461 			tableEntry->relationId = rangeTableEntry->relid;
1462 			tableEntry->rangeTableId = tableId;
1463 
1464 			tableEntryList = lappend(tableEntryList, tableEntry);
1465 		}
1466 
1467 		/*
1468 		 * Increment tableId regardless so that table entry's tableId remains
1469 		 * congruent with column's range table reference (varno).
1470 		 */
1471 		tableId++;
1472 	}
1473 
1474 	return tableEntryList;
1475 }
1476 
1477 
1478 /*
1479  * UsedTableEntryList returns list of relation range table entries
1480  * that are referenced within the query. Unused entries due to query
1481  * flattening or re-rewriting are ignored.
1482  */
1483 List *
UsedTableEntryList(Query * query)1484 UsedTableEntryList(Query *query)
1485 {
1486 	List *tableEntryList = NIL;
1487 	List *rangeTableList = query->rtable;
1488 	List *joinTreeTableIndexList = NIL;
1489 	ListCell *joinTreeTableIndexCell = NULL;
1490 
1491 	ExtractRangeTableIndexWalker((Node *) query->jointree, &joinTreeTableIndexList);
1492 	foreach(joinTreeTableIndexCell, joinTreeTableIndexList)
1493 	{
1494 		int joinTreeTableIndex = lfirst_int(joinTreeTableIndexCell);
1495 		RangeTblEntry *rangeTableEntry = rt_fetch(joinTreeTableIndex, rangeTableList);
1496 		if (rangeTableEntry->rtekind == RTE_RELATION)
1497 		{
1498 			TableEntry *tableEntry = (TableEntry *) palloc0(sizeof(TableEntry));
1499 			tableEntry->relationId = rangeTableEntry->relid;
1500 			tableEntry->rangeTableId = joinTreeTableIndex;
1501 
1502 			tableEntryList = lappend(tableEntryList, tableEntry);
1503 		}
1504 	}
1505 
1506 	return tableEntryList;
1507 }
1508 
1509 
1510 /*
1511  * MultiTableNodeList builds a list of MultiTable nodes from the given table
1512  * entry list. A multi table node represents one entry from the range table
1513  * list. These entries may belong to the same physical relation in the case of
1514  * self-joins.
1515  */
1516 static List *
MultiTableNodeList(List * tableEntryList,List * rangeTableList)1517 MultiTableNodeList(List *tableEntryList, List *rangeTableList)
1518 {
1519 	List *tableNodeList = NIL;
1520 	ListCell *tableEntryCell = NULL;
1521 
1522 	foreach(tableEntryCell, tableEntryList)
1523 	{
1524 		TableEntry *tableEntry = (TableEntry *) lfirst(tableEntryCell);
1525 		Oid relationId = tableEntry->relationId;
1526 		uint32 rangeTableId = tableEntry->rangeTableId;
1527 		Var *partitionColumn = PartitionColumn(relationId, rangeTableId);
1528 		RangeTblEntry *rangeTableEntry = rt_fetch(rangeTableId, rangeTableList);
1529 
1530 		MultiTable *tableNode = CitusMakeNode(MultiTable);
1531 		tableNode->subquery = NULL;
1532 		tableNode->relationId = relationId;
1533 		tableNode->rangeTableId = rangeTableId;
1534 		tableNode->partitionColumn = partitionColumn;
1535 		tableNode->alias = rangeTableEntry->alias;
1536 		tableNode->referenceNames = rangeTableEntry->eref;
1537 		tableNode->includePartitions = GetOriginalInh(rangeTableEntry);
1538 
1539 		tableNodeList = lappend(tableNodeList, tableNode);
1540 	}
1541 
1542 	return tableNodeList;
1543 }
1544 
1545 
1546 /* Adds a MultiCollect node on top of each MultiTable node in the given list. */
1547 static List *
AddMultiCollectNodes(List * tableNodeList)1548 AddMultiCollectNodes(List *tableNodeList)
1549 {
1550 	List *collectTableList = NIL;
1551 	ListCell *tableNodeCell = NULL;
1552 
1553 	foreach(tableNodeCell, tableNodeList)
1554 	{
1555 		MultiTable *tableNode = (MultiTable *) lfirst(tableNodeCell);
1556 
1557 		MultiCollect *collectNode = CitusMakeNode(MultiCollect);
1558 		SetChild((MultiUnaryNode *) collectNode, (MultiNode *) tableNode);
1559 
1560 		collectTableList = lappend(collectTableList, collectNode);
1561 	}
1562 
1563 	return collectTableList;
1564 }
1565 
1566 
1567 /*
1568  * MultiJoinTree takes in the join order information and the list of tables, and
1569  * builds a join tree by applying the corresponding join rules. The function
1570  * builds a left deep tree, as expressed by the join order list.
1571  *
1572  * The function starts by setting the first table as the top node in the join
1573  * tree. Then, the function iterates over the list of tables, and builds a new
1574  * join node between the top of the join tree and the next table in the list.
1575  * At each iteration, the function sets the top of the join tree to the newly
1576  * built list. This results in a left deep join tree, and the function returns
1577  * this tree after every table in the list has been joined.
1578  */
1579 static MultiNode *
MultiJoinTree(List * joinOrderList,List * collectTableList,List * joinWhereClauseList)1580 MultiJoinTree(List *joinOrderList, List *collectTableList, List *joinWhereClauseList)
1581 {
1582 	MultiNode *currentTopNode = NULL;
1583 	ListCell *joinOrderCell = NULL;
1584 	bool firstJoinNode = true;
1585 
1586 	foreach(joinOrderCell, joinOrderList)
1587 	{
1588 		JoinOrderNode *joinOrderNode = (JoinOrderNode *) lfirst(joinOrderCell);
1589 		uint32 joinTableId = joinOrderNode->tableEntry->rangeTableId;
1590 		MultiCollect *collectNode = CollectNodeForTable(collectTableList, joinTableId);
1591 
1592 		if (firstJoinNode)
1593 		{
1594 			currentTopNode = (MultiNode *) collectNode;
1595 			firstJoinNode = false;
1596 		}
1597 		else
1598 		{
1599 			JoinRuleType joinRuleType = joinOrderNode->joinRuleType;
1600 			JoinType joinType = joinOrderNode->joinType;
1601 			List *partitionColumnList = joinOrderNode->partitionColumnList;
1602 			List *joinClauseList = joinOrderNode->joinClauseList;
1603 
1604 			/*
1605 			 * Build a join node between the top of our join tree and the next
1606 			 * table in the join order.
1607 			 */
1608 			MultiNode *newJoinNode = ApplyJoinRule(currentTopNode,
1609 												   (MultiNode *) collectNode,
1610 												   joinRuleType, partitionColumnList,
1611 												   joinType,
1612 												   joinClauseList);
1613 
1614 			/* the new join node becomes the top of our join tree */
1615 			currentTopNode = newJoinNode;
1616 		}
1617 	}
1618 
1619 	/* current top node points to the entire left deep join tree */
1620 	return currentTopNode;
1621 }
1622 
1623 
1624 /*
1625  * CollectNodeForTable finds the MultiCollect node whose MultiTable node has the
1626  * given range table identifier. Note that this function expects each collect
1627  * node in the given list to have one table node as its child.
1628  */
1629 static MultiCollect *
CollectNodeForTable(List * collectTableList,uint32 rangeTableId)1630 CollectNodeForTable(List *collectTableList, uint32 rangeTableId)
1631 {
1632 	MultiCollect *collectNodeForTable = NULL;
1633 	ListCell *collectTableCell = NULL;
1634 
1635 	foreach(collectTableCell, collectTableList)
1636 	{
1637 		MultiCollect *collectNode = (MultiCollect *) lfirst(collectTableCell);
1638 
1639 		List *tableIdList = OutputTableIdList((MultiNode *) collectNode);
1640 		uint32 tableId = (uint32) linitial_int(tableIdList);
1641 		Assert(list_length(tableIdList) == 1);
1642 
1643 		if (tableId == rangeTableId)
1644 		{
1645 			collectNodeForTable = collectNode;
1646 			break;
1647 		}
1648 	}
1649 
1650 	Assert(collectNodeForTable != NULL);
1651 	return collectNodeForTable;
1652 }
1653 
1654 
1655 /*
1656  * MultiSelectNode extracts the select clauses from the given where clause list,
1657  * and builds a MultiSelect node from these clauses. If the expression tree does
1658  * not have any select clauses, the function return null.
1659  */
1660 static MultiSelect *
MultiSelectNode(List * whereClauseList)1661 MultiSelectNode(List *whereClauseList)
1662 {
1663 	List *selectClauseList = NIL;
1664 	MultiSelect *selectNode = NULL;
1665 
1666 	ListCell *whereClauseCell = NULL;
1667 	foreach(whereClauseCell, whereClauseList)
1668 	{
1669 		Node *whereClause = (Node *) lfirst(whereClauseCell);
1670 		if (IsSelectClause(whereClause))
1671 		{
1672 			selectClauseList = lappend(selectClauseList, whereClause);
1673 		}
1674 	}
1675 
1676 	if (list_length(selectClauseList) > 0)
1677 	{
1678 		selectNode = CitusMakeNode(MultiSelect);
1679 		selectNode->selectClauseList = selectClauseList;
1680 	}
1681 
1682 	return selectNode;
1683 }
1684 
1685 
1686 /*
1687  * IsSelectClause determines if the given node is a select clause according to
1688  * our criteria. Our criteria defines a select clause as an expression that has
1689  * zero or more columns belonging to only one table. The function assumes that
1690  * no sublinks exists in the clause.
1691  */
1692 static bool
IsSelectClause(Node * clause)1693 IsSelectClause(Node *clause)
1694 {
1695 	ListCell *columnCell = NULL;
1696 	bool isSelectClause = true;
1697 
1698 	/* extract columns from the clause */
1699 	List *columnList = pull_var_clause_default(clause);
1700 	if (list_length(columnList) == 0)
1701 	{
1702 		return true;
1703 	}
1704 
1705 	/* get first column's tableId */
1706 	Var *firstColumn = (Var *) linitial(columnList);
1707 	Index firstColumnTableId = firstColumn->varno;
1708 
1709 	/* check if all columns are from the same table */
1710 	foreach(columnCell, columnList)
1711 	{
1712 		Var *column = (Var *) lfirst(columnCell);
1713 		if (column->varno != firstColumnTableId)
1714 		{
1715 			isSelectClause = false;
1716 		}
1717 	}
1718 
1719 	return isSelectClause;
1720 }
1721 
1722 
1723 /*
1724  * MultiProjectNode builds the project node using the target entry information
1725  * from the query tree. The project node only encapsulates projected columns,
1726  * and does not include aggregates, group clauses, or project expressions.
1727  */
1728 MultiProject *
MultiProjectNode(List * targetEntryList)1729 MultiProjectNode(List *targetEntryList)
1730 {
1731 	List *uniqueColumnList = NIL;
1732 	ListCell *columnCell = NULL;
1733 
1734 	/* extract the list of columns and remove any duplicates */
1735 	List *columnList = pull_var_clause_default((Node *) targetEntryList);
1736 	foreach(columnCell, columnList)
1737 	{
1738 		Var *column = (Var *) lfirst(columnCell);
1739 
1740 		uniqueColumnList = list_append_unique(uniqueColumnList, column);
1741 	}
1742 
1743 	/* create project node with list of columns to project */
1744 	MultiProject *projectNode = CitusMakeNode(MultiProject);
1745 	projectNode->columnList = uniqueColumnList;
1746 
1747 	return projectNode;
1748 }
1749 
1750 
1751 /* Builds the extended operator node using fields from the given query tree. */
1752 MultiExtendedOp *
MultiExtendedOpNode(Query * queryTree,Query * originalQuery)1753 MultiExtendedOpNode(Query *queryTree, Query *originalQuery)
1754 {
1755 	MultiExtendedOp *extendedOpNode = CitusMakeNode(MultiExtendedOp);
1756 	extendedOpNode->targetList = queryTree->targetList;
1757 	extendedOpNode->groupClauseList = queryTree->groupClause;
1758 	extendedOpNode->sortClauseList = queryTree->sortClause;
1759 	extendedOpNode->limitCount = queryTree->limitCount;
1760 	extendedOpNode->limitOffset = queryTree->limitOffset;
1761 #if PG_VERSION_NUM >= PG_VERSION_13
1762 	extendedOpNode->limitOption = queryTree->limitOption;
1763 #endif
1764 	extendedOpNode->havingQual = queryTree->havingQual;
1765 	extendedOpNode->distinctClause = queryTree->distinctClause;
1766 	extendedOpNode->hasDistinctOn = queryTree->hasDistinctOn;
1767 	extendedOpNode->hasWindowFuncs = queryTree->hasWindowFuncs;
1768 	extendedOpNode->windowClause = queryTree->windowClause;
1769 	extendedOpNode->onlyPushableWindowFunctions =
1770 		!queryTree->hasWindowFuncs ||
1771 		SafeToPushdownWindowFunction(originalQuery, NULL);
1772 
1773 	return extendedOpNode;
1774 }
1775 
1776 
1777 /* Helper function to return the parent node of the given node. */
1778 MultiNode *
ParentNode(MultiNode * multiNode)1779 ParentNode(MultiNode *multiNode)
1780 {
1781 	MultiNode *parentNode = multiNode->parentNode;
1782 	return parentNode;
1783 }
1784 
1785 
1786 /* Helper function to return the child of the given unary node. */
1787 MultiNode *
ChildNode(MultiUnaryNode * multiNode)1788 ChildNode(MultiUnaryNode *multiNode)
1789 {
1790 	MultiNode *childNode = multiNode->childNode;
1791 	return childNode;
1792 }
1793 
1794 
1795 /* Helper function to return the grand child of the given unary node. */
1796 MultiNode *
GrandChildNode(MultiUnaryNode * multiNode)1797 GrandChildNode(MultiUnaryNode *multiNode)
1798 {
1799 	MultiNode *childNode = ChildNode(multiNode);
1800 	MultiNode *grandChildNode = ChildNode((MultiUnaryNode *) childNode);
1801 
1802 	return grandChildNode;
1803 }
1804 
1805 
1806 /* Sets the given child node as a child of the given unary parent node. */
1807 void
SetChild(MultiUnaryNode * parent,MultiNode * child)1808 SetChild(MultiUnaryNode *parent, MultiNode *child)
1809 {
1810 	parent->childNode = child;
1811 	child->parentNode = (MultiNode *) parent;
1812 }
1813 
1814 
1815 /* Sets the given child node as a left child of the given parent node. */
1816 void
SetLeftChild(MultiBinaryNode * parent,MultiNode * leftChild)1817 SetLeftChild(MultiBinaryNode *parent, MultiNode *leftChild)
1818 {
1819 	parent->leftChildNode = leftChild;
1820 	leftChild->parentNode = (MultiNode *) parent;
1821 }
1822 
1823 
1824 /* Sets the given child node as a right child of the given parent node. */
1825 void
SetRightChild(MultiBinaryNode * parent,MultiNode * rightChild)1826 SetRightChild(MultiBinaryNode *parent, MultiNode *rightChild)
1827 {
1828 	parent->rightChildNode = rightChild;
1829 	rightChild->parentNode = (MultiNode *) parent;
1830 }
1831 
1832 
1833 /* Returns true if the given node is a unary operator. */
1834 bool
UnaryOperator(MultiNode * node)1835 UnaryOperator(MultiNode *node)
1836 {
1837 	bool unaryOperator = false;
1838 
1839 	if (CitusIsA(node, MultiTreeRoot) || CitusIsA(node, MultiTable) ||
1840 		CitusIsA(node, MultiCollect) || CitusIsA(node, MultiSelect) ||
1841 		CitusIsA(node, MultiProject) || CitusIsA(node, MultiPartition) ||
1842 		CitusIsA(node, MultiExtendedOp))
1843 	{
1844 		unaryOperator = true;
1845 	}
1846 
1847 	return unaryOperator;
1848 }
1849 
1850 
1851 /* Returns true if the given node is a binary operator. */
1852 bool
BinaryOperator(MultiNode * node)1853 BinaryOperator(MultiNode *node)
1854 {
1855 	bool binaryOperator = false;
1856 
1857 	if (CitusIsA(node, MultiJoin) || CitusIsA(node, MultiCartesianProduct))
1858 	{
1859 		binaryOperator = true;
1860 	}
1861 
1862 	return binaryOperator;
1863 }
1864 
1865 
1866 /*
1867  * OutputTableIdList finds all table identifiers that are output by the given
1868  * multi node, and returns these identifiers in a new list.
1869  */
1870 List *
OutputTableIdList(MultiNode * multiNode)1871 OutputTableIdList(MultiNode *multiNode)
1872 {
1873 	List *tableIdList = NIL;
1874 	List *tableNodeList = FindNodesOfType(multiNode, T_MultiTable);
1875 	ListCell *tableNodeCell = NULL;
1876 
1877 	foreach(tableNodeCell, tableNodeList)
1878 	{
1879 		MultiTable *tableNode = (MultiTable *) lfirst(tableNodeCell);
1880 		int tableId = (int) tableNode->rangeTableId;
1881 
1882 		if (tableId != SUBQUERY_RANGE_TABLE_ID)
1883 		{
1884 			tableIdList = lappend_int(tableIdList, tableId);
1885 		}
1886 	}
1887 
1888 	return tableIdList;
1889 }
1890 
1891 
1892 /*
1893  * FindNodesOfType takes in a given logical plan tree, and recursively traverses
1894  * the tree in preorder. The function finds all nodes of requested type during
1895  * the traversal, and returns them in a list.
1896  */
1897 List *
FindNodesOfType(MultiNode * node,int type)1898 FindNodesOfType(MultiNode *node, int type)
1899 {
1900 	List *nodeList = NIL;
1901 
1902 	/* terminal condition for recursion */
1903 	if (node == NULL)
1904 	{
1905 		return NIL;
1906 	}
1907 
1908 	/* current node has expected node type */
1909 	int nodeType = CitusNodeTag(node);
1910 	if (nodeType == type)
1911 	{
1912 		nodeList = lappend(nodeList, node);
1913 	}
1914 
1915 	if (UnaryOperator(node))
1916 	{
1917 		MultiNode *childNode = ((MultiUnaryNode *) node)->childNode;
1918 		List *childNodeList = FindNodesOfType(childNode, type);
1919 
1920 		nodeList = list_concat(nodeList, childNodeList);
1921 	}
1922 	else if (BinaryOperator(node))
1923 	{
1924 		MultiNode *leftChildNode = ((MultiBinaryNode *) node)->leftChildNode;
1925 		MultiNode *rightChildNode = ((MultiBinaryNode *) node)->rightChildNode;
1926 
1927 		List *leftChildNodeList = FindNodesOfType(leftChildNode, type);
1928 		List *rightChildNodeList = FindNodesOfType(rightChildNode, type);
1929 
1930 		nodeList = list_concat(nodeList, leftChildNodeList);
1931 		nodeList = list_concat(nodeList, rightChildNodeList);
1932 	}
1933 
1934 	return nodeList;
1935 }
1936 
1937 
1938 /*
1939  * pull_var_clause_default calls pull_var_clause with the most commonly used
1940  * arguments for distributed planning.
1941  */
1942 List *
pull_var_clause_default(Node * node)1943 pull_var_clause_default(Node *node)
1944 {
1945 	/*
1946 	 * PVC_REJECT_PLACEHOLDERS is implicit if PVC_INCLUDE_PLACEHOLDERS
1947 	 * isn't specified.
1948 	 */
1949 	List *columnList = pull_var_clause(node, PVC_RECURSE_AGGREGATES |
1950 									   PVC_RECURSE_WINDOWFUNCS);
1951 
1952 	return columnList;
1953 }
1954 
1955 
1956 /*
1957  * ApplyJoinRule finds the join rule application function that corresponds to
1958  * the given join rule, and calls this function to create a new join node that
1959  * joins the left and right nodes together.
1960  */
1961 static MultiNode *
ApplyJoinRule(MultiNode * leftNode,MultiNode * rightNode,JoinRuleType ruleType,List * partitionColumnList,JoinType joinType,List * joinClauseList)1962 ApplyJoinRule(MultiNode *leftNode, MultiNode *rightNode, JoinRuleType ruleType,
1963 			  List *partitionColumnList, JoinType joinType, List *joinClauseList)
1964 {
1965 	List *leftTableIdList = OutputTableIdList(leftNode);
1966 	List *rightTableIdList = OutputTableIdList(rightNode);
1967 	int rightTableIdCount PG_USED_FOR_ASSERTS_ONLY = 0;
1968 
1969 	rightTableIdCount = list_length(rightTableIdList);
1970 	Assert(rightTableIdCount == 1);
1971 
1972 	/* find applicable join clauses between the left and right data sources */
1973 	uint32 rightTableId = (uint32) linitial_int(rightTableIdList);
1974 	List *applicableJoinClauses = ApplicableJoinClauses(leftTableIdList, rightTableId,
1975 														joinClauseList);
1976 
1977 	/* call the join rule application function to create the new join node */
1978 	RuleApplyFunction ruleApplyFunction = JoinRuleApplyFunction(ruleType);
1979 	MultiNode *multiNode = (*ruleApplyFunction)(leftNode, rightNode, partitionColumnList,
1980 												joinType, applicableJoinClauses);
1981 
1982 	if (joinType != JOIN_INNER && CitusIsA(multiNode, MultiJoin))
1983 	{
1984 		MultiJoin *joinNode = (MultiJoin *) multiNode;
1985 
1986 		/* preserve non-join clauses for OUTER joins */
1987 		joinNode->joinClauseList = list_copy(joinClauseList);
1988 	}
1989 
1990 	return multiNode;
1991 }
1992 
1993 
1994 /*
1995  * JoinRuleApplyFunction returns a function pointer for the rule application
1996  * function; this rule application function corresponds to the given rule type.
1997  * This function also initializes the rule application function array in a
1998  * static code block, if the array has not been initialized.
1999  */
2000 static RuleApplyFunction
JoinRuleApplyFunction(JoinRuleType ruleType)2001 JoinRuleApplyFunction(JoinRuleType ruleType)
2002 {
2003 	static bool ruleApplyFunctionInitialized = false;
2004 
2005 	if (!ruleApplyFunctionInitialized)
2006 	{
2007 		RuleApplyFunctionArray[REFERENCE_JOIN] = &ApplyReferenceJoin;
2008 		RuleApplyFunctionArray[LOCAL_PARTITION_JOIN] = &ApplyLocalJoin;
2009 		RuleApplyFunctionArray[SINGLE_HASH_PARTITION_JOIN] =
2010 			&ApplySingleHashPartitionJoin;
2011 		RuleApplyFunctionArray[SINGLE_RANGE_PARTITION_JOIN] =
2012 			&ApplySingleRangePartitionJoin;
2013 		RuleApplyFunctionArray[DUAL_PARTITION_JOIN] = &ApplyDualPartitionJoin;
2014 		RuleApplyFunctionArray[CARTESIAN_PRODUCT_REFERENCE_JOIN] =
2015 			&ApplyCartesianProductReferenceJoin;
2016 		RuleApplyFunctionArray[CARTESIAN_PRODUCT] = &ApplyCartesianProduct;
2017 
2018 		ruleApplyFunctionInitialized = true;
2019 	}
2020 
2021 	RuleApplyFunction ruleApplyFunction = RuleApplyFunctionArray[ruleType];
2022 	Assert(ruleApplyFunction != NULL);
2023 
2024 	return ruleApplyFunction;
2025 }
2026 
2027 
2028 /*
2029  * ApplyBroadcastJoin creates a new MultiJoin node that joins the left and the
2030  * right node. The new node uses the broadcast join rule to perform the join.
2031  */
2032 static MultiNode *
ApplyReferenceJoin(MultiNode * leftNode,MultiNode * rightNode,List * partitionColumnList,JoinType joinType,List * applicableJoinClauses)2033 ApplyReferenceJoin(MultiNode *leftNode, MultiNode *rightNode,
2034 				   List *partitionColumnList, JoinType joinType,
2035 				   List *applicableJoinClauses)
2036 {
2037 	MultiJoin *joinNode = CitusMakeNode(MultiJoin);
2038 	joinNode->joinRuleType = REFERENCE_JOIN;
2039 	joinNode->joinType = joinType;
2040 	joinNode->joinClauseList = applicableJoinClauses;
2041 
2042 	SetLeftChild((MultiBinaryNode *) joinNode, leftNode);
2043 	SetRightChild((MultiBinaryNode *) joinNode, rightNode);
2044 
2045 	return (MultiNode *) joinNode;
2046 }
2047 
2048 
2049 /*
2050  * ApplyCartesianProductReferenceJoin creates a new MultiJoin node that joins
2051  * the left and the right node. The new node uses the broadcast join rule to
2052  * perform the join.
2053  */
2054 static MultiNode *
ApplyCartesianProductReferenceJoin(MultiNode * leftNode,MultiNode * rightNode,List * partitionColumnList,JoinType joinType,List * applicableJoinClauses)2055 ApplyCartesianProductReferenceJoin(MultiNode *leftNode, MultiNode *rightNode,
2056 								   List *partitionColumnList, JoinType joinType,
2057 								   List *applicableJoinClauses)
2058 {
2059 	MultiJoin *joinNode = CitusMakeNode(MultiJoin);
2060 	joinNode->joinRuleType = CARTESIAN_PRODUCT_REFERENCE_JOIN;
2061 	joinNode->joinType = joinType;
2062 	joinNode->joinClauseList = applicableJoinClauses;
2063 
2064 	SetLeftChild((MultiBinaryNode *) joinNode, leftNode);
2065 	SetRightChild((MultiBinaryNode *) joinNode, rightNode);
2066 
2067 	return (MultiNode *) joinNode;
2068 }
2069 
2070 
2071 /*
2072  * ApplyLocalJoin creates a new MultiJoin node that joins the left and the right
2073  * node. The new node uses the local join rule to perform the join.
2074  */
2075 static MultiNode *
ApplyLocalJoin(MultiNode * leftNode,MultiNode * rightNode,List * partitionColumnList,JoinType joinType,List * applicableJoinClauses)2076 ApplyLocalJoin(MultiNode *leftNode, MultiNode *rightNode,
2077 			   List *partitionColumnList, JoinType joinType,
2078 			   List *applicableJoinClauses)
2079 {
2080 	MultiJoin *joinNode = CitusMakeNode(MultiJoin);
2081 	joinNode->joinRuleType = LOCAL_PARTITION_JOIN;
2082 	joinNode->joinType = joinType;
2083 	joinNode->joinClauseList = applicableJoinClauses;
2084 
2085 	SetLeftChild((MultiBinaryNode *) joinNode, leftNode);
2086 	SetRightChild((MultiBinaryNode *) joinNode, rightNode);
2087 
2088 	return (MultiNode *) joinNode;
2089 }
2090 
2091 
2092 /*
2093  * ApplySingleRangePartitionJoin is a wrapper around ApplySinglePartitionJoin()
2094  * which sets the joinRuleType properly.
2095  */
2096 static MultiNode *
ApplySingleRangePartitionJoin(MultiNode * leftNode,MultiNode * rightNode,List * partitionColumnList,JoinType joinType,List * applicableJoinClauses)2097 ApplySingleRangePartitionJoin(MultiNode *leftNode, MultiNode *rightNode,
2098 							  List *partitionColumnList, JoinType joinType,
2099 							  List *applicableJoinClauses)
2100 {
2101 	MultiJoin *joinNode =
2102 		ApplySinglePartitionJoin(leftNode, rightNode, partitionColumnList, joinType,
2103 								 applicableJoinClauses);
2104 
2105 	joinNode->joinRuleType = SINGLE_RANGE_PARTITION_JOIN;
2106 
2107 	return (MultiNode *) joinNode;
2108 }
2109 
2110 
2111 /*
2112  * ApplySingleHashPartitionJoin is a wrapper around ApplySinglePartitionJoin()
2113  * which sets the joinRuleType properly.
2114  */
2115 static MultiNode *
ApplySingleHashPartitionJoin(MultiNode * leftNode,MultiNode * rightNode,List * partitionColumnList,JoinType joinType,List * applicableJoinClauses)2116 ApplySingleHashPartitionJoin(MultiNode *leftNode, MultiNode *rightNode,
2117 							 List *partitionColumnList, JoinType joinType,
2118 							 List *applicableJoinClauses)
2119 {
2120 	MultiJoin *joinNode =
2121 		ApplySinglePartitionJoin(leftNode, rightNode, partitionColumnList, joinType,
2122 								 applicableJoinClauses);
2123 
2124 	joinNode->joinRuleType = SINGLE_HASH_PARTITION_JOIN;
2125 
2126 	return (MultiNode *) joinNode;
2127 }
2128 
2129 
2130 /*
2131  * ApplySinglePartitionJoin creates a new MultiJoin node that joins the left and
2132  * right node. The function also adds a MultiPartition node on top of the node
2133  * (left or right) that is not partitioned on the join column.
2134  */
2135 static MultiJoin *
ApplySinglePartitionJoin(MultiNode * leftNode,MultiNode * rightNode,List * partitionColumnList,JoinType joinType,List * applicableJoinClauses)2136 ApplySinglePartitionJoin(MultiNode *leftNode, MultiNode *rightNode,
2137 						 List *partitionColumnList, JoinType joinType,
2138 						 List *applicableJoinClauses)
2139 {
2140 	Var *partitionColumn = linitial(partitionColumnList);
2141 	uint32 partitionTableId = partitionColumn->varno;
2142 
2143 	/* create all operator structures up front */
2144 	MultiJoin *joinNode = CitusMakeNode(MultiJoin);
2145 	MultiCollect *collectNode = CitusMakeNode(MultiCollect);
2146 	MultiPartition *partitionNode = CitusMakeNode(MultiPartition);
2147 
2148 	/*
2149 	 * We first find the appropriate join clause. Then, we compare the partition
2150 	 * column against the join clause's columns. If one of the columns matches,
2151 	 * we introduce a (re-)partition operator for the other column.
2152 	 */
2153 	OpExpr *joinClause = SinglePartitionJoinClause(partitionColumnList,
2154 												   applicableJoinClauses);
2155 	Assert(joinClause != NULL);
2156 
2157 	/* both are verified in SinglePartitionJoinClause to not be NULL, assert is to guard */
2158 	Var *leftColumn = LeftColumnOrNULL(joinClause);
2159 	Var *rightColumn = RightColumnOrNULL(joinClause);
2160 
2161 	Assert(leftColumn != NULL);
2162 	Assert(rightColumn != NULL);
2163 
2164 	if (equal(partitionColumn, leftColumn))
2165 	{
2166 		partitionNode->partitionColumn = rightColumn;
2167 		partitionNode->splitPointTableId = partitionTableId;
2168 	}
2169 	else if (equal(partitionColumn, rightColumn))
2170 	{
2171 		partitionNode->partitionColumn = leftColumn;
2172 		partitionNode->splitPointTableId = partitionTableId;
2173 	}
2174 
2175 	/* determine the node the partition operator goes on top of */
2176 	List *rightTableIdList = OutputTableIdList(rightNode);
2177 	uint32 rightTableId = (uint32) linitial_int(rightTableIdList);
2178 	Assert(list_length(rightTableIdList) == 1);
2179 
2180 	/*
2181 	 * If the right child node is partitioned on the partition key column, we
2182 	 * add the partition operator on the left child node; and vice versa. Then,
2183 	 * we add a collect operator on top of the partition operator, and always
2184 	 * make sure that we have at most one relation on the right-hand side.
2185 	 */
2186 	if (partitionTableId == rightTableId)
2187 	{
2188 		SetChild((MultiUnaryNode *) partitionNode, leftNode);
2189 		SetChild((MultiUnaryNode *) collectNode, (MultiNode *) partitionNode);
2190 
2191 		SetLeftChild((MultiBinaryNode *) joinNode, (MultiNode *) collectNode);
2192 		SetRightChild((MultiBinaryNode *) joinNode, rightNode);
2193 	}
2194 	else
2195 	{
2196 		SetChild((MultiUnaryNode *) partitionNode, rightNode);
2197 		SetChild((MultiUnaryNode *) collectNode, (MultiNode *) partitionNode);
2198 
2199 		SetLeftChild((MultiBinaryNode *) joinNode, leftNode);
2200 		SetRightChild((MultiBinaryNode *) joinNode, (MultiNode *) collectNode);
2201 	}
2202 
2203 	/* finally set join operator fields */
2204 	joinNode->joinType = joinType;
2205 	joinNode->joinClauseList = applicableJoinClauses;
2206 
2207 	return joinNode;
2208 }
2209 
2210 
2211 /*
2212  * ApplyDualPartitionJoin creates a new MultiJoin node that joins the left and
2213  * right node. The function also adds two MultiPartition operators on top of
2214  * both nodes to repartition these nodes' data on the join clause columns.
2215  */
2216 static MultiNode *
ApplyDualPartitionJoin(MultiNode * leftNode,MultiNode * rightNode,List * partitionColumnList,JoinType joinType,List * applicableJoinClauses)2217 ApplyDualPartitionJoin(MultiNode *leftNode, MultiNode *rightNode,
2218 					   List *partitionColumnList, JoinType joinType,
2219 					   List *applicableJoinClauses)
2220 {
2221 	/* find the appropriate join clause */
2222 	OpExpr *joinClause = DualPartitionJoinClause(applicableJoinClauses);
2223 	Assert(joinClause != NULL);
2224 
2225 	/* both are verified in DualPartitionJoinClause to not be NULL, assert is to guard */
2226 	Var *leftColumn = LeftColumnOrNULL(joinClause);
2227 	Var *rightColumn = RightColumnOrNULL(joinClause);
2228 	Assert(leftColumn != NULL);
2229 	Assert(rightColumn != NULL);
2230 
2231 	List *rightTableIdList = OutputTableIdList(rightNode);
2232 	uint32 rightTableId = (uint32) linitial_int(rightTableIdList);
2233 	Assert(list_length(rightTableIdList) == 1);
2234 
2235 	MultiPartition *leftPartitionNode = CitusMakeNode(MultiPartition);
2236 	MultiPartition *rightPartitionNode = CitusMakeNode(MultiPartition);
2237 
2238 	/* find the partition node each join clause column belongs to */
2239 	if (leftColumn->varno == rightTableId)
2240 	{
2241 		leftPartitionNode->partitionColumn = rightColumn;
2242 		rightPartitionNode->partitionColumn = leftColumn;
2243 	}
2244 	else
2245 	{
2246 		leftPartitionNode->partitionColumn = leftColumn;
2247 		rightPartitionNode->partitionColumn = rightColumn;
2248 	}
2249 
2250 	/* add partition operators on top of left and right nodes */
2251 	SetChild((MultiUnaryNode *) leftPartitionNode, leftNode);
2252 	SetChild((MultiUnaryNode *) rightPartitionNode, rightNode);
2253 
2254 	/* add collect operators on top of the two partition operators */
2255 	MultiCollect *leftCollectNode = CitusMakeNode(MultiCollect);
2256 	MultiCollect *rightCollectNode = CitusMakeNode(MultiCollect);
2257 
2258 	SetChild((MultiUnaryNode *) leftCollectNode, (MultiNode *) leftPartitionNode);
2259 	SetChild((MultiUnaryNode *) rightCollectNode, (MultiNode *) rightPartitionNode);
2260 
2261 	/* add join operator on top of the two collect operators */
2262 	MultiJoin *joinNode = CitusMakeNode(MultiJoin);
2263 	joinNode->joinRuleType = DUAL_PARTITION_JOIN;
2264 	joinNode->joinType = joinType;
2265 	joinNode->joinClauseList = applicableJoinClauses;
2266 
2267 	SetLeftChild((MultiBinaryNode *) joinNode, (MultiNode *) leftCollectNode);
2268 	SetRightChild((MultiBinaryNode *) joinNode, (MultiNode *) rightCollectNode);
2269 
2270 	return (MultiNode *) joinNode;
2271 }
2272 
2273 
2274 /* Creates a cartesian product node that joins the left and the right node. */
2275 static MultiNode *
ApplyCartesianProduct(MultiNode * leftNode,MultiNode * rightNode,List * partitionColumnList,JoinType joinType,List * applicableJoinClauses)2276 ApplyCartesianProduct(MultiNode *leftNode, MultiNode *rightNode,
2277 					  List *partitionColumnList, JoinType joinType,
2278 					  List *applicableJoinClauses)
2279 {
2280 	MultiCartesianProduct *cartesianNode = CitusMakeNode(MultiCartesianProduct);
2281 
2282 	SetLeftChild((MultiBinaryNode *) cartesianNode, leftNode);
2283 	SetRightChild((MultiBinaryNode *) cartesianNode, rightNode);
2284 
2285 	return (MultiNode *) cartesianNode;
2286 }
2287 
2288 
2289 /*
2290  * OperatorImplementsEquality returns true if the given opno represents an
2291  * equality operator. The function retrieves btree interpretation list for this
2292  * opno and check if BTEqualStrategyNumber strategy is present.
2293  */
2294 bool
OperatorImplementsEquality(Oid opno)2295 OperatorImplementsEquality(Oid opno)
2296 {
2297 	bool equalityOperator = false;
2298 	List *btreeIntepretationList = get_op_btree_interpretation(opno);
2299 	ListCell *btreeInterpretationCell = NULL;
2300 	foreach(btreeInterpretationCell, btreeIntepretationList)
2301 	{
2302 		OpBtreeInterpretation *btreeIntepretation = (OpBtreeInterpretation *)
2303 													lfirst(btreeInterpretationCell);
2304 		if (btreeIntepretation->strategy == BTEqualStrategyNumber)
2305 		{
2306 			equalityOperator = true;
2307 			break;
2308 		}
2309 	}
2310 
2311 	return equalityOperator;
2312 }
2313