1 /*-------------------------------------------------------------------------
2  *
3  * multi_physical_planner.c
4  *	  Routines for creating physical plans from given multi-relational algebra
5  *	  trees.
6  *
7  * Copyright (c) Citus Data, Inc.
8  *
9  * $Id$
10  *
11  *-------------------------------------------------------------------------
12  */
13 
14 #include "postgres.h"
15 
16 #include "distributed/pg_version_constants.h"
17 
18 #include <math.h>
19 #include <stdint.h>
20 
21 #include "miscadmin.h"
22 
23 #include "access/genam.h"
24 #include "access/hash.h"
25 #include "access/heapam.h"
26 #include "access/nbtree.h"
27 #include "access/skey.h"
28 #include "access/xlog.h"
29 #include "catalog/pg_aggregate.h"
30 #include "catalog/pg_am.h"
31 #include "catalog/pg_operator.h"
32 #include "catalog/pg_type.h"
33 #include "commands/defrem.h"
34 #include "commands/sequence.h"
35 #include "distributed/backend_data.h"
36 #include "distributed/listutils.h"
37 #include "distributed/citus_nodefuncs.h"
38 #include "distributed/citus_nodes.h"
39 #include "distributed/citus_ruleutils.h"
40 #include "distributed/colocation_utils.h"
41 #include "distributed/deparse_shard_query.h"
42 #include "distributed/coordinator_protocol.h"
43 #include "distributed/metadata_cache.h"
44 #include "distributed/multi_router_planner.h"
45 #include "distributed/multi_join_order.h"
46 #include "distributed/multi_logical_optimizer.h"
47 #include "distributed/multi_logical_planner.h"
48 #include "distributed/multi_partitioning_utils.h"
49 #include "distributed/multi_physical_planner.h"
50 #include "distributed/log_utils.h"
51 #include "distributed/pg_dist_partition.h"
52 #include "distributed/pg_dist_shard.h"
53 #include "distributed/query_pushdown_planning.h"
54 #include "distributed/query_utils.h"
55 #include "distributed/shardinterval_utils.h"
56 #include "distributed/shard_pruning.h"
57 #include "distributed/string_utils.h"
58 
59 #include "distributed/worker_manager.h"
60 #include "distributed/worker_protocol.h"
61 #include "distributed/version_compat.h"
62 #include "nodes/makefuncs.h"
63 #include "nodes/nodeFuncs.h"
64 #include "optimizer/clauses.h"
65 #include "nodes/pathnodes.h"
66 #include "optimizer/optimizer.h"
67 #include "optimizer/restrictinfo.h"
68 #include "optimizer/tlist.h"
69 #include "parser/parse_relation.h"
70 #include "parser/parsetree.h"
71 #include "rewrite/rewriteManip.h"
72 #include "utils/builtins.h"
73 #include "utils/catcache.h"
74 #include "utils/datum.h"
75 #include "utils/fmgroids.h"
76 #include "utils/guc.h"
77 #include "utils/lsyscache.h"
78 #include "utils/memutils.h"
79 #include "utils/rel.h"
80 #include "utils/typcache.h"
81 
82 /* RepartitionJoinBucketCountPerNode determines bucket amount during repartitions */
83 int RepartitionJoinBucketCountPerNode = 8;
84 
85 /* Policy to use when assigning tasks to worker nodes */
86 int TaskAssignmentPolicy = TASK_ASSIGNMENT_GREEDY;
87 bool EnableUniqueJobIds = true;
88 
89 
90 /*
91  * OperatorCache is used for caching operator identifiers for given typeId,
92  * accessMethodId and strategyNumber. It is initialized to empty list as
93  * there are no items in the cache.
94  */
95 static List *OperatorCache = NIL;
96 
97 
98 /* context passed down in AddAnyValueAggregates mutator */
99 typedef struct AddAnyValueAggregatesContext
100 {
101 	/* SortGroupClauses corresponding to the GROUP BY clause */
102 	List *groupClauseList;
103 
104 	/* TargetEntry's to which the GROUP BY clauses refer */
105 	List *groupByTargetEntryList;
106 
107 	/*
108 	 * haveNonVarGrouping is true if there are expressions in the
109 	 * GROUP BY target entries. We use this as an optimisation to
110 	 * skip expensive checks when possible.
111 	 */
112 	bool haveNonVarGrouping;
113 } AddAnyValueAggregatesContext;
114 
115 
116 /* Local functions forward declarations for job creation */
117 static Job * BuildJobTree(MultiTreeRoot *multiTree);
118 static MultiNode * LeftMostNode(MultiTreeRoot *multiTree);
119 static Oid RangePartitionJoinBaseRelationId(MultiJoin *joinNode);
120 static MultiTable * FindTableNode(MultiNode *multiNode, int rangeTableId);
121 static Query * BuildJobQuery(MultiNode *multiNode, List *dependentJobList);
122 static List * BaseRangeTableList(MultiNode *multiNode);
123 static List * QueryTargetList(MultiNode *multiNode);
124 static List * TargetEntryList(List *expressionList);
125 static Node * AddAnyValueAggregates(Node *node, AddAnyValueAggregatesContext *context);
126 static List * QueryGroupClauseList(MultiNode *multiNode);
127 static List * QuerySelectClauseList(MultiNode *multiNode);
128 static List * QueryFromList(List *rangeTableList);
129 static Node * QueryJoinTree(MultiNode *multiNode, List *dependentJobList,
130 							List **rangeTableList);
131 static void SetJoinRelatedColumnsCompat(RangeTblEntry *rangeTableEntry,
132 										Oid leftRelId,
133 										Oid rightRelId,
134 										List *leftColumnVars,
135 										List *rightColumnVars);
136 static RangeTblEntry * JoinRangeTableEntry(JoinExpr *joinExpr, List *dependentJobList,
137 										   List *rangeTableList);
138 static int ExtractRangeTableId(Node *node);
139 static void ExtractColumns(RangeTblEntry *callingRTE, int rangeTableId,
140 						   List **columnNames, List **columnVars);
141 static RangeTblEntry * ConstructCallingRTE(RangeTblEntry *rangeTableEntry,
142 										   List *dependentJobList);
143 static Query * BuildSubqueryJobQuery(MultiNode *multiNode);
144 static void UpdateAllColumnAttributes(Node *columnContainer, List *rangeTableList,
145 									  List *dependentJobList);
146 static void UpdateColumnAttributes(Var *column, List *rangeTableList,
147 								   List *dependentJobList);
148 static Index NewTableId(Index originalTableId, List *rangeTableList);
149 static AttrNumber NewColumnId(Index originalTableId, AttrNumber originalColumnId,
150 							  RangeTblEntry *newRangeTableEntry, List *dependentJobList);
151 static Job * JobForRangeTable(List *jobList, RangeTblEntry *rangeTableEntry);
152 static Job * JobForTableIdList(List *jobList, List *searchedTableIdList);
153 static List * ChildNodeList(MultiNode *multiNode);
154 static Job * BuildJob(Query *jobQuery, List *dependentJobList);
155 static MapMergeJob * BuildMapMergeJob(Query *jobQuery, List *dependentJobList,
156 									  Var *partitionKey, PartitionType partitionType,
157 									  Oid baseRelationId,
158 									  BoundaryNodeJobType boundaryNodeJobType);
159 static uint32 HashPartitionCount(void);
160 static ArrayType * SplitPointObject(ShardInterval **shardIntervalArray,
161 									uint32 shardIntervalCount);
162 
163 /* Local functions forward declarations for task list creation and helper functions */
164 static Job * BuildJobTreeTaskList(Job *jobTree,
165 								  PlannerRestrictionContext *plannerRestrictionContext);
166 static bool IsInnerTableOfOuterJoin(RelationRestriction *relationRestriction);
167 static void ErrorIfUnsupportedShardDistribution(Query *query);
168 static Task * QueryPushdownTaskCreate(Query *originalQuery, int shardIndex,
169 									  RelationRestrictionContext *restrictionContext,
170 									  uint32 taskId,
171 									  TaskType taskType,
172 									  bool modifyRequiresCoordinatorEvaluation,
173 									  DeferredErrorMessage **planningError);
174 static bool ShardIntervalsEqual(FmgrInfo *comparisonFunction,
175 								Oid collation,
176 								ShardInterval *firstInterval,
177 								ShardInterval *secondInterval);
178 static List * SqlTaskList(Job *job);
179 static bool DependsOnHashPartitionJob(Job *job);
180 static uint32 AnchorRangeTableId(List *rangeTableList);
181 static List * BaseRangeTableIdList(List *rangeTableList);
182 static List * AnchorRangeTableIdList(List *rangeTableList, List *baseRangeTableIdList);
183 static void AdjustColumnOldAttributes(List *expressionList);
184 static List * RangeTableFragmentsList(List *rangeTableList, List *whereClauseList,
185 									  List *dependentJobList);
186 static OperatorCacheEntry * LookupOperatorByType(Oid typeId, Oid accessMethodId,
187 												 int16 strategyNumber);
188 static Oid GetOperatorByType(Oid typeId, Oid accessMethodId, int16 strategyNumber);
189 static List * FragmentCombinationList(List *rangeTableFragmentsList, Query *jobQuery,
190 									  List *dependentJobList);
191 static JoinSequenceNode * JoinSequenceArray(List *rangeTableFragmentsList,
192 											Query *jobQuery, List *dependentJobList);
193 static bool PartitionedOnColumn(Var *column, List *rangeTableList,
194 								List *dependentJobList);
195 static void CheckJoinBetweenColumns(OpExpr *joinClause);
196 static List * FindRangeTableFragmentsList(List *rangeTableFragmentsList, int taskId);
197 static bool JoinPrunable(RangeTableFragment *leftFragment,
198 						 RangeTableFragment *rightFragment);
199 static ShardInterval * FragmentInterval(RangeTableFragment *fragment);
200 static StringInfo FragmentIntervalString(ShardInterval *fragmentInterval);
201 static List * DataFetchTaskList(uint64 jobId, uint32 taskIdIndex, List *fragmentList);
202 static StringInfo DatumArrayString(Datum *datumArray, uint32 datumCount, Oid datumTypeId);
203 static List * BuildRelationShardList(List *rangeTableList, List *fragmentList);
204 static void UpdateRangeTableAlias(List *rangeTableList, List *fragmentList);
205 static Alias * FragmentAlias(RangeTblEntry *rangeTableEntry,
206 							 RangeTableFragment *fragment);
207 static uint64 AnchorShardId(List *fragmentList, uint32 anchorRangeTableId);
208 static List * PruneSqlTaskDependencies(List *sqlTaskList);
209 static List * AssignTaskList(List *sqlTaskList);
210 static bool HasMergeTaskDependencies(List *sqlTaskList);
211 static List * GreedyAssignTaskList(List *taskList);
212 static Task * GreedyAssignTask(WorkerNode *workerNode, List *taskList,
213 							   List *activeShardPlacementLists);
214 static List * ReorderAndAssignTaskList(List *taskList,
215 									   ReorderFunction reorderFunction);
216 static int CompareTasksByShardId(const void *leftElement, const void *rightElement);
217 static List * ActiveShardPlacementLists(List *taskList);
218 static List * ActivePlacementList(List *placementList);
219 static List * LeftRotateList(List *list, uint32 rotateCount);
220 static List * FindDependentMergeTaskList(Task *sqlTask);
221 static List * AssignDualHashTaskList(List *taskList);
222 static void AssignDataFetchDependencies(List *taskList);
223 static uint32 TaskListHighestTaskId(List *taskList);
224 static List * MapTaskList(MapMergeJob *mapMergeJob, List *filterTaskList);
225 static StringInfo CreateMapQueryString(MapMergeJob *mapMergeJob, Task *filterTask,
226 									   uint32 partitionColumnIndex);
227 static List * MergeTaskList(MapMergeJob *mapMergeJob, List *mapTaskList,
228 							uint32 taskIdIndex);
229 static StringInfo ColumnNameArrayString(uint32 columnCount, uint64 generatingJobId);
230 static StringInfo ColumnTypeArrayString(List *targetEntryList);
231 static bool CoPlacedShardIntervals(ShardInterval *firstInterval,
232 								   ShardInterval *secondInterval);
233 
234 static List * FetchEqualityAttrNumsForRTEOpExpr(OpExpr *opExpr);
235 static List * FetchEqualityAttrNumsForRTEBoolExpr(BoolExpr *boolExpr);
236 static List * FetchEqualityAttrNumsForList(List *nodeList);
237 static int PartitionColumnIndex(Var *targetVar, List *targetList);
238 #if PG_VERSION_NUM >= PG_VERSION_13
239 static List * GetColumnOriginalIndexes(Oid relationId);
240 #endif
241 
242 
243 /*
244  * CreatePhysicalDistributedPlan is the entry point for physical plan generation. The
245  * function builds the physical plan; this plan includes the list of tasks to be
246  * executed on worker nodes, and the final query to run on the master node.
247  */
248 DistributedPlan *
CreatePhysicalDistributedPlan(MultiTreeRoot * multiTree,PlannerRestrictionContext * plannerRestrictionContext)249 CreatePhysicalDistributedPlan(MultiTreeRoot *multiTree,
250 							  PlannerRestrictionContext *plannerRestrictionContext)
251 {
252 	/* build the worker job tree and check that we only have one job in the tree */
253 	Job *workerJob = BuildJobTree(multiTree);
254 
255 	/* create the tree of executable tasks for the worker job */
256 	workerJob = BuildJobTreeTaskList(workerJob, plannerRestrictionContext);
257 
258 	/* build the final merge query to execute on the master */
259 	List *masterDependentJobList = list_make1(workerJob);
260 	Query *combineQuery = BuildJobQuery((MultiNode *) multiTree, masterDependentJobList);
261 
262 	DistributedPlan *distributedPlan = CitusMakeNode(DistributedPlan);
263 	distributedPlan->workerJob = workerJob;
264 	distributedPlan->combineQuery = combineQuery;
265 	distributedPlan->modLevel = ROW_MODIFY_READONLY;
266 	distributedPlan->expectResults = true;
267 
268 	return distributedPlan;
269 }
270 
271 
272 /*
273  * ModifyLocalTableJob returns true if the given task contains
274  * a modification of local table.
275  */
276 bool
ModifyLocalTableJob(Job * job)277 ModifyLocalTableJob(Job *job)
278 {
279 	if (job == NULL)
280 	{
281 		return false;
282 	}
283 	List *taskList = job->taskList;
284 	if (list_length(taskList) != 1)
285 	{
286 		return false;
287 	}
288 	Task *singleTask = (Task *) linitial(taskList);
289 	return singleTask->isLocalTableModification;
290 }
291 
292 
293 /*
294  * BuildJobTree builds the physical job tree from the given logical plan tree.
295  * The function walks over the logical plan from the bottom up, finds boundaries
296  * for jobs, and creates the query structure for each job. The function also
297  * sets dependencies between jobs, and then returns the top level worker job.
298  */
299 static Job *
BuildJobTree(MultiTreeRoot * multiTree)300 BuildJobTree(MultiTreeRoot *multiTree)
301 {
302 	/* start building the tree from the deepest left node */
303 	MultiNode *leftMostNode = LeftMostNode(multiTree);
304 	MultiNode *currentNode = leftMostNode;
305 	MultiNode *parentNode = ParentNode(currentNode);
306 	List *loopDependentJobList = NIL;
307 	Job *topLevelJob = NULL;
308 
309 	while (parentNode != NULL)
310 	{
311 		CitusNodeTag currentNodeType = CitusNodeTag(currentNode);
312 		CitusNodeTag parentNodeType = CitusNodeTag(parentNode);
313 		BoundaryNodeJobType boundaryNodeJobType = JOB_INVALID_FIRST;
314 
315 		/* we first check if this node forms the boundary for a remote job */
316 		if (currentNodeType == T_MultiJoin)
317 		{
318 			MultiJoin *joinNode = (MultiJoin *) currentNode;
319 			if (joinNode->joinRuleType == SINGLE_HASH_PARTITION_JOIN ||
320 				joinNode->joinRuleType == SINGLE_RANGE_PARTITION_JOIN ||
321 				joinNode->joinRuleType == DUAL_PARTITION_JOIN)
322 			{
323 				boundaryNodeJobType = JOIN_MAP_MERGE_JOB;
324 			}
325 		}
326 		else if (currentNodeType == T_MultiPartition &&
327 				 parentNodeType == T_MultiExtendedOp)
328 		{
329 			boundaryNodeJobType = SUBQUERY_MAP_MERGE_JOB;
330 		}
331 		else if (currentNodeType == T_MultiCollect &&
332 				 parentNodeType != T_MultiPartition)
333 		{
334 			boundaryNodeJobType = TOP_LEVEL_WORKER_JOB;
335 		}
336 
337 		/*
338 		 * If this node is at the boundary for a repartition or top level worker
339 		 * job, we build the corresponding job(s) and set their dependencies.
340 		 */
341 		if (boundaryNodeJobType == JOIN_MAP_MERGE_JOB)
342 		{
343 			MultiJoin *joinNode = (MultiJoin *) currentNode;
344 			MultiNode *leftChildNode = joinNode->binaryNode.leftChildNode;
345 			MultiNode *rightChildNode = joinNode->binaryNode.rightChildNode;
346 
347 			PartitionType partitionType = PARTITION_INVALID_FIRST;
348 			Oid baseRelationId = InvalidOid;
349 
350 			if (joinNode->joinRuleType == SINGLE_RANGE_PARTITION_JOIN)
351 			{
352 				partitionType = RANGE_PARTITION_TYPE;
353 				baseRelationId = RangePartitionJoinBaseRelationId(joinNode);
354 			}
355 			else if (joinNode->joinRuleType == SINGLE_HASH_PARTITION_JOIN)
356 			{
357 				partitionType = SINGLE_HASH_PARTITION_TYPE;
358 				baseRelationId = RangePartitionJoinBaseRelationId(joinNode);
359 			}
360 			else if (joinNode->joinRuleType == DUAL_PARTITION_JOIN)
361 			{
362 				partitionType = DUAL_HASH_PARTITION_TYPE;
363 			}
364 
365 			if (CitusIsA(leftChildNode, MultiPartition))
366 			{
367 				MultiPartition *partitionNode = (MultiPartition *) leftChildNode;
368 				MultiNode *queryNode = GrandChildNode((MultiUnaryNode *) partitionNode);
369 				Var *partitionKey = partitionNode->partitionColumn;
370 
371 				/* build query and partition job */
372 				List *dependentJobList = list_copy(loopDependentJobList);
373 				Query *jobQuery = BuildJobQuery(queryNode, dependentJobList);
374 
375 				MapMergeJob *mapMergeJob = BuildMapMergeJob(jobQuery, dependentJobList,
376 															partitionKey, partitionType,
377 															baseRelationId,
378 															JOIN_MAP_MERGE_JOB);
379 
380 				/* reset dependent job list */
381 				loopDependentJobList = NIL;
382 				loopDependentJobList = list_make1(mapMergeJob);
383 			}
384 
385 			if (CitusIsA(rightChildNode, MultiPartition))
386 			{
387 				MultiPartition *partitionNode = (MultiPartition *) rightChildNode;
388 				MultiNode *queryNode = GrandChildNode((MultiUnaryNode *) partitionNode);
389 				Var *partitionKey = partitionNode->partitionColumn;
390 
391 				/*
392 				 * The right query and right partition job do not depend on any
393 				 * jobs since our logical plan tree is left deep.
394 				 */
395 				Query *jobQuery = BuildJobQuery(queryNode, NIL);
396 				MapMergeJob *mapMergeJob = BuildMapMergeJob(jobQuery, NIL,
397 															partitionKey, partitionType,
398 															baseRelationId,
399 															JOIN_MAP_MERGE_JOB);
400 
401 				/* append to the dependent job list for on-going dependencies */
402 				loopDependentJobList = lappend(loopDependentJobList, mapMergeJob);
403 			}
404 		}
405 		else if (boundaryNodeJobType == TOP_LEVEL_WORKER_JOB)
406 		{
407 			MultiNode *childNode = ChildNode((MultiUnaryNode *) currentNode);
408 			List *dependentJobList = list_copy(loopDependentJobList);
409 			bool subqueryPushdown = false;
410 
411 			List *subqueryMultiTableList = SubqueryMultiTableList(childNode);
412 			int subqueryCount = list_length(subqueryMultiTableList);
413 
414 			if (subqueryCount > 0)
415 			{
416 				subqueryPushdown = true;
417 			}
418 
419 			/*
420 			 * Build top level query. If subquery pushdown is set, we use
421 			 * sligthly different version of BuildJobQuery(). They are similar
422 			 * but we don't need some parts of BuildJobQuery() for subquery
423 			 * pushdown such as updating column attributes etc.
424 			 */
425 			if (subqueryPushdown)
426 			{
427 				Query *topLevelQuery = BuildSubqueryJobQuery(childNode);
428 
429 				topLevelJob = BuildJob(topLevelQuery, dependentJobList);
430 				topLevelJob->subqueryPushdown = true;
431 			}
432 			else
433 			{
434 				Query *topLevelQuery = BuildJobQuery(childNode, dependentJobList);
435 
436 				topLevelJob = BuildJob(topLevelQuery, dependentJobList);
437 			}
438 		}
439 
440 		/* walk up the tree */
441 		currentNode = parentNode;
442 		parentNode = ParentNode(currentNode);
443 	}
444 
445 	return topLevelJob;
446 }
447 
448 
449 /*
450  * LeftMostNode finds the deepest left node in the left-deep logical plan tree.
451  * We build the physical plan by traversing the logical plan from the bottom up;
452  * and this function helps us find the bottom of the logical tree.
453  */
454 static MultiNode *
LeftMostNode(MultiTreeRoot * multiTree)455 LeftMostNode(MultiTreeRoot *multiTree)
456 {
457 	MultiNode *currentNode = (MultiNode *) multiTree;
458 	MultiNode *leftChildNode = ChildNode((MultiUnaryNode *) multiTree);
459 
460 	while (leftChildNode != NULL)
461 	{
462 		currentNode = leftChildNode;
463 
464 		if (UnaryOperator(currentNode))
465 		{
466 			leftChildNode = ChildNode((MultiUnaryNode *) currentNode);
467 		}
468 		else if (BinaryOperator(currentNode))
469 		{
470 			MultiBinaryNode *binaryNode = (MultiBinaryNode *) currentNode;
471 			leftChildNode = binaryNode->leftChildNode;
472 		}
473 	}
474 
475 	return currentNode;
476 }
477 
478 
479 /*
480  * RangePartitionJoinBaseRelationId finds partition node from join node, and
481  * returns base relation id of this node. Note that this function assumes that
482  * given join node is range partition join type.
483  */
484 static Oid
RangePartitionJoinBaseRelationId(MultiJoin * joinNode)485 RangePartitionJoinBaseRelationId(MultiJoin *joinNode)
486 {
487 	MultiPartition *partitionNode = NULL;
488 
489 	MultiNode *leftChildNode = joinNode->binaryNode.leftChildNode;
490 	MultiNode *rightChildNode = joinNode->binaryNode.rightChildNode;
491 
492 	if (CitusIsA(leftChildNode, MultiPartition))
493 	{
494 		partitionNode = (MultiPartition *) leftChildNode;
495 	}
496 	else if (CitusIsA(rightChildNode, MultiPartition))
497 	{
498 		partitionNode = (MultiPartition *) rightChildNode;
499 	}
500 
501 	Index baseTableId = partitionNode->splitPointTableId;
502 	MultiTable *baseTable = FindTableNode((MultiNode *) joinNode, baseTableId);
503 	Oid baseRelationId = baseTable->relationId;
504 
505 	return baseRelationId;
506 }
507 
508 
509 /*
510  * FindTableNode walks over the given logical plan tree, and returns the table
511  * node that corresponds to the given range tableId.
512  */
513 static MultiTable *
FindTableNode(MultiNode * multiNode,int rangeTableId)514 FindTableNode(MultiNode *multiNode, int rangeTableId)
515 {
516 	MultiTable *foundTableNode = NULL;
517 	List *tableNodeList = FindNodesOfType(multiNode, T_MultiTable);
518 	ListCell *tableNodeCell = NULL;
519 
520 	foreach(tableNodeCell, tableNodeList)
521 	{
522 		MultiTable *tableNode = (MultiTable *) lfirst(tableNodeCell);
523 		if (tableNode->rangeTableId == rangeTableId)
524 		{
525 			foundTableNode = tableNode;
526 			break;
527 		}
528 	}
529 
530 	Assert(foundTableNode != NULL);
531 	return foundTableNode;
532 }
533 
534 
535 /*
536  * BuildJobQuery traverses the given logical plan tree, determines the job that
537  * corresponds to this part of the tree, and builds the query structure for that
538  * particular job. The function assumes that jobs this particular job depends on
539  * have already been built, as their output is needed to build the query.
540  */
541 static Query *
BuildJobQuery(MultiNode * multiNode,List * dependentJobList)542 BuildJobQuery(MultiNode *multiNode, List *dependentJobList)
543 {
544 	bool updateColumnAttributes = false;
545 	List *targetList = NIL;
546 	List *sortClauseList = NIL;
547 	Node *limitCount = NULL;
548 	Node *limitOffset = NULL;
549 #if PG_VERSION_NUM >= PG_VERSION_13
550 	LimitOption limitOption = LIMIT_OPTION_DEFAULT;
551 #endif
552 	Node *havingQual = NULL;
553 	bool hasDistinctOn = false;
554 	List *distinctClause = NIL;
555 	bool isRepartitionJoin = false;
556 	bool hasWindowFuncs = false;
557 	List *windowClause = NIL;
558 
559 	/* we start building jobs from below the collect node */
560 	Assert(!CitusIsA(multiNode, MultiCollect));
561 
562 	/*
563 	 * First check if we are building a master/worker query. If we are building
564 	 * a worker query, we update the column attributes for target entries, select
565 	 * and join columns. Because if underlying query includes repartition joins,
566 	 * then we create multiple queries from a join. In this case, range table lists
567 	 * and column lists are subject to change.
568 	 *
569 	 * Note that we don't do this for master queries, as column attributes for
570 	 * master target entries are already set during the master/worker split.
571 	 */
572 	MultiNode *parentNode = ParentNode(multiNode);
573 	if (parentNode != NULL)
574 	{
575 		updateColumnAttributes = true;
576 	}
577 
578 	/*
579 	 * If we are building this query on a repartitioned subquery job then we
580 	 * don't need to update column attributes.
581 	 */
582 	if (dependentJobList != NIL)
583 	{
584 		Job *job = (Job *) linitial(dependentJobList);
585 		if (CitusIsA(job, MapMergeJob))
586 		{
587 			MapMergeJob *mapMergeJob = (MapMergeJob *) job;
588 			isRepartitionJoin = true;
589 			if (mapMergeJob->reduceQuery)
590 			{
591 				updateColumnAttributes = false;
592 			}
593 		}
594 	}
595 
596 	/*
597 	 * If we have an extended operator, then we copy the operator's target list.
598 	 * Otherwise, we use the target list based on the MultiProject node at this
599 	 * level in the query tree.
600 	 */
601 	List *extendedOpNodeList = FindNodesOfType(multiNode, T_MultiExtendedOp);
602 	if (extendedOpNodeList != NIL)
603 	{
604 		MultiExtendedOp *extendedOp = (MultiExtendedOp *) linitial(extendedOpNodeList);
605 		targetList = copyObject(extendedOp->targetList);
606 		distinctClause = extendedOp->distinctClause;
607 		hasDistinctOn = extendedOp->hasDistinctOn;
608 		hasWindowFuncs = extendedOp->hasWindowFuncs;
609 		windowClause = extendedOp->windowClause;
610 	}
611 	else
612 	{
613 		targetList = QueryTargetList(multiNode);
614 	}
615 
616 	/* build the join tree and the range table list */
617 	List *rangeTableList = BaseRangeTableList(multiNode);
618 	Node *joinRoot = QueryJoinTree(multiNode, dependentJobList, &rangeTableList);
619 
620 	/* update the column attributes for target entries */
621 	if (updateColumnAttributes)
622 	{
623 		UpdateAllColumnAttributes((Node *) targetList, rangeTableList, dependentJobList);
624 	}
625 
626 	/* extract limit count/offset and sort clauses */
627 	if (extendedOpNodeList != NIL)
628 	{
629 		MultiExtendedOp *extendedOp = (MultiExtendedOp *) linitial(extendedOpNodeList);
630 
631 		limitCount = extendedOp->limitCount;
632 		limitOffset = extendedOp->limitOffset;
633 #if PG_VERSION_NUM >= PG_VERSION_13
634 		limitOption = extendedOp->limitOption;
635 #endif
636 		sortClauseList = extendedOp->sortClauseList;
637 		havingQual = extendedOp->havingQual;
638 	}
639 
640 	/* build group clauses */
641 	List *groupClauseList = QueryGroupClauseList(multiNode);
642 
643 
644 	/* build the where clause list using select predicates */
645 	List *selectClauseList = QuerySelectClauseList(multiNode);
646 
647 	/* set correct column attributes for select and having clauses */
648 	if (updateColumnAttributes)
649 	{
650 		UpdateAllColumnAttributes((Node *) selectClauseList, rangeTableList,
651 								  dependentJobList);
652 		UpdateAllColumnAttributes(havingQual, rangeTableList, dependentJobList);
653 	}
654 
655 	/*
656 	 * Group by on primary key allows all columns to appear in the target
657 	 * list, but after re-partitioning we will be querying an intermediate
658 	 * table that does not have the primary key. We therefore wrap all the
659 	 * columns that do not appear in the GROUP BY in an any_value aggregate.
660 	 */
661 	if (groupClauseList != NIL && isRepartitionJoin)
662 	{
663 		targetList = (List *) WrapUngroupedVarsInAnyValueAggregate(
664 			(Node *) targetList, groupClauseList, targetList, true);
665 
666 		havingQual = WrapUngroupedVarsInAnyValueAggregate(
667 			(Node *) havingQual, groupClauseList, targetList, false);
668 	}
669 
670 	/*
671 	 * Build the From/Where construct. We keep the where-clause list implicitly
672 	 * AND'd, since both partition and join pruning depends on the clauses being
673 	 * expressed as a list.
674 	 */
675 	FromExpr *joinTree = makeNode(FromExpr);
676 	joinTree->quals = (Node *) list_copy(selectClauseList);
677 	joinTree->fromlist = list_make1(joinRoot);
678 
679 	/* build the query structure for this job */
680 	Query *jobQuery = makeNode(Query);
681 	jobQuery->commandType = CMD_SELECT;
682 	jobQuery->querySource = QSRC_ORIGINAL;
683 	jobQuery->canSetTag = true;
684 	jobQuery->rtable = rangeTableList;
685 	jobQuery->targetList = targetList;
686 	jobQuery->jointree = joinTree;
687 	jobQuery->sortClause = sortClauseList;
688 	jobQuery->groupClause = groupClauseList;
689 	jobQuery->limitOffset = limitOffset;
690 	jobQuery->limitCount = limitCount;
691 #if PG_VERSION_NUM >= PG_VERSION_13
692 	jobQuery->limitOption = limitOption;
693 #endif
694 	jobQuery->havingQual = havingQual;
695 	jobQuery->hasAggs = contain_aggs_of_level((Node *) targetList, 0) ||
696 						contain_aggs_of_level((Node *) havingQual, 0);
697 	jobQuery->distinctClause = distinctClause;
698 	jobQuery->hasDistinctOn = hasDistinctOn;
699 	jobQuery->windowClause = windowClause;
700 	jobQuery->hasWindowFuncs = hasWindowFuncs;
701 	jobQuery->hasSubLinks = checkExprHasSubLink((Node *) jobQuery);
702 
703 	Assert(jobQuery->hasWindowFuncs == contain_window_function((Node *) jobQuery));
704 
705 	return jobQuery;
706 }
707 
708 
709 /*
710  * BaseRangeTableList returns the list of range table entries for base tables in
711  * the query. These base tables stand in contrast to derived tables generated by
712  * repartition jobs. Note that this function only considers base tables relevant
713  * to the current query, and does not visit nodes under the collect node.
714  */
715 static List *
BaseRangeTableList(MultiNode * multiNode)716 BaseRangeTableList(MultiNode *multiNode)
717 {
718 	List *baseRangeTableList = NIL;
719 	List *pendingNodeList = list_make1(multiNode);
720 
721 	while (pendingNodeList != NIL)
722 	{
723 		MultiNode *currMultiNode = (MultiNode *) linitial(pendingNodeList);
724 		CitusNodeTag nodeType = CitusNodeTag(currMultiNode);
725 		pendingNodeList = list_delete_first(pendingNodeList);
726 
727 		if (nodeType == T_MultiTable)
728 		{
729 			/*
730 			 * We represent subqueries as MultiTables, and so for base table
731 			 * entries we skip the subquery ones.
732 			 */
733 			MultiTable *multiTable = (MultiTable *) currMultiNode;
734 			if (multiTable->relationId != SUBQUERY_RELATION_ID &&
735 				multiTable->relationId != SUBQUERY_PUSHDOWN_RELATION_ID)
736 			{
737 				RangeTblEntry *rangeTableEntry = makeNode(RangeTblEntry);
738 				rangeTableEntry->inFromCl = true;
739 				rangeTableEntry->eref = multiTable->referenceNames;
740 				rangeTableEntry->alias = multiTable->alias;
741 				rangeTableEntry->relid = multiTable->relationId;
742 				rangeTableEntry->inh = multiTable->includePartitions;
743 
744 				SetRangeTblExtraData(rangeTableEntry, CITUS_RTE_RELATION, NULL, NULL,
745 									 list_make1_int(multiTable->rangeTableId),
746 									 NIL, NIL, NIL, NIL);
747 
748 				baseRangeTableList = lappend(baseRangeTableList, rangeTableEntry);
749 			}
750 		}
751 
752 		/* do not visit nodes that belong to remote queries */
753 		if (nodeType != T_MultiCollect)
754 		{
755 			List *childNodeList = ChildNodeList(currMultiNode);
756 			pendingNodeList = list_concat(pendingNodeList, childNodeList);
757 		}
758 	}
759 
760 	return baseRangeTableList;
761 }
762 
763 
764 /*
765  * DerivedRangeTableEntry builds a range table entry for the derived table. This
766  * derived table either represents the output of a repartition job; or the data
767  * on worker nodes in case of the master node query.
768  */
769 RangeTblEntry *
DerivedRangeTableEntry(MultiNode * multiNode,List * columnList,List * tableIdList,List * funcColumnNames,List * funcColumnTypes,List * funcColumnTypeMods,List * funcCollations)770 DerivedRangeTableEntry(MultiNode *multiNode, List *columnList, List *tableIdList,
771 					   List *funcColumnNames, List *funcColumnTypes,
772 					   List *funcColumnTypeMods, List *funcCollations)
773 {
774 	RangeTblEntry *rangeTableEntry = makeNode(RangeTblEntry);
775 	rangeTableEntry->inFromCl = true;
776 	rangeTableEntry->eref = makeNode(Alias);
777 	rangeTableEntry->eref->colnames = columnList;
778 
779 	SetRangeTblExtraData(rangeTableEntry, CITUS_RTE_REMOTE_QUERY, NULL, NULL, tableIdList,
780 						 funcColumnNames, funcColumnTypes, funcColumnTypeMods,
781 						 funcCollations);
782 
783 	return rangeTableEntry;
784 }
785 
786 
787 /*
788  * DerivedColumnNameList builds a column name list for derived (intermediate)
789  * tables. These column names are then used when building the create stament
790  * query string for derived tables.
791  */
792 List *
DerivedColumnNameList(uint32 columnCount,uint64 generatingJobId)793 DerivedColumnNameList(uint32 columnCount, uint64 generatingJobId)
794 {
795 	List *columnNameList = NIL;
796 
797 	for (uint32 columnIndex = 0; columnIndex < columnCount; columnIndex++)
798 	{
799 		StringInfo columnName = makeStringInfo();
800 
801 		appendStringInfo(columnName, "intermediate_column_");
802 		appendStringInfo(columnName, UINT64_FORMAT "_", generatingJobId);
803 		appendStringInfo(columnName, "%u", columnIndex);
804 
805 		Value *columnValue = makeString(columnName->data);
806 		columnNameList = lappend(columnNameList, columnValue);
807 	}
808 
809 	return columnNameList;
810 }
811 
812 
813 /*
814  * QueryTargetList returns the target entry list for the projected columns
815  * needed to evaluate the operators above the given multiNode. To do this,
816  * the function retrieves a list of all MultiProject nodes below the given
817  * node and picks the columns from the top-most MultiProject node, as this
818  * will be the minimal list of columns needed. Note that this function relies
819  * on a pre-order traversal of the operator tree by the function FindNodesOfType.
820  */
821 static List *
QueryTargetList(MultiNode * multiNode)822 QueryTargetList(MultiNode *multiNode)
823 {
824 	List *projectNodeList = FindNodesOfType(multiNode, T_MultiProject);
825 	if (list_length(projectNodeList) == 0)
826 	{
827 		/*
828 		 * The physical planner assumes that all worker queries would have
829 		 * target list entries based on the fact that at least the column
830 		 * on the JOINs have to be on the target list. However, there is
831 		 * an exception to that if there is a cartesian product join and
832 		 * there is no additional target list entries belong to one side
833 		 * of the JOIN. Once we support cartesian product join, we should
834 		 * remove this error.
835 		 */
836 		ereport(ERROR, (errcode(ERRCODE_FEATURE_NOT_SUPPORTED),
837 						errmsg("cannot perform distributed planning on this query"),
838 						errdetail("Cartesian products are currently unsupported")));
839 	}
840 
841 	MultiProject *topProjectNode = (MultiProject *) linitial(projectNodeList);
842 	List *columnList = topProjectNode->columnList;
843 	List *queryTargetList = TargetEntryList(columnList);
844 
845 	Assert(queryTargetList != NIL);
846 	return queryTargetList;
847 }
848 
849 
850 /*
851  * TargetEntryList creates a target entry for each expression in the given list,
852  * and returns the newly created target entries in a list.
853  */
854 static List *
TargetEntryList(List * expressionList)855 TargetEntryList(List *expressionList)
856 {
857 	List *targetEntryList = NIL;
858 	ListCell *expressionCell = NULL;
859 
860 	foreach(expressionCell, expressionList)
861 	{
862 		Expr *expression = (Expr *) lfirst(expressionCell);
863 
864 		TargetEntry *targetEntry = makeTargetEntry(expression,
865 												   list_length(targetEntryList) + 1,
866 												   NULL, false);
867 		targetEntryList = lappend(targetEntryList, targetEntry);
868 	}
869 
870 	return targetEntryList;
871 }
872 
873 
874 /*
875  * WrapUngroupedVarsInAnyValueAggregate finds Var nodes in the expression
876  * that do not refer to any GROUP BY column and wraps them in an any_value
877  * aggregate. These columns are allowed when the GROUP BY is on a primary
878  * key of a relation, but not if we wrap the relation in a subquery.
879  * However, since we still know the value is unique, any_value gives the
880  * right result.
881  */
882 Node *
WrapUngroupedVarsInAnyValueAggregate(Node * expression,List * groupClauseList,List * targetList,bool checkExpressionEquality)883 WrapUngroupedVarsInAnyValueAggregate(Node *expression, List *groupClauseList,
884 									 List *targetList, bool checkExpressionEquality)
885 {
886 	if (expression == NULL)
887 	{
888 		return NULL;
889 	}
890 
891 	AddAnyValueAggregatesContext context;
892 	context.groupClauseList = groupClauseList;
893 	context.groupByTargetEntryList = GroupTargetEntryList(groupClauseList, targetList);
894 	context.haveNonVarGrouping = false;
895 
896 	if (checkExpressionEquality)
897 	{
898 		/*
899 		 * If the GROUP BY contains non-Var expressions, we need to do an expensive
900 		 * subexpression equality check.
901 		 */
902 		TargetEntry *targetEntry = NULL;
903 		foreach_ptr(targetEntry, context.groupByTargetEntryList)
904 		{
905 			if (!IsA(targetEntry->expr, Var))
906 			{
907 				context.haveNonVarGrouping = true;
908 				break;
909 			}
910 		}
911 	}
912 
913 	/* put the result in the same memory context */
914 	MemoryContext nodeContext = GetMemoryChunkContext(expression);
915 	MemoryContext oldContext = MemoryContextSwitchTo(nodeContext);
916 
917 	Node *result = expression_tree_mutator(expression, AddAnyValueAggregates,
918 										   &context);
919 
920 	MemoryContextSwitchTo(oldContext);
921 
922 	return result;
923 }
924 
925 
926 /*
927  * AddAnyValueAggregates wraps all vars that do not appear in the GROUP BY
928  * clause or are inside an aggregate function in an any_value aggregate
929  * function. This is needed because postgres allows columns that are not
930  * in the GROUP BY to appear on the target list as long as the primary key
931  * of the table is in the GROUP BY, but we sometimes wrap the join tree
932  * in a subquery in which case the primary key information is lost.
933  *
934  * This function copies parts of the node tree, but may contain references
935  * to the original node tree.
936  *
937  * The implementation is derived from / inspired by
938  * check_ungrouped_columns_walker.
939  */
940 static Node *
AddAnyValueAggregates(Node * node,AddAnyValueAggregatesContext * context)941 AddAnyValueAggregates(Node *node, AddAnyValueAggregatesContext *context)
942 {
943 	if (node == NULL)
944 	{
945 		return node;
946 	}
947 
948 	if (IsA(node, Aggref) || IsA(node, GroupingFunc))
949 	{
950 		/* any column is allowed to appear in an aggregate or grouping */
951 		return node;
952 	}
953 	else if (IsA(node, Var))
954 	{
955 		Var *var = (Var *) node;
956 
957 		/*
958 		 * Check whether this Var appears in the GROUP BY.
959 		 */
960 		TargetEntry *groupByTargetEntry = NULL;
961 		foreach_ptr(groupByTargetEntry, context->groupByTargetEntryList)
962 		{
963 			if (!IsA(groupByTargetEntry->expr, Var))
964 			{
965 				continue;
966 			}
967 
968 			Var *groupByVar = (Var *) groupByTargetEntry->expr;
969 
970 			/* we should only be doing this at the top level of the query */
971 			Assert(groupByVar->varlevelsup == 0);
972 
973 			if (var->varno == groupByVar->varno &&
974 				var->varattno == groupByVar->varattno)
975 			{
976 				/* this Var is in the GROUP BY, do not wrap it */
977 				return node;
978 			}
979 		}
980 
981 		/*
982 		 * We have found a Var that does not appear in the GROUP BY.
983 		 * Wrap it in an any_value aggregate.
984 		 */
985 		Aggref *agg = makeNode(Aggref);
986 		agg->aggfnoid = CitusAnyValueFunctionId();
987 		agg->aggtype = var->vartype;
988 		agg->args = list_make1(makeTargetEntry((Expr *) var, 1, NULL, false));
989 		agg->aggkind = AGGKIND_NORMAL;
990 		agg->aggtranstype = InvalidOid;
991 		agg->aggargtypes = list_make1_oid(var->vartype);
992 		agg->aggsplit = AGGSPLIT_SIMPLE;
993 		agg->aggcollid = exprCollation((Node *) var);
994 		return (Node *) agg;
995 	}
996 	else if (context->haveNonVarGrouping)
997 	{
998 		/*
999 		 * The GROUP BY contains at least one expression. Check whether the
1000 		 * current expression is equal to one of the GROUP BY expressions.
1001 		 * Otherwise, continue to descend into subexpressions.
1002 		 */
1003 		TargetEntry *groupByTargetEntry = NULL;
1004 		foreach_ptr(groupByTargetEntry, context->groupByTargetEntryList)
1005 		{
1006 			if (equal(node, groupByTargetEntry->expr))
1007 			{
1008 				/* do not descend into mutator, all Vars are safe */
1009 				return node;
1010 			}
1011 		}
1012 	}
1013 
1014 	return expression_tree_mutator(node, AddAnyValueAggregates, context);
1015 }
1016 
1017 
1018 /*
1019  * QueryGroupClauseList extracts the group clause list from the logical plan. If
1020  * no grouping clauses exist, the function returns an empty list.
1021  */
1022 static List *
QueryGroupClauseList(MultiNode * multiNode)1023 QueryGroupClauseList(MultiNode *multiNode)
1024 {
1025 	List *groupClauseList = NIL;
1026 	List *pendingNodeList = list_make1(multiNode);
1027 
1028 	while (pendingNodeList != NIL)
1029 	{
1030 		MultiNode *currMultiNode = (MultiNode *) linitial(pendingNodeList);
1031 		CitusNodeTag nodeType = CitusNodeTag(currMultiNode);
1032 		pendingNodeList = list_delete_first(pendingNodeList);
1033 
1034 		/* extract the group clause list from the extended operator */
1035 		if (nodeType == T_MultiExtendedOp)
1036 		{
1037 			MultiExtendedOp *extendedOpNode = (MultiExtendedOp *) currMultiNode;
1038 			groupClauseList = extendedOpNode->groupClauseList;
1039 		}
1040 
1041 		/* add children only if this node isn't a multi collect and multi table */
1042 		if (nodeType != T_MultiCollect && nodeType != T_MultiTable)
1043 		{
1044 			List *childNodeList = ChildNodeList(currMultiNode);
1045 			pendingNodeList = list_concat(pendingNodeList, childNodeList);
1046 		}
1047 	}
1048 
1049 	return groupClauseList;
1050 }
1051 
1052 
1053 /*
1054  * QuerySelectClauseList traverses the given logical plan tree, and extracts all
1055  * select clauses from the select nodes. Note that this function does not walk
1056  * below a collect node; the clauses below the collect node apply to a remote
1057  * query, and they would have been captured by the remote job we depend upon.
1058  */
1059 static List *
QuerySelectClauseList(MultiNode * multiNode)1060 QuerySelectClauseList(MultiNode *multiNode)
1061 {
1062 	List *selectClauseList = NIL;
1063 	List *pendingNodeList = list_make1(multiNode);
1064 
1065 	while (pendingNodeList != NIL)
1066 	{
1067 		MultiNode *currMultiNode = (MultiNode *) linitial(pendingNodeList);
1068 		CitusNodeTag nodeType = CitusNodeTag(currMultiNode);
1069 		pendingNodeList = list_delete_first(pendingNodeList);
1070 
1071 		/* extract select clauses from the multi select node */
1072 		if (nodeType == T_MultiSelect)
1073 		{
1074 			MultiSelect *selectNode = (MultiSelect *) currMultiNode;
1075 			List *clauseList = copyObject(selectNode->selectClauseList);
1076 			selectClauseList = list_concat(selectClauseList, clauseList);
1077 		}
1078 
1079 		/* add children only if this node isn't a multi collect */
1080 		if (nodeType != T_MultiCollect)
1081 		{
1082 			List *childNodeList = ChildNodeList(currMultiNode);
1083 			pendingNodeList = list_concat(pendingNodeList, childNodeList);
1084 		}
1085 	}
1086 
1087 	return selectClauseList;
1088 }
1089 
1090 
1091 /*
1092  * Create a tree of JoinExpr and RangeTblRef nodes for the job query from
1093  * a given multiNode. If the tree contains MultiCollect or MultiJoin nodes,
1094  * add corresponding entries to the range table list. We need to construct
1095  * the entries at the same time as the tree to know the appropriate rtindex.
1096  */
1097 static Node *
QueryJoinTree(MultiNode * multiNode,List * dependentJobList,List ** rangeTableList)1098 QueryJoinTree(MultiNode *multiNode, List *dependentJobList, List **rangeTableList)
1099 {
1100 	CitusNodeTag nodeType = CitusNodeTag(multiNode);
1101 
1102 	switch (nodeType)
1103 	{
1104 		case T_MultiJoin:
1105 		{
1106 			MultiJoin *joinNode = (MultiJoin *) multiNode;
1107 			MultiBinaryNode *binaryNode = (MultiBinaryNode *) multiNode;
1108 			ListCell *columnCell = NULL;
1109 			JoinExpr *joinExpr = makeNode(JoinExpr);
1110 			joinExpr->jointype = joinNode->joinType;
1111 			joinExpr->isNatural = false;
1112 			joinExpr->larg = QueryJoinTree(binaryNode->leftChildNode, dependentJobList,
1113 										   rangeTableList);
1114 			joinExpr->rarg = QueryJoinTree(binaryNode->rightChildNode, dependentJobList,
1115 										   rangeTableList);
1116 			joinExpr->usingClause = NIL;
1117 			joinExpr->alias = NULL;
1118 			joinExpr->rtindex = list_length(*rangeTableList) + 1;
1119 
1120 			/*
1121 			 * PostgreSQL's optimizer may mark left joins as anti-joins, when there
1122 			 * is a right-hand-join-key-is-null restriction, but there is no logic
1123 			 * in ruleutils to deparse anti-joins, so we cannot construct a task
1124 			 * query containing anti-joins. We therefore translate anti-joins back
1125 			 * into left-joins. At some point, we may also want to use different
1126 			 * join pruning logic for anti-joins.
1127 			 *
1128 			 * This approach would not work for anti-joins introduced via NOT EXISTS
1129 			 * sublinks, but currently such queries are prevented by error checks in
1130 			 * the logical planner.
1131 			 */
1132 			if (joinExpr->jointype == JOIN_ANTI)
1133 			{
1134 				joinExpr->jointype = JOIN_LEFT;
1135 			}
1136 
1137 			/* fix the column attributes in ON (...) clauses */
1138 			List *columnList = pull_var_clause_default((Node *) joinNode->joinClauseList);
1139 			foreach(columnCell, columnList)
1140 			{
1141 				Var *column = (Var *) lfirst(columnCell);
1142 				UpdateColumnAttributes(column, *rangeTableList, dependentJobList);
1143 
1144 				/* adjust our column old attributes for partition pruning to work */
1145 				column->varnosyn = column->varno;
1146 				column->varattnosyn = column->varattno;
1147 			}
1148 
1149 			/* make AND clauses explicit after fixing them */
1150 			joinExpr->quals = (Node *) make_ands_explicit(joinNode->joinClauseList);
1151 
1152 			RangeTblEntry *rangeTableEntry = JoinRangeTableEntry(joinExpr,
1153 																 dependentJobList,
1154 																 *rangeTableList);
1155 			*rangeTableList = lappend(*rangeTableList, rangeTableEntry);
1156 
1157 			return (Node *) joinExpr;
1158 		}
1159 
1160 		case T_MultiTable:
1161 		{
1162 			MultiTable *rangeTableNode = (MultiTable *) multiNode;
1163 			MultiUnaryNode *unaryNode = (MultiUnaryNode *) multiNode;
1164 
1165 			if (unaryNode->childNode != NULL)
1166 			{
1167 				/* MultiTable is actually a subquery, return the query tree below */
1168 				Node *childNode = QueryJoinTree(unaryNode->childNode, dependentJobList,
1169 												rangeTableList);
1170 
1171 				return childNode;
1172 			}
1173 			else
1174 			{
1175 				RangeTblRef *rangeTableRef = makeNode(RangeTblRef);
1176 				uint32 rangeTableId = rangeTableNode->rangeTableId;
1177 				rangeTableRef->rtindex = NewTableId(rangeTableId, *rangeTableList);
1178 
1179 				return (Node *) rangeTableRef;
1180 			}
1181 		}
1182 
1183 		case T_MultiCollect:
1184 		{
1185 			List *tableIdList = OutputTableIdList(multiNode);
1186 			Job *dependentJob = JobForTableIdList(dependentJobList, tableIdList);
1187 			List *dependentTargetList = dependentJob->jobQuery->targetList;
1188 
1189 			/* compute column names for the derived table */
1190 			uint32 columnCount = (uint32) list_length(dependentTargetList);
1191 			List *columnNameList = DerivedColumnNameList(columnCount,
1192 														 dependentJob->jobId);
1193 
1194 			List *funcColumnNames = NIL;
1195 			List *funcColumnTypes = NIL;
1196 			List *funcColumnTypeMods = NIL;
1197 			List *funcCollations = NIL;
1198 
1199 			TargetEntry *targetEntry = NULL;
1200 			foreach_ptr(targetEntry, dependentTargetList)
1201 			{
1202 				Node *expr = (Node *) targetEntry->expr;
1203 
1204 				char *name = targetEntry->resname;
1205 				if (name == NULL)
1206 				{
1207 					name = pstrdup("unnamed");
1208 				}
1209 
1210 				funcColumnNames = lappend(funcColumnNames, makeString(name));
1211 
1212 				funcColumnTypes = lappend_oid(funcColumnTypes, exprType(expr));
1213 				funcColumnTypeMods = lappend_int(funcColumnTypeMods, exprTypmod(expr));
1214 				funcCollations = lappend_oid(funcCollations, exprCollation(expr));
1215 			}
1216 
1217 			RangeTblEntry *rangeTableEntry = DerivedRangeTableEntry(multiNode,
1218 																	columnNameList,
1219 																	tableIdList,
1220 																	funcColumnNames,
1221 																	funcColumnTypes,
1222 																	funcColumnTypeMods,
1223 																	funcCollations);
1224 
1225 			RangeTblRef *rangeTableRef = makeNode(RangeTblRef);
1226 
1227 			rangeTableRef->rtindex = list_length(*rangeTableList) + 1;
1228 			*rangeTableList = lappend(*rangeTableList, rangeTableEntry);
1229 
1230 			return (Node *) rangeTableRef;
1231 		}
1232 
1233 		case T_MultiCartesianProduct:
1234 		{
1235 			MultiBinaryNode *binaryNode = (MultiBinaryNode *) multiNode;
1236 
1237 			JoinExpr *joinExpr = makeNode(JoinExpr);
1238 			joinExpr->jointype = JOIN_INNER;
1239 			joinExpr->isNatural = false;
1240 			joinExpr->larg = QueryJoinTree(binaryNode->leftChildNode, dependentJobList,
1241 										   rangeTableList);
1242 			joinExpr->rarg = QueryJoinTree(binaryNode->rightChildNode, dependentJobList,
1243 										   rangeTableList);
1244 			joinExpr->usingClause = NIL;
1245 			joinExpr->alias = NULL;
1246 			joinExpr->quals = NULL;
1247 			joinExpr->rtindex = list_length(*rangeTableList) + 1;
1248 
1249 			RangeTblEntry *rangeTableEntry = JoinRangeTableEntry(joinExpr,
1250 																 dependentJobList,
1251 																 *rangeTableList);
1252 			*rangeTableList = lappend(*rangeTableList, rangeTableEntry);
1253 
1254 			return (Node *) joinExpr;
1255 		}
1256 
1257 		case T_MultiTreeRoot:
1258 		case T_MultiSelect:
1259 		case T_MultiProject:
1260 		case T_MultiExtendedOp:
1261 		case T_MultiPartition:
1262 		{
1263 			MultiUnaryNode *unaryNode = (MultiUnaryNode *) multiNode;
1264 
1265 			Assert(UnaryOperator(multiNode));
1266 
1267 			Node *childNode = QueryJoinTree(unaryNode->childNode, dependentJobList,
1268 											rangeTableList);
1269 
1270 			return childNode;
1271 		}
1272 
1273 		default:
1274 		{
1275 			ereport(ERROR, (errmsg("unrecognized multi-node type: %d", nodeType)));
1276 		}
1277 	}
1278 }
1279 
1280 
1281 /*
1282  * JoinRangeTableEntry builds a range table entry for a fully initialized JoinExpr node.
1283  * The column names and vars are determined using expandRTE, analogous to
1284  * transformFromClauseItem.
1285  */
1286 static RangeTblEntry *
JoinRangeTableEntry(JoinExpr * joinExpr,List * dependentJobList,List * rangeTableList)1287 JoinRangeTableEntry(JoinExpr *joinExpr, List *dependentJobList, List *rangeTableList)
1288 {
1289 	RangeTblEntry *rangeTableEntry = makeNode(RangeTblEntry);
1290 	List *leftColumnNames = NIL;
1291 	List *leftColumnVars = NIL;
1292 	List *joinedColumnNames = NIL;
1293 	List *joinedColumnVars = NIL;
1294 	int leftRangeTableId = ExtractRangeTableId(joinExpr->larg);
1295 	RangeTblEntry *leftRTE = rt_fetch(leftRangeTableId, rangeTableList);
1296 	List *rightColumnNames = NIL;
1297 	List *rightColumnVars = NIL;
1298 	int rightRangeTableId = ExtractRangeTableId(joinExpr->rarg);
1299 	RangeTblEntry *rightRTE = rt_fetch(rightRangeTableId, rangeTableList);
1300 
1301 	rangeTableEntry->rtekind = RTE_JOIN;
1302 	rangeTableEntry->relid = InvalidOid;
1303 	rangeTableEntry->inFromCl = true;
1304 	rangeTableEntry->alias = joinExpr->alias;
1305 	rangeTableEntry->jointype = joinExpr->jointype;
1306 	rangeTableEntry->subquery = NULL;
1307 	rangeTableEntry->eref = makeAlias("unnamed_join", NIL);
1308 
1309 	RangeTblEntry *leftCallingRTE = ConstructCallingRTE(leftRTE, dependentJobList);
1310 	RangeTblEntry *rightCallingRte = ConstructCallingRTE(rightRTE, dependentJobList);
1311 	ExtractColumns(leftCallingRTE, leftRangeTableId,
1312 				   &leftColumnNames, &leftColumnVars);
1313 	ExtractColumns(rightCallingRte, rightRangeTableId,
1314 				   &rightColumnNames, &rightColumnVars);
1315 	Oid leftRelId = leftCallingRTE->relid;
1316 	Oid rightRelId = rightCallingRte->relid;
1317 	joinedColumnNames = list_concat(joinedColumnNames, leftColumnNames);
1318 	joinedColumnNames = list_concat(joinedColumnNames, rightColumnNames);
1319 	joinedColumnVars = list_concat(joinedColumnVars, leftColumnVars);
1320 	joinedColumnVars = list_concat(joinedColumnVars, rightColumnVars);
1321 
1322 	rangeTableEntry->eref->colnames = joinedColumnNames;
1323 	rangeTableEntry->joinaliasvars = joinedColumnVars;
1324 
1325 	SetJoinRelatedColumnsCompat(rangeTableEntry, leftRelId, rightRelId, leftColumnVars,
1326 								rightColumnVars);
1327 
1328 	return rangeTableEntry;
1329 }
1330 
1331 
1332 /*
1333  * SetJoinRelatedColumnsCompat sets join related fields on the given range table entry.
1334  * Currently it sets joinleftcols/joinrightcols which are introduced with postgres 13.
1335  * For more info see postgres commit: 9ce77d75c5ab094637cc4a446296dc3be6e3c221
1336  */
1337 static void
SetJoinRelatedColumnsCompat(RangeTblEntry * rangeTableEntry,Oid leftRelId,Oid rightRelId,List * leftColumnVars,List * rightColumnVars)1338 SetJoinRelatedColumnsCompat(RangeTblEntry *rangeTableEntry, Oid leftRelId, Oid rightRelId,
1339 							List *leftColumnVars, List *rightColumnVars)
1340 {
1341 	#if PG_VERSION_NUM >= PG_VERSION_13
1342 
1343 	/* We don't have any merged columns so set it to 0 */
1344 	rangeTableEntry->joinmergedcols = 0;
1345 
1346 	if (OidIsValid(leftRelId))
1347 	{
1348 		rangeTableEntry->joinleftcols = GetColumnOriginalIndexes(leftRelId);
1349 	}
1350 	else
1351 	{
1352 		int leftColsSize = list_length(leftColumnVars);
1353 		rangeTableEntry->joinleftcols = GeneratePositiveIntSequenceList(leftColsSize);
1354 	}
1355 
1356 	if (OidIsValid(rightRelId))
1357 	{
1358 		rangeTableEntry->joinrightcols = GetColumnOriginalIndexes(rightRelId);
1359 	}
1360 	else
1361 	{
1362 		int rightColsSize = list_length(rightColumnVars);
1363 		rangeTableEntry->joinrightcols = GeneratePositiveIntSequenceList(rightColsSize);
1364 	}
1365 
1366 	#endif
1367 }
1368 
1369 
1370 #if PG_VERSION_NUM >= PG_VERSION_13
1371 
1372 /*
1373  * GetColumnOriginalIndexes gets the original indexes of columns by taking column drops into account.
1374  */
1375 static List *
GetColumnOriginalIndexes(Oid relationId)1376 GetColumnOriginalIndexes(Oid relationId)
1377 {
1378 	List *originalIndexes = NIL;
1379 	Relation relation = table_open(relationId, AccessShareLock);
1380 	TupleDesc tupleDescriptor = RelationGetDescr(relation);
1381 	for (int columnIndex = 0; columnIndex < tupleDescriptor->natts; columnIndex++)
1382 	{
1383 		Form_pg_attribute currentColumn = TupleDescAttr(tupleDescriptor, columnIndex);
1384 		if (currentColumn->attisdropped)
1385 		{
1386 			continue;
1387 		}
1388 		originalIndexes = lappend_int(originalIndexes, columnIndex + 1);
1389 	}
1390 	table_close(relation, NoLock);
1391 	return originalIndexes;
1392 }
1393 
1394 
1395 #endif
1396 
1397 /*
1398  * ExtractRangeTableId gets the range table id from a node that could
1399  * either be a JoinExpr or RangeTblRef.
1400  */
1401 static int
ExtractRangeTableId(Node * node)1402 ExtractRangeTableId(Node *node)
1403 {
1404 	int rangeTableId = 0;
1405 
1406 	if (IsA(node, JoinExpr))
1407 	{
1408 		JoinExpr *joinExpr = (JoinExpr *) node;
1409 		rangeTableId = joinExpr->rtindex;
1410 	}
1411 	else if (IsA(node, RangeTblRef))
1412 	{
1413 		RangeTblRef *rangeTableRef = (RangeTblRef *) node;
1414 		rangeTableId = rangeTableRef->rtindex;
1415 	}
1416 
1417 	Assert(rangeTableId > 0);
1418 
1419 	return rangeTableId;
1420 }
1421 
1422 
1423 /*
1424  * ExtractColumns gets a list of column names and vars for a given range
1425  * table entry using expandRTE.
1426  */
1427 static void
ExtractColumns(RangeTblEntry * callingRTE,int rangeTableId,List ** columnNames,List ** columnVars)1428 ExtractColumns(RangeTblEntry *callingRTE, int rangeTableId,
1429 			   List **columnNames, List **columnVars)
1430 {
1431 	int subLevelsUp = 0;
1432 	int location = -1;
1433 	bool includeDroppedColumns = false;
1434 	expandRTE(callingRTE, rangeTableId, subLevelsUp, location, includeDroppedColumns,
1435 			  columnNames, columnVars);
1436 }
1437 
1438 
1439 /*
1440  * ConstructCallingRTE constructs a calling RTE from the given range table entry and
1441  * dependentJobList in case of repartition joins. Since the range table entries in a job
1442  * query are mocked RTE_FUNCTION entries, this construction is needed to form an RTE
1443  * that expandRTE can handle.
1444  */
1445 static RangeTblEntry *
ConstructCallingRTE(RangeTblEntry * rangeTableEntry,List * dependentJobList)1446 ConstructCallingRTE(RangeTblEntry *rangeTableEntry, List *dependentJobList)
1447 {
1448 	RangeTblEntry *callingRTE = NULL;
1449 
1450 	CitusRTEKind rangeTableKind = GetRangeTblKind(rangeTableEntry);
1451 	if (rangeTableKind == CITUS_RTE_JOIN)
1452 	{
1453 		/*
1454 		 * For joins, we can call expandRTE directly.
1455 		 */
1456 		callingRTE = rangeTableEntry;
1457 	}
1458 	else if (rangeTableKind == CITUS_RTE_RELATION)
1459 	{
1460 		/*
1461 		 * For distributed tables, we construct a regular table RTE to call
1462 		 * expandRTE, which will extract columns from the distributed table
1463 		 * schema.
1464 		 */
1465 		callingRTE = makeNode(RangeTblEntry);
1466 		callingRTE->rtekind = RTE_RELATION;
1467 		callingRTE->eref = rangeTableEntry->eref;
1468 		callingRTE->relid = rangeTableEntry->relid;
1469 		callingRTE->inh = rangeTableEntry->inh;
1470 	}
1471 	else if (rangeTableKind == CITUS_RTE_REMOTE_QUERY)
1472 	{
1473 		Job *dependentJob = JobForRangeTable(dependentJobList, rangeTableEntry);
1474 		Query *jobQuery = dependentJob->jobQuery;
1475 
1476 		/*
1477 		 * For re-partition jobs, we construct a subquery RTE to call expandRTE,
1478 		 * which will extract the columns from the target list of the job query.
1479 		 */
1480 		callingRTE = makeNode(RangeTblEntry);
1481 		callingRTE->rtekind = RTE_SUBQUERY;
1482 		callingRTE->eref = rangeTableEntry->eref;
1483 		callingRTE->subquery = jobQuery;
1484 	}
1485 	else
1486 	{
1487 		ereport(ERROR, (errmsg("unsupported Citus RTE kind: %d", rangeTableKind)));
1488 	}
1489 	return callingRTE;
1490 }
1491 
1492 
1493 /*
1494  * QueryFromList creates the from list construct that is used for building the
1495  * query's join tree. The function creates the from list by making a range table
1496  * reference for each entry in the given range table list.
1497  */
1498 static List *
QueryFromList(List * rangeTableList)1499 QueryFromList(List *rangeTableList)
1500 {
1501 	List *fromList = NIL;
1502 	int rangeTableCount = list_length(rangeTableList);
1503 
1504 	for (Index rangeTableIndex = 1; rangeTableIndex <= rangeTableCount; rangeTableIndex++)
1505 	{
1506 		RangeTblRef *rangeTableReference = makeNode(RangeTblRef);
1507 		rangeTableReference->rtindex = rangeTableIndex;
1508 
1509 		fromList = lappend(fromList, rangeTableReference);
1510 	}
1511 
1512 	return fromList;
1513 }
1514 
1515 
1516 /*
1517  * BuildSubqueryJobQuery traverses the given logical plan tree, finds MultiTable
1518  * which represents the subquery. It builds the query structure by adding this
1519  * subquery as it is to range table list of the query.
1520  *
1521  * Such as if user runs a query like this;
1522  *
1523  * SELECT avg(id) FROM (
1524  *     SELECT ... FROM ()
1525  * )
1526  *
1527  * then this function will build this worker query as keeping subquery as it is;
1528  *
1529  * SELECT sum(id), count(id) FROM (
1530  *     SELECT ... FROM ()
1531  * )
1532  */
1533 static Query *
BuildSubqueryJobQuery(MultiNode * multiNode)1534 BuildSubqueryJobQuery(MultiNode *multiNode)
1535 {
1536 	List *targetList = NIL;
1537 	List *sortClauseList = NIL;
1538 	Node *havingQual = NULL;
1539 	Node *limitCount = NULL;
1540 	Node *limitOffset = NULL;
1541 	bool hasAggregates = false;
1542 	List *distinctClause = NIL;
1543 	bool hasDistinctOn = false;
1544 	bool hasWindowFuncs = false;
1545 	List *windowClause = NIL;
1546 
1547 	/* we start building jobs from below the collect node */
1548 	Assert(!CitusIsA(multiNode, MultiCollect));
1549 
1550 	List *subqueryMultiTableList = SubqueryMultiTableList(multiNode);
1551 	Assert(list_length(subqueryMultiTableList) == 1);
1552 
1553 	MultiTable *multiTable = (MultiTable *) linitial(subqueryMultiTableList);
1554 	Query *subquery = multiTable->subquery;
1555 
1556 	/*  build subquery range table list */
1557 	RangeTblEntry *rangeTableEntry = makeNode(RangeTblEntry);
1558 	rangeTableEntry->rtekind = RTE_SUBQUERY;
1559 	rangeTableEntry->inFromCl = true;
1560 	rangeTableEntry->eref = multiTable->referenceNames;
1561 	rangeTableEntry->alias = multiTable->alias;
1562 	rangeTableEntry->subquery = subquery;
1563 
1564 	List *rangeTableList = list_make1(rangeTableEntry);
1565 
1566 	/*
1567 	 * If we have an extended operator, then we copy the operator's target list.
1568 	 * Otherwise, we use the target list based on the MultiProject node at this
1569 	 * level in the query tree.
1570 	 */
1571 	List *extendedOpNodeList = FindNodesOfType(multiNode, T_MultiExtendedOp);
1572 	if (extendedOpNodeList != NIL)
1573 	{
1574 		MultiExtendedOp *extendedOp = (MultiExtendedOp *) linitial(extendedOpNodeList);
1575 		targetList = copyObject(extendedOp->targetList);
1576 	}
1577 	else
1578 	{
1579 		targetList = QueryTargetList(multiNode);
1580 	}
1581 
1582 	/* extract limit count/offset, sort and having clauses */
1583 	if (extendedOpNodeList != NIL)
1584 	{
1585 		MultiExtendedOp *extendedOp = (MultiExtendedOp *) linitial(extendedOpNodeList);
1586 
1587 		limitCount = extendedOp->limitCount;
1588 		limitOffset = extendedOp->limitOffset;
1589 		sortClauseList = extendedOp->sortClauseList;
1590 		havingQual = extendedOp->havingQual;
1591 		distinctClause = extendedOp->distinctClause;
1592 		hasDistinctOn = extendedOp->hasDistinctOn;
1593 		hasWindowFuncs = extendedOp->hasWindowFuncs;
1594 		windowClause = extendedOp->windowClause;
1595 	}
1596 
1597 	/* build group clauses */
1598 	List *groupClauseList = QueryGroupClauseList(multiNode);
1599 
1600 	/* build the where clause list using select predicates */
1601 	List *whereClauseList = QuerySelectClauseList(multiNode);
1602 
1603 	if (contain_aggs_of_level((Node *) targetList, 0) ||
1604 		contain_aggs_of_level((Node *) havingQual, 0))
1605 	{
1606 		hasAggregates = true;
1607 	}
1608 
1609 	/* distinct is not sent to worker query if there are top level aggregates */
1610 	if (hasAggregates)
1611 	{
1612 		hasDistinctOn = false;
1613 		distinctClause = NIL;
1614 	}
1615 
1616 
1617 	/*
1618 	 * Build the From/Where construct. We keep the where-clause list implicitly
1619 	 * AND'd, since both partition and join pruning depends on the clauses being
1620 	 * expressed as a list.
1621 	 */
1622 	FromExpr *joinTree = makeNode(FromExpr);
1623 	joinTree->quals = (Node *) whereClauseList;
1624 	joinTree->fromlist = QueryFromList(rangeTableList);
1625 
1626 	/* build the query structure for this job */
1627 	Query *jobQuery = makeNode(Query);
1628 	jobQuery->commandType = CMD_SELECT;
1629 	jobQuery->querySource = QSRC_ORIGINAL;
1630 	jobQuery->canSetTag = true;
1631 	jobQuery->rtable = rangeTableList;
1632 	jobQuery->targetList = targetList;
1633 	jobQuery->jointree = joinTree;
1634 	jobQuery->sortClause = sortClauseList;
1635 	jobQuery->groupClause = groupClauseList;
1636 	jobQuery->limitOffset = limitOffset;
1637 	jobQuery->limitCount = limitCount;
1638 	jobQuery->havingQual = havingQual;
1639 	jobQuery->hasAggs = hasAggregates;
1640 	jobQuery->hasDistinctOn = hasDistinctOn;
1641 	jobQuery->distinctClause = distinctClause;
1642 	jobQuery->hasWindowFuncs = hasWindowFuncs;
1643 	jobQuery->windowClause = windowClause;
1644 	jobQuery->hasSubLinks = checkExprHasSubLink((Node *) jobQuery);
1645 
1646 	Assert(jobQuery->hasWindowFuncs == contain_window_function((Node *) jobQuery));
1647 
1648 	return jobQuery;
1649 }
1650 
1651 
1652 /*
1653  * UpdateAllColumnAttributes extracts column references from provided columnContainer
1654  * and calls UpdateColumnAttributes to updates the column's range table reference (varno) and
1655  * column attribute number for the range table (varattno).
1656  */
1657 static void
UpdateAllColumnAttributes(Node * columnContainer,List * rangeTableList,List * dependentJobList)1658 UpdateAllColumnAttributes(Node *columnContainer, List *rangeTableList,
1659 						  List *dependentJobList)
1660 {
1661 	ListCell *columnCell = NULL;
1662 	List *columnList = pull_var_clause_default(columnContainer);
1663 	foreach(columnCell, columnList)
1664 	{
1665 		Var *column = (Var *) lfirst(columnCell);
1666 		UpdateColumnAttributes(column, rangeTableList, dependentJobList);
1667 	}
1668 }
1669 
1670 
1671 /*
1672  * UpdateColumnAttributes updates the column's range table reference (varno) and
1673  * column attribute number for the range table (varattno). The function uses the
1674  * newly built range table list to update the given column's attributes.
1675  */
1676 static void
UpdateColumnAttributes(Var * column,List * rangeTableList,List * dependentJobList)1677 UpdateColumnAttributes(Var *column, List *rangeTableList, List *dependentJobList)
1678 {
1679 	Index originalTableId = column->varnosyn;
1680 	AttrNumber originalColumnId = column->varattnosyn;
1681 
1682 	/* find the new table identifier */
1683 	Index newTableId = NewTableId(originalTableId, rangeTableList);
1684 	AttrNumber newColumnId = originalColumnId;
1685 
1686 	/* if this is a derived table, find the new column identifier */
1687 	RangeTblEntry *newRangeTableEntry = rt_fetch(newTableId, rangeTableList);
1688 	if (GetRangeTblKind(newRangeTableEntry) == CITUS_RTE_REMOTE_QUERY)
1689 	{
1690 		newColumnId = NewColumnId(originalTableId, originalColumnId,
1691 								  newRangeTableEntry, dependentJobList);
1692 	}
1693 
1694 	column->varno = newTableId;
1695 	column->varattno = newColumnId;
1696 }
1697 
1698 
1699 /*
1700  * NewTableId determines the new tableId for the query that is currently being
1701  * built. In this query, the original tableId represents the order of the table
1702  * in the initial parse tree. When queries involve repartitioning, we re-order
1703  * tables; and the new tableId corresponds to this new table order.
1704  */
1705 static Index
NewTableId(Index originalTableId,List * rangeTableList)1706 NewTableId(Index originalTableId, List *rangeTableList)
1707 {
1708 	Index rangeTableIndex = 1;
1709 	ListCell *rangeTableCell = NULL;
1710 
1711 	foreach(rangeTableCell, rangeTableList)
1712 	{
1713 		RangeTblEntry *rangeTableEntry = (RangeTblEntry *) lfirst(rangeTableCell);
1714 		List *originalTableIdList = NIL;
1715 
1716 		ExtractRangeTblExtraData(rangeTableEntry, NULL, NULL, NULL, &originalTableIdList);
1717 
1718 		bool listMember = list_member_int(originalTableIdList, originalTableId);
1719 		if (listMember)
1720 		{
1721 			return rangeTableIndex;
1722 		}
1723 
1724 		rangeTableIndex++;
1725 	}
1726 
1727 	ereport(ERROR, (errmsg("Unrecognized range table id %d", (int) originalTableId)));
1728 
1729 	return 0;
1730 }
1731 
1732 
1733 /*
1734  * NewColumnId determines the new columnId for the query that is currently being
1735  * built. In this query, the original columnId corresponds to the column in base
1736  * tables. When the current query is a partition job and generates intermediate
1737  * tables, the columns have a different order and the new columnId corresponds
1738  * to this order. Please note that this function assumes columnIds for dependent
1739  * jobs have already been updated.
1740  */
1741 static AttrNumber
NewColumnId(Index originalTableId,AttrNumber originalColumnId,RangeTblEntry * newRangeTableEntry,List * dependentJobList)1742 NewColumnId(Index originalTableId, AttrNumber originalColumnId,
1743 			RangeTblEntry *newRangeTableEntry, List *dependentJobList)
1744 {
1745 	AttrNumber newColumnId = 1;
1746 	AttrNumber columnIndex = 1;
1747 
1748 	Job *dependentJob = JobForRangeTable(dependentJobList, newRangeTableEntry);
1749 	List *targetEntryList = dependentJob->jobQuery->targetList;
1750 
1751 	ListCell *targetEntryCell = NULL;
1752 	foreach(targetEntryCell, targetEntryList)
1753 	{
1754 		TargetEntry *targetEntry = (TargetEntry *) lfirst(targetEntryCell);
1755 		Expr *expression = targetEntry->expr;
1756 
1757 		Var *column = (Var *) expression;
1758 		Assert(IsA(expression, Var));
1759 
1760 		/*
1761 		 * Check against the *old* values for this column, as the new values
1762 		 * would have been updated already.
1763 		 */
1764 		if (column->varnosyn == originalTableId &&
1765 			column->varattnosyn == originalColumnId)
1766 		{
1767 			newColumnId = columnIndex;
1768 			break;
1769 		}
1770 
1771 		columnIndex++;
1772 	}
1773 
1774 	return newColumnId;
1775 }
1776 
1777 
1778 /*
1779  * JobForRangeTable returns the job that corresponds to the given range table
1780  * entry. The function walks over jobs in the given job list, and compares each
1781  * job's table list against the given range table entry's table list. When two
1782  * table lists match, the function returns the matching job. Note that we call
1783  * this function in practice when we need to determine which one of the jobs we
1784  * depend upon corresponds to given range table entry.
1785  */
1786 static Job *
JobForRangeTable(List * jobList,RangeTblEntry * rangeTableEntry)1787 JobForRangeTable(List *jobList, RangeTblEntry *rangeTableEntry)
1788 {
1789 	List *searchedTableIdList = NIL;
1790 	CitusRTEKind rangeTableKind;
1791 
1792 	ExtractRangeTblExtraData(rangeTableEntry, &rangeTableKind, NULL, NULL,
1793 							 &searchedTableIdList);
1794 
1795 	Assert(rangeTableKind == CITUS_RTE_REMOTE_QUERY);
1796 
1797 	Job *searchedJob = JobForTableIdList(jobList, searchedTableIdList);
1798 
1799 	return searchedJob;
1800 }
1801 
1802 
1803 /*
1804  * JobForTableIdList returns the job that corresponds to the given
1805  * tableIdList. The function walks over jobs in the given job list, and
1806  * compares each job's table list against the given table list. When the
1807  * two table lists match, the function returns the matching job.
1808  */
1809 static Job *
JobForTableIdList(List * jobList,List * searchedTableIdList)1810 JobForTableIdList(List *jobList, List *searchedTableIdList)
1811 {
1812 	Job *searchedJob = NULL;
1813 	ListCell *jobCell = NULL;
1814 
1815 	foreach(jobCell, jobList)
1816 	{
1817 		Job *job = (Job *) lfirst(jobCell);
1818 		List *jobRangeTableList = job->jobQuery->rtable;
1819 		List *jobTableIdList = NIL;
1820 		ListCell *jobRangeTableCell = NULL;
1821 
1822 		foreach(jobRangeTableCell, jobRangeTableList)
1823 		{
1824 			RangeTblEntry *jobRangeTable = (RangeTblEntry *) lfirst(jobRangeTableCell);
1825 			List *tableIdList = NIL;
1826 
1827 			ExtractRangeTblExtraData(jobRangeTable, NULL, NULL, NULL, &tableIdList);
1828 
1829 			/* copy the list since list_concat is destructive */
1830 			tableIdList = list_copy(tableIdList);
1831 			jobTableIdList = list_concat(jobTableIdList, tableIdList);
1832 		}
1833 
1834 		/*
1835 		 * Check if the searched range table's tableIds and the current job's
1836 		 * tableIds are the same.
1837 		 */
1838 		List *lhsDiff = list_difference_int(jobTableIdList, searchedTableIdList);
1839 		List *rhsDiff = list_difference_int(searchedTableIdList, jobTableIdList);
1840 		if (lhsDiff == NIL && rhsDiff == NIL)
1841 		{
1842 			searchedJob = job;
1843 			break;
1844 		}
1845 	}
1846 
1847 	Assert(searchedJob != NULL);
1848 	return searchedJob;
1849 }
1850 
1851 
1852 /* Returns the list of children for the given multi node. */
1853 static List *
ChildNodeList(MultiNode * multiNode)1854 ChildNodeList(MultiNode *multiNode)
1855 {
1856 	List *childNodeList = NIL;
1857 	bool isUnaryNode = UnaryOperator(multiNode);
1858 	bool isBinaryNode = BinaryOperator(multiNode);
1859 
1860 	/* relation table nodes don't have any children */
1861 	if (CitusIsA(multiNode, MultiTable))
1862 	{
1863 		MultiTable *multiTable = (MultiTable *) multiNode;
1864 		if (multiTable->relationId != SUBQUERY_RELATION_ID)
1865 		{
1866 			return NIL;
1867 		}
1868 	}
1869 
1870 	if (isUnaryNode)
1871 	{
1872 		MultiUnaryNode *unaryNode = (MultiUnaryNode *) multiNode;
1873 		childNodeList = list_make1(unaryNode->childNode);
1874 	}
1875 	else if (isBinaryNode)
1876 	{
1877 		MultiBinaryNode *binaryNode = (MultiBinaryNode *) multiNode;
1878 		childNodeList = list_make2(binaryNode->leftChildNode,
1879 								   binaryNode->rightChildNode);
1880 	}
1881 
1882 	return childNodeList;
1883 }
1884 
1885 
1886 /*
1887  * UniqueJobId allocates and returns a unique jobId for the job to be executed.
1888  *
1889  * The resulting job ID is built up as:
1890  * <16-bit group ID><24-bit process ID><1-bit secondary flag><23-bit local counter>
1891  *
1892  * When citus.enable_unique_job_ids is off then only the local counter is
1893  * included to get repeatable results.
1894  */
1895 uint64
UniqueJobId(void)1896 UniqueJobId(void)
1897 {
1898 	static uint32 jobIdCounter = 0;
1899 
1900 	uint64 jobId = 0;
1901 	uint64 processId = 0;
1902 	uint64 localGroupId = 0;
1903 
1904 	jobIdCounter++;
1905 
1906 	if (EnableUniqueJobIds)
1907 	{
1908 		/*
1909 		 * Add the local group id information to the jobId to
1910 		 * prevent concurrent jobs on different groups to conflict.
1911 		 */
1912 		localGroupId = GetLocalGroupId() & 0xFF;
1913 		jobId = jobId | (localGroupId << 48);
1914 
1915 		/*
1916 		 * Add the current process ID to distinguish jobs by this
1917 		 * backends from jobs started by other backends. Process
1918 		 * IDs can have at most 24-bits on platforms supported by
1919 		 * Citus.
1920 		 */
1921 		processId = MyProcPid & 0xFFFFFF;
1922 		jobId = jobId | (processId << 24);
1923 
1924 		/*
1925 		 * Add an extra bit for secondaries to distinguish their
1926 		 * jobs from primaries.
1927 		 */
1928 		if (RecoveryInProgress())
1929 		{
1930 			jobId = jobId | (1 << 23);
1931 		}
1932 	}
1933 
1934 	/*
1935 	 * Use the remaining 23 bits to distinguish jobs by the
1936 	 * same backend.
1937 	 */
1938 	uint64 jobIdNumber = jobIdCounter & 0x1FFFFFF;
1939 	jobId = jobId | jobIdNumber;
1940 
1941 	return jobId;
1942 }
1943 
1944 
1945 /* Builds a job from the given job query and dependent job list. */
1946 static Job *
BuildJob(Query * jobQuery,List * dependentJobList)1947 BuildJob(Query *jobQuery, List *dependentJobList)
1948 {
1949 	Job *job = CitusMakeNode(Job);
1950 	job->jobId = UniqueJobId();
1951 	job->jobQuery = jobQuery;
1952 	job->dependentJobList = dependentJobList;
1953 	job->requiresCoordinatorEvaluation = false;
1954 
1955 	return job;
1956 }
1957 
1958 
1959 /*
1960  * BuildMapMergeJob builds a MapMerge job from the given query and dependent job
1961  * list. The function then copies and updates the logical plan's partition
1962  * column, and uses the join rule type to determine the physical repartitioning
1963  * method to apply.
1964  */
1965 static MapMergeJob *
BuildMapMergeJob(Query * jobQuery,List * dependentJobList,Var * partitionKey,PartitionType partitionType,Oid baseRelationId,BoundaryNodeJobType boundaryNodeJobType)1966 BuildMapMergeJob(Query *jobQuery, List *dependentJobList, Var *partitionKey,
1967 				 PartitionType partitionType, Oid baseRelationId,
1968 				 BoundaryNodeJobType boundaryNodeJobType)
1969 {
1970 	List *rangeTableList = jobQuery->rtable;
1971 	Var *partitionColumn = copyObject(partitionKey);
1972 
1973 	/* update the logical partition key's table and column identifiers */
1974 	if (boundaryNodeJobType != SUBQUERY_MAP_MERGE_JOB)
1975 	{
1976 		UpdateColumnAttributes(partitionColumn, rangeTableList, dependentJobList);
1977 	}
1978 
1979 	MapMergeJob *mapMergeJob = CitusMakeNode(MapMergeJob);
1980 	mapMergeJob->job.jobId = UniqueJobId();
1981 	mapMergeJob->job.jobQuery = jobQuery;
1982 	mapMergeJob->job.dependentJobList = dependentJobList;
1983 	mapMergeJob->partitionColumn = partitionColumn;
1984 	mapMergeJob->sortedShardIntervalArrayLength = 0;
1985 
1986 	/*
1987 	 * We assume dual partition join defaults to hash partitioning, and single
1988 	 * partition join defaults to range partitioning. In practice, the join type
1989 	 * should have no impact on the physical repartitioning (hash/range) method.
1990 	 * If join type is not set, this means this job represents a subquery, and
1991 	 * uses hash partitioning.
1992 	 */
1993 	if (partitionType == DUAL_HASH_PARTITION_TYPE)
1994 	{
1995 		uint32 partitionCount = HashPartitionCount();
1996 
1997 		mapMergeJob->partitionType = DUAL_HASH_PARTITION_TYPE;
1998 		mapMergeJob->partitionCount = partitionCount;
1999 	}
2000 	else if (partitionType == SINGLE_HASH_PARTITION_TYPE || partitionType ==
2001 			 RANGE_PARTITION_TYPE)
2002 	{
2003 		CitusTableCacheEntry *cache = GetCitusTableCacheEntry(baseRelationId);
2004 		int shardCount = cache->shardIntervalArrayLength;
2005 		ShardInterval **cachedSortedShardIntervalArray =
2006 			cache->sortedShardIntervalArray;
2007 		bool hasUninitializedShardInterval =
2008 			cache->hasUninitializedShardInterval;
2009 
2010 		ShardInterval **sortedShardIntervalArray =
2011 			palloc0(sizeof(ShardInterval) * shardCount);
2012 
2013 		for (int shardIndex = 0; shardIndex < shardCount; shardIndex++)
2014 		{
2015 			sortedShardIntervalArray[shardIndex] =
2016 				CopyShardInterval(cachedSortedShardIntervalArray[shardIndex]);
2017 		}
2018 
2019 		if (hasUninitializedShardInterval)
2020 		{
2021 			ereport(ERROR, (errmsg("cannot range repartition shard with "
2022 								   "missing min/max values")));
2023 		}
2024 
2025 		mapMergeJob->partitionType = partitionType;
2026 		mapMergeJob->partitionCount = (uint32) shardCount;
2027 		mapMergeJob->sortedShardIntervalArray = sortedShardIntervalArray;
2028 		mapMergeJob->sortedShardIntervalArrayLength = shardCount;
2029 	}
2030 
2031 	return mapMergeJob;
2032 }
2033 
2034 
2035 /*
2036  * HashPartitionCount returns the number of partition files we create for a hash
2037  * partition task. The function follows Hadoop's method for picking the number
2038  * of reduce tasks: 0.95 or 1.75 * node count * max reduces per node. We choose
2039  * the lower constant 0.95 so that all tasks can start immediately, but round it
2040  * to 1.0 so that we have a smooth number of partition tasks.
2041  */
2042 static uint32
HashPartitionCount(void)2043 HashPartitionCount(void)
2044 {
2045 	uint32 groupCount = list_length(ActiveReadableNodeList());
2046 	double maxReduceTasksPerNode = RepartitionJoinBucketCountPerNode;
2047 
2048 	uint32 partitionCount = (uint32) rint(groupCount * maxReduceTasksPerNode);
2049 	return partitionCount;
2050 }
2051 
2052 
2053 /*
2054  * SplitPointObject walks over shard intervals in the given array, extracts each
2055  * shard interval's minimum value, sorts and inserts these minimum values into a
2056  * new array. This sorted array is then used by the MapMerge job.
2057  */
2058 static ArrayType *
SplitPointObject(ShardInterval ** shardIntervalArray,uint32 shardIntervalCount)2059 SplitPointObject(ShardInterval **shardIntervalArray, uint32 shardIntervalCount)
2060 {
2061 	Oid typeId = InvalidOid;
2062 	bool typeByValue = false;
2063 	char typeAlignment = 0;
2064 	int16 typeLength = 0;
2065 
2066 	/* allocate an array for shard min values */
2067 	uint32 minDatumCount = shardIntervalCount;
2068 	Datum *minDatumArray = palloc0(minDatumCount * sizeof(Datum));
2069 
2070 	for (uint32 intervalIndex = 0; intervalIndex < shardIntervalCount; intervalIndex++)
2071 	{
2072 		ShardInterval *shardInterval = shardIntervalArray[intervalIndex];
2073 		minDatumArray[intervalIndex] = shardInterval->minValue;
2074 		Assert(shardInterval->minValueExists);
2075 
2076 		/* resolve the datum type on the first pass */
2077 		if (intervalIndex == 0)
2078 		{
2079 			typeId = shardInterval->valueTypeId;
2080 		}
2081 	}
2082 
2083 	/* construct the split point object from the sorted array */
2084 	get_typlenbyvalalign(typeId, &typeLength, &typeByValue, &typeAlignment);
2085 	ArrayType *splitPointObject = construct_array(minDatumArray, minDatumCount, typeId,
2086 												  typeLength, typeByValue, typeAlignment);
2087 
2088 	return splitPointObject;
2089 }
2090 
2091 
2092 /* ------------------------------------------------------------
2093  * Functions that relate to building and assigning tasks follow
2094  * ------------------------------------------------------------
2095  */
2096 
2097 /*
2098  * BuildJobTreeTaskList takes in the given job tree and walks over jobs in this
2099  * tree bottom up. The function then creates tasks for each job in the tree,
2100  * sets dependencies between tasks and their downstream dependencies and assigns
2101  * tasks to worker nodes.
2102  */
2103 static Job *
BuildJobTreeTaskList(Job * jobTree,PlannerRestrictionContext * plannerRestrictionContext)2104 BuildJobTreeTaskList(Job *jobTree, PlannerRestrictionContext *plannerRestrictionContext)
2105 {
2106 	List *flattenedJobList = NIL;
2107 
2108 	/*
2109 	 * We traverse the job tree in preorder, and append each visited job to our
2110 	 * flattened list. This way, each job in our list appears before the jobs it
2111 	 * depends on.
2112 	 */
2113 	List *jobStack = list_make1(jobTree);
2114 	while (jobStack != NIL)
2115 	{
2116 		Job *job = (Job *) llast(jobStack);
2117 		flattenedJobList = lappend(flattenedJobList, job);
2118 
2119 		/* pop top element and push its children to the stack */
2120 		jobStack = list_delete_ptr(jobStack, job);
2121 		jobStack = list_union_ptr(jobStack, job->dependentJobList);
2122 	}
2123 
2124 	/*
2125 	 * We walk the job list in reverse order to visit jobs bottom up. This way,
2126 	 * we can create dependencies between tasks bottom up, and assign them to
2127 	 * worker nodes accordingly.
2128 	 */
2129 	uint32 flattenedJobCount = (int32) list_length(flattenedJobList);
2130 	for (int32 jobIndex = (flattenedJobCount - 1); jobIndex >= 0; jobIndex--)
2131 	{
2132 		Job *job = (Job *) list_nth(flattenedJobList, jobIndex);
2133 		List *sqlTaskList = NIL;
2134 		ListCell *assignedSqlTaskCell = NULL;
2135 
2136 		/* create sql tasks for the job, and prune redundant data fetch tasks */
2137 		if (job->subqueryPushdown)
2138 		{
2139 			bool isMultiShardQuery = false;
2140 			List *prunedRelationShardList =
2141 				TargetShardIntervalsForRestrictInfo(plannerRestrictionContext->
2142 													relationRestrictionContext,
2143 													&isMultiShardQuery, NULL);
2144 
2145 			DeferredErrorMessage *deferredErrorMessage = NULL;
2146 			sqlTaskList = QueryPushdownSqlTaskList(job->jobQuery, job->jobId,
2147 												   plannerRestrictionContext->
2148 												   relationRestrictionContext,
2149 												   prunedRelationShardList, READ_TASK,
2150 												   false,
2151 												   &deferredErrorMessage);
2152 			if (deferredErrorMessage != NULL)
2153 			{
2154 				RaiseDeferredErrorInternal(deferredErrorMessage, ERROR);
2155 			}
2156 		}
2157 		else
2158 		{
2159 			sqlTaskList = SqlTaskList(job);
2160 		}
2161 
2162 		sqlTaskList = PruneSqlTaskDependencies(sqlTaskList);
2163 
2164 		/*
2165 		 * We first assign sql and merge tasks to worker nodes. Next, we assign
2166 		 * sql tasks' data fetch dependencies.
2167 		 */
2168 		List *assignedSqlTaskList = AssignTaskList(sqlTaskList);
2169 		AssignDataFetchDependencies(assignedSqlTaskList);
2170 
2171 		/* if the parameters has not been resolved, record it */
2172 		job->parametersInJobQueryResolved =
2173 			!HasUnresolvedExternParamsWalker((Node *) job->jobQuery, NULL);
2174 
2175 		/*
2176 		 * Make final adjustments for the assigned tasks.
2177 		 *
2178 		 * First, update SELECT tasks' parameters resolved field.
2179 		 *
2180 		 * Second, assign merge task's data fetch dependencies.
2181 		 */
2182 		foreach(assignedSqlTaskCell, assignedSqlTaskList)
2183 		{
2184 			Task *assignedSqlTask = (Task *) lfirst(assignedSqlTaskCell);
2185 
2186 			/* we don't support parameters in the physical planner */
2187 			if (assignedSqlTask->taskType == READ_TASK)
2188 			{
2189 				assignedSqlTask->parametersInQueryStringResolved =
2190 					job->parametersInJobQueryResolved;
2191 			}
2192 
2193 			List *assignedMergeTaskList = FindDependentMergeTaskList(assignedSqlTask);
2194 			AssignDataFetchDependencies(assignedMergeTaskList);
2195 		}
2196 
2197 		/*
2198 		 * If we have a MapMerge job, the map tasks in this job wrap around the
2199 		 * SQL tasks and their assignments.
2200 		 */
2201 		if (CitusIsA(job, MapMergeJob))
2202 		{
2203 			MapMergeJob *mapMergeJob = (MapMergeJob *) job;
2204 			uint32 taskIdIndex = TaskListHighestTaskId(assignedSqlTaskList) + 1;
2205 
2206 			List *mapTaskList = MapTaskList(mapMergeJob, assignedSqlTaskList);
2207 			List *mergeTaskList = MergeTaskList(mapMergeJob, mapTaskList, taskIdIndex);
2208 
2209 			mapMergeJob->mapTaskList = mapTaskList;
2210 			mapMergeJob->mergeTaskList = mergeTaskList;
2211 		}
2212 		else
2213 		{
2214 			job->taskList = assignedSqlTaskList;
2215 		}
2216 	}
2217 
2218 	return jobTree;
2219 }
2220 
2221 
2222 /*
2223  * QueryPushdownSqlTaskList creates a list of SQL tasks to execute the given subquery
2224  * pushdown job. For this, it is being checked whether the query is router
2225  * plannable per target shard interval. For those router plannable worker
2226  * queries, we create a SQL task and append the task to the task list that is going
2227  * to be executed.
2228  */
2229 List *
QueryPushdownSqlTaskList(Query * query,uint64 jobId,RelationRestrictionContext * relationRestrictionContext,List * prunedRelationShardList,TaskType taskType,bool modifyRequiresCoordinatorEvaluation,DeferredErrorMessage ** planningError)2230 QueryPushdownSqlTaskList(Query *query, uint64 jobId,
2231 						 RelationRestrictionContext *relationRestrictionContext,
2232 						 List *prunedRelationShardList, TaskType taskType, bool
2233 						 modifyRequiresCoordinatorEvaluation,
2234 						 DeferredErrorMessage **planningError)
2235 {
2236 	List *sqlTaskList = NIL;
2237 	ListCell *restrictionCell = NULL;
2238 	uint32 taskIdIndex = 1; /* 0 is reserved for invalid taskId */
2239 	int shardCount = 0;
2240 	bool *taskRequiredForShardIndex = NULL;
2241 	ListCell *prunedRelationShardCell = NULL;
2242 
2243 	/* error if shards are not co-partitioned */
2244 	ErrorIfUnsupportedShardDistribution(query);
2245 
2246 	if (list_length(relationRestrictionContext->relationRestrictionList) == 0)
2247 	{
2248 		*planningError = DeferredError(ERRCODE_FEATURE_NOT_SUPPORTED,
2249 									   "cannot handle complex subqueries when the "
2250 									   "router executor is disabled",
2251 									   NULL, NULL);
2252 		return NIL;
2253 	}
2254 
2255 	/* defaults to be used if this is a reference table-only query */
2256 	int minShardOffset = 0;
2257 	int maxShardOffset = 0;
2258 
2259 	forboth(prunedRelationShardCell, prunedRelationShardList,
2260 			restrictionCell, relationRestrictionContext->relationRestrictionList)
2261 	{
2262 		RelationRestriction *relationRestriction =
2263 			(RelationRestriction *) lfirst(restrictionCell);
2264 		Oid relationId = relationRestriction->relationId;
2265 		List *prunedShardList = (List *) lfirst(prunedRelationShardCell);
2266 		ListCell *shardIntervalCell = NULL;
2267 
2268 		CitusTableCacheEntry *cacheEntry = GetCitusTableCacheEntry(relationId);
2269 		if (IsCitusTableTypeCacheEntry(cacheEntry, CITUS_TABLE_WITH_NO_DIST_KEY))
2270 		{
2271 			continue;
2272 		}
2273 
2274 		/* we expect distributed tables to have the same shard count */
2275 		if (shardCount > 0 && shardCount != cacheEntry->shardIntervalArrayLength)
2276 		{
2277 			*planningError = DeferredError(ERRCODE_FEATURE_NOT_SUPPORTED,
2278 										   "shard counts of co-located tables do not "
2279 										   "match",
2280 										   NULL, NULL);
2281 			return NIL;
2282 		}
2283 
2284 		if (taskRequiredForShardIndex == NULL)
2285 		{
2286 			shardCount = cacheEntry->shardIntervalArrayLength;
2287 			taskRequiredForShardIndex = (bool *) palloc0(shardCount);
2288 
2289 			/* there is a distributed table, find the shard range */
2290 			minShardOffset = shardCount;
2291 			maxShardOffset = -1;
2292 		}
2293 
2294 		/*
2295 		 * For left joins we don't care about the shards pruned for the right hand side.
2296 		 * If the right hand side would prune to a smaller set we should still send it to
2297 		 * all tables of the left hand side. However if the right hand side is bigger than
2298 		 * the left hand side we don't have to send the query to any shard that is not
2299 		 * matching anything on the left hand side.
2300 		 *
2301 		 * Instead we will simply skip any RelationRestriction if it is an OUTER join and
2302 		 * the table is part of the non-outer side of the join.
2303 		 */
2304 		if (IsInnerTableOfOuterJoin(relationRestriction))
2305 		{
2306 			continue;
2307 		}
2308 
2309 		foreach(shardIntervalCell, prunedShardList)
2310 		{
2311 			ShardInterval *shardInterval = (ShardInterval *) lfirst(shardIntervalCell);
2312 			int shardIndex = shardInterval->shardIndex;
2313 
2314 			taskRequiredForShardIndex[shardIndex] = true;
2315 
2316 			minShardOffset = Min(minShardOffset, shardIndex);
2317 			maxShardOffset = Max(maxShardOffset, shardIndex);
2318 		}
2319 	}
2320 
2321 	/*
2322 	 * To avoid iterating through all shards indexes we keep the minimum and maximum
2323 	 * offsets of shards that were not pruned away. This optimisation is primarily
2324 	 * relevant for queries on range-distributed tables that, due to range filters,
2325 	 * prune to a small number of adjacent shards.
2326 	 *
2327 	 * In other cases, such as an OR condition on a hash-distributed table, we may
2328 	 * still visit most or all shards even if some of them were pruned away. However,
2329 	 * given that hash-distributed tables typically only have a few shards the
2330 	 * iteration is still very fast.
2331 	 */
2332 	for (int shardOffset = minShardOffset; shardOffset <= maxShardOffset; shardOffset++)
2333 	{
2334 		if (taskRequiredForShardIndex != NULL && !taskRequiredForShardIndex[shardOffset])
2335 		{
2336 			/* this shard index is pruned away for all relations */
2337 			continue;
2338 		}
2339 
2340 		Task *subqueryTask = QueryPushdownTaskCreate(query, shardOffset,
2341 													 relationRestrictionContext,
2342 													 taskIdIndex,
2343 													 taskType,
2344 													 modifyRequiresCoordinatorEvaluation,
2345 													 planningError);
2346 		if (*planningError != NULL)
2347 		{
2348 			return NIL;
2349 		}
2350 		subqueryTask->jobId = jobId;
2351 		sqlTaskList = lappend(sqlTaskList, subqueryTask);
2352 
2353 		++taskIdIndex;
2354 	}
2355 
2356 	/* If it is a modify task with multiple tables */
2357 	if (taskType == MODIFY_TASK && list_length(
2358 			relationRestrictionContext->relationRestrictionList) > 1)
2359 	{
2360 		ListCell *taskCell = NULL;
2361 		foreach(taskCell, sqlTaskList)
2362 		{
2363 			Task *task = (Task *) lfirst(taskCell);
2364 			task->modifyWithSubquery = true;
2365 		}
2366 	}
2367 
2368 	return sqlTaskList;
2369 }
2370 
2371 
2372 /*
2373  * IsInnerTableOfOuterJoin tests based on the join information envoded in a
2374  * RelationRestriction if the table accessed for this relation is
2375  *   a) in an outer join
2376  *   b) on the inner part of said join
2377  *
2378  * The function returns true only if both conditions above hold true
2379  */
2380 static bool
IsInnerTableOfOuterJoin(RelationRestriction * relationRestriction)2381 IsInnerTableOfOuterJoin(RelationRestriction *relationRestriction)
2382 {
2383 	RestrictInfo *joinInfo = NULL;
2384 	foreach_ptr(joinInfo, relationRestriction->relOptInfo->joininfo)
2385 	{
2386 		if (joinInfo->outer_relids == NULL)
2387 		{
2388 			/* not an outer join */
2389 			continue;
2390 		}
2391 
2392 		/*
2393 		 * This join restriction info describes an outer join, we need to figure out if
2394 		 * our table is in the non outer part of this join. If that is the case this is a
2395 		 * non outer table of an outer join.
2396 		 */
2397 		bool isInOuter = bms_is_member(relationRestriction->relOptInfo->relid,
2398 									   joinInfo->outer_relids);
2399 		if (!isInOuter)
2400 		{
2401 			/* this table is joined in the inner part of an outer join */
2402 			return true;
2403 		}
2404 	}
2405 
2406 	/* we have not found any join clause that satisfies both requirements */
2407 	return false;
2408 }
2409 
2410 
2411 /*
2412  * ErrorIfUnsupportedShardDistribution gets list of relations in the given query
2413  * and checks if two conditions below hold for them, otherwise it errors out.
2414  * a. Every relation is distributed by range or hash. This means shards are
2415  * disjoint based on the partition column.
2416  * b. All relations have 1-to-1 shard partitioning between them. This means
2417  * shard count for every relation is same and for every shard in a relation
2418  * there is exactly one shard in other relations with same min/max values.
2419  */
2420 static void
ErrorIfUnsupportedShardDistribution(Query * query)2421 ErrorIfUnsupportedShardDistribution(Query *query)
2422 {
2423 	Oid firstTableRelationId = InvalidOid;
2424 	List *relationIdList = DistributedRelationIdList(query);
2425 	List *nonReferenceRelations = NIL;
2426 	ListCell *relationIdCell = NULL;
2427 	uint32 relationIndex = 0;
2428 	uint32 rangeDistributedRelationCount = 0;
2429 	uint32 hashDistributedRelationCount = 0;
2430 	uint32 appendDistributedRelationCount = 0;
2431 
2432 	foreach(relationIdCell, relationIdList)
2433 	{
2434 		Oid relationId = lfirst_oid(relationIdCell);
2435 		if (IsCitusTableType(relationId, RANGE_DISTRIBUTED))
2436 		{
2437 			rangeDistributedRelationCount++;
2438 			nonReferenceRelations = lappend_oid(nonReferenceRelations,
2439 												relationId);
2440 		}
2441 		else if (IsCitusTableType(relationId, HASH_DISTRIBUTED))
2442 		{
2443 			hashDistributedRelationCount++;
2444 			nonReferenceRelations = lappend_oid(nonReferenceRelations,
2445 												relationId);
2446 		}
2447 		else if (IsCitusTableType(relationId, CITUS_TABLE_WITH_NO_DIST_KEY))
2448 		{
2449 			/* do not need to handle non-distributed tables */
2450 			continue;
2451 		}
2452 		else
2453 		{
2454 			CitusTableCacheEntry *distTableEntry = GetCitusTableCacheEntry(relationId);
2455 			if (distTableEntry->hasOverlappingShardInterval)
2456 			{
2457 				ereport(ERROR, (errcode(ERRCODE_FEATURE_NOT_SUPPORTED),
2458 								errmsg("cannot push down this subquery"),
2459 								errdetail("Currently append partitioned relations "
2460 										  "with overlapping shard intervals are "
2461 										  "not supported")));
2462 			}
2463 
2464 			appendDistributedRelationCount++;
2465 		}
2466 	}
2467 
2468 	if ((rangeDistributedRelationCount > 0) && (hashDistributedRelationCount > 0))
2469 	{
2470 		ereport(ERROR, (errcode(ERRCODE_FEATURE_NOT_SUPPORTED),
2471 						errmsg("cannot push down this subquery"),
2472 						errdetail("A query including both range and hash "
2473 								  "partitioned relations are unsupported")));
2474 	}
2475 	else if ((rangeDistributedRelationCount > 0) && (appendDistributedRelationCount > 0))
2476 	{
2477 		ereport(ERROR, (errcode(ERRCODE_FEATURE_NOT_SUPPORTED),
2478 						errmsg("cannot push down this subquery"),
2479 						errdetail("A query including both range and append "
2480 								  "partitioned relations are unsupported")));
2481 	}
2482 	else if ((appendDistributedRelationCount > 0) && (hashDistributedRelationCount > 0))
2483 	{
2484 		ereport(ERROR, (errcode(ERRCODE_FEATURE_NOT_SUPPORTED),
2485 						errmsg("cannot push down this subquery"),
2486 						errdetail("A query including both append and hash "
2487 								  "partitioned relations are unsupported")));
2488 	}
2489 
2490 	foreach(relationIdCell, nonReferenceRelations)
2491 	{
2492 		Oid relationId = lfirst_oid(relationIdCell);
2493 		Oid currentRelationId = relationId;
2494 
2495 		/* get shard list of first relation and continue for the next relation */
2496 		if (relationIndex == 0)
2497 		{
2498 			firstTableRelationId = relationId;
2499 			relationIndex++;
2500 
2501 			continue;
2502 		}
2503 
2504 		/* check if this table has 1-1 shard partitioning with first table */
2505 		bool coPartitionedTables = CoPartitionedTables(firstTableRelationId,
2506 													   currentRelationId);
2507 		if (!coPartitionedTables)
2508 		{
2509 			ereport(ERROR, (errcode(ERRCODE_FEATURE_NOT_SUPPORTED),
2510 							errmsg("cannot push down this subquery"),
2511 							errdetail("Shards of relations in subquery need to "
2512 									  "have 1-to-1 shard partitioning")));
2513 		}
2514 	}
2515 }
2516 
2517 
2518 /*
2519  * SubqueryTaskCreate creates a sql task by replacing the target
2520  * shardInterval's boundary value.
2521  */
2522 static Task *
QueryPushdownTaskCreate(Query * originalQuery,int shardIndex,RelationRestrictionContext * restrictionContext,uint32 taskId,TaskType taskType,bool modifyRequiresCoordinatorEvaluation,DeferredErrorMessage ** planningError)2523 QueryPushdownTaskCreate(Query *originalQuery, int shardIndex,
2524 						RelationRestrictionContext *restrictionContext, uint32 taskId,
2525 						TaskType taskType, bool modifyRequiresCoordinatorEvaluation,
2526 						DeferredErrorMessage **planningError)
2527 {
2528 	Query *taskQuery = copyObject(originalQuery);
2529 
2530 	StringInfo queryString = makeStringInfo();
2531 	ListCell *restrictionCell = NULL;
2532 	List *taskShardList = NIL;
2533 	List *relationShardList = NIL;
2534 	uint64 jobId = INVALID_JOB_ID;
2535 	uint64 anchorShardId = INVALID_SHARD_ID;
2536 	bool modifyWithSubselect = false;
2537 	RangeTblEntry *resultRangeTable = NULL;
2538 	Oid resultRelationOid = InvalidOid;
2539 
2540 	/*
2541 	 * If it is a modify query with sub-select, we need to set result relation shard's id
2542 	 * as anchor shard id.
2543 	 */
2544 	if (UpdateOrDeleteQuery(originalQuery))
2545 	{
2546 		resultRangeTable = rt_fetch(originalQuery->resultRelation, originalQuery->rtable);
2547 		resultRelationOid = resultRangeTable->relid;
2548 		modifyWithSubselect = true;
2549 	}
2550 
2551 	/*
2552 	 * Find the relevant shard out of each relation for this task.
2553 	 */
2554 	foreach(restrictionCell, restrictionContext->relationRestrictionList)
2555 	{
2556 		RelationRestriction *relationRestriction =
2557 			(RelationRestriction *) lfirst(restrictionCell);
2558 		Oid relationId = relationRestriction->relationId;
2559 		ShardInterval *shardInterval = NULL;
2560 
2561 		CitusTableCacheEntry *cacheEntry = GetCitusTableCacheEntry(relationId);
2562 		if (IsCitusTableTypeCacheEntry(cacheEntry, CITUS_TABLE_WITH_NO_DIST_KEY))
2563 		{
2564 			/* non-distributed tables have only one shard */
2565 			shardInterval = cacheEntry->sortedShardIntervalArray[0];
2566 
2567 			/* only use reference table as anchor shard if none exists yet */
2568 			if (anchorShardId == INVALID_SHARD_ID)
2569 			{
2570 				anchorShardId = shardInterval->shardId;
2571 			}
2572 		}
2573 		else if (UpdateOrDeleteQuery(originalQuery))
2574 		{
2575 			shardInterval = cacheEntry->sortedShardIntervalArray[shardIndex];
2576 			if (!modifyWithSubselect || relationId == resultRelationOid)
2577 			{
2578 				/* for UPDATE/DELETE the shard in the result relation becomes the anchor shard */
2579 				anchorShardId = shardInterval->shardId;
2580 			}
2581 		}
2582 		else
2583 		{
2584 			/* for SELECT we pick an arbitrary shard as the anchor shard */
2585 			shardInterval = cacheEntry->sortedShardIntervalArray[shardIndex];
2586 			anchorShardId = shardInterval->shardId;
2587 		}
2588 
2589 		ShardInterval *copiedShardInterval = CopyShardInterval(shardInterval);
2590 
2591 		taskShardList = lappend(taskShardList, list_make1(copiedShardInterval));
2592 
2593 		RelationShard *relationShard = CitusMakeNode(RelationShard);
2594 		relationShard->relationId = copiedShardInterval->relationId;
2595 		relationShard->shardId = copiedShardInterval->shardId;
2596 
2597 		relationShardList = lappend(relationShardList, relationShard);
2598 	}
2599 
2600 	Assert(anchorShardId != INVALID_SHARD_ID);
2601 
2602 	List *taskPlacementList = PlacementsForWorkersContainingAllShards(taskShardList);
2603 	if (list_length(taskPlacementList) == 0)
2604 	{
2605 		*planningError = DeferredError(ERRCODE_FEATURE_NOT_SUPPORTED,
2606 									   "cannot find a worker that has active placements for all "
2607 									   "shards in the query",
2608 									   NULL, NULL);
2609 
2610 		return NULL;
2611 	}
2612 
2613 	/*
2614 	 * Augment the relations in the query with the shard IDs.
2615 	 */
2616 	UpdateRelationToShardNames((Node *) taskQuery, relationShardList);
2617 
2618 	/*
2619 	 * Ands are made implicit during shard pruning, as predicate comparison and
2620 	 * refutation depend on it being so. We need to make them explicit again so
2621 	 * that the query string is generated as (...) AND (...) as opposed to
2622 	 * (...), (...).
2623 	 */
2624 	if (taskQuery->jointree->quals != NULL && IsA(taskQuery->jointree->quals, List))
2625 	{
2626 		taskQuery->jointree->quals = (Node *) make_ands_explicit(
2627 			(List *) taskQuery->jointree->quals);
2628 	}
2629 
2630 	Task *subqueryTask = CreateBasicTask(jobId, taskId, taskType, NULL);
2631 
2632 	if ((taskType == MODIFY_TASK && !modifyRequiresCoordinatorEvaluation) ||
2633 		taskType == READ_TASK)
2634 	{
2635 		pg_get_query_def(taskQuery, queryString);
2636 		ereport(DEBUG4, (errmsg("distributed statement: %s",
2637 								ApplyLogRedaction(queryString->data))));
2638 		SetTaskQueryString(subqueryTask, queryString->data);
2639 	}
2640 
2641 	subqueryTask->dependentTaskList = NULL;
2642 	subqueryTask->anchorShardId = anchorShardId;
2643 	subqueryTask->taskPlacementList = taskPlacementList;
2644 	subqueryTask->relationShardList = relationShardList;
2645 
2646 	return subqueryTask;
2647 }
2648 
2649 
2650 /*
2651  * CoPartitionedTables checks if given two distributed tables have 1-to-1 shard
2652  * placement matching. It first checks for the shard count, if tables don't have
2653  * same amount shard then it returns false. Note that, if any table does not
2654  * have any shard, it returns true. If two tables have same amount of shards,
2655  * we check colocationIds for hash distributed tables and shardInterval's min
2656  * max values for append and range distributed tables.
2657  */
2658 bool
CoPartitionedTables(Oid firstRelationId,Oid secondRelationId)2659 CoPartitionedTables(Oid firstRelationId, Oid secondRelationId)
2660 {
2661 	if (firstRelationId == secondRelationId)
2662 	{
2663 		return true;
2664 	}
2665 
2666 	CitusTableCacheEntry *firstTableCache = GetCitusTableCacheEntry(firstRelationId);
2667 	CitusTableCacheEntry *secondTableCache = GetCitusTableCacheEntry(secondRelationId);
2668 
2669 	ShardInterval **sortedFirstIntervalArray = firstTableCache->sortedShardIntervalArray;
2670 	ShardInterval **sortedSecondIntervalArray =
2671 		secondTableCache->sortedShardIntervalArray;
2672 	uint32 firstListShardCount = firstTableCache->shardIntervalArrayLength;
2673 	uint32 secondListShardCount = secondTableCache->shardIntervalArrayLength;
2674 	FmgrInfo *comparisonFunction = firstTableCache->shardIntervalCompareFunction;
2675 
2676 	/* reference tables are always & only copartitioned with reference tables */
2677 	if (IsCitusTableTypeCacheEntry(firstTableCache, CITUS_TABLE_WITH_NO_DIST_KEY) &&
2678 		IsCitusTableTypeCacheEntry(secondTableCache, CITUS_TABLE_WITH_NO_DIST_KEY))
2679 	{
2680 		return true;
2681 	}
2682 	else if (IsCitusTableTypeCacheEntry(firstTableCache, CITUS_TABLE_WITH_NO_DIST_KEY) ||
2683 			 IsCitusTableTypeCacheEntry(secondTableCache, CITUS_TABLE_WITH_NO_DIST_KEY))
2684 	{
2685 		return false;
2686 	}
2687 
2688 	if (firstListShardCount != secondListShardCount)
2689 	{
2690 		return false;
2691 	}
2692 
2693 	/* if there are not any shards just return true */
2694 	if (firstListShardCount == 0)
2695 	{
2696 		return true;
2697 	}
2698 
2699 	Assert(comparisonFunction != NULL);
2700 
2701 	/*
2702 	 * Check if the tables have the same colocation ID - if so, we know
2703 	 * they're colocated.
2704 	 */
2705 	if (firstTableCache->colocationId != INVALID_COLOCATION_ID &&
2706 		firstTableCache->colocationId == secondTableCache->colocationId)
2707 	{
2708 		return true;
2709 	}
2710 
2711 	/*
2712 	 * For hash distributed tables two tables are accepted as colocated only if
2713 	 * they have the same colocationId. Otherwise they may have same minimum and
2714 	 * maximum values for each shard interval, yet hash function may result with
2715 	 * different values for the same value. int vs bigint can be given as an
2716 	 * example.
2717 	 */
2718 	if (IsCitusTableTypeCacheEntry(firstTableCache, HASH_DISTRIBUTED) ||
2719 		IsCitusTableTypeCacheEntry(secondTableCache, HASH_DISTRIBUTED))
2720 	{
2721 		return false;
2722 	}
2723 
2724 
2725 	/*
2726 	 * Don't compare unequal types
2727 	 */
2728 	Oid collation = firstTableCache->partitionColumn->varcollid;
2729 	if (firstTableCache->partitionColumn->vartype !=
2730 		secondTableCache->partitionColumn->vartype ||
2731 		collation != secondTableCache->partitionColumn->varcollid)
2732 	{
2733 		return false;
2734 	}
2735 
2736 
2737 	/*
2738 	 * If not known to be colocated check if the remaining shards are
2739 	 * anyway. Do so by comparing the shard interval arrays that are sorted on
2740 	 * interval minimum values. Then it compares every shard interval in order
2741 	 * and if any pair of shard intervals are not equal or they are not located
2742 	 * in the same node it returns false.
2743 	 */
2744 	for (uint32 intervalIndex = 0; intervalIndex < firstListShardCount; intervalIndex++)
2745 	{
2746 		ShardInterval *firstInterval = sortedFirstIntervalArray[intervalIndex];
2747 		ShardInterval *secondInterval = sortedSecondIntervalArray[intervalIndex];
2748 
2749 		bool shardIntervalsEqual = ShardIntervalsEqual(comparisonFunction,
2750 													   collation,
2751 													   firstInterval,
2752 													   secondInterval);
2753 		if (!shardIntervalsEqual || !CoPlacedShardIntervals(firstInterval,
2754 															secondInterval))
2755 		{
2756 			return false;
2757 		}
2758 	}
2759 
2760 	return true;
2761 }
2762 
2763 
2764 /*
2765  * CoPlacedShardIntervals checks whether the given intervals located in the same nodes.
2766  */
2767 static bool
CoPlacedShardIntervals(ShardInterval * firstInterval,ShardInterval * secondInterval)2768 CoPlacedShardIntervals(ShardInterval *firstInterval, ShardInterval *secondInterval)
2769 {
2770 	List *firstShardPlacementList = ShardPlacementListWithoutOrphanedPlacements(
2771 		firstInterval->shardId);
2772 	List *secondShardPlacementList = ShardPlacementListWithoutOrphanedPlacements(
2773 		secondInterval->shardId);
2774 	ListCell *firstShardPlacementCell = NULL;
2775 	ListCell *secondShardPlacementCell = NULL;
2776 
2777 	/* Shards must have same number of placements */
2778 	if (list_length(firstShardPlacementList) != list_length(secondShardPlacementList))
2779 	{
2780 		return false;
2781 	}
2782 
2783 	firstShardPlacementList = SortList(firstShardPlacementList, CompareShardPlacements);
2784 	secondShardPlacementList = SortList(secondShardPlacementList, CompareShardPlacements);
2785 
2786 	forboth(firstShardPlacementCell, firstShardPlacementList, secondShardPlacementCell,
2787 			secondShardPlacementList)
2788 	{
2789 		ShardPlacement *firstShardPlacement = (ShardPlacement *) lfirst(
2790 			firstShardPlacementCell);
2791 		ShardPlacement *secondShardPlacement = (ShardPlacement *) lfirst(
2792 			secondShardPlacementCell);
2793 
2794 		if (firstShardPlacement->nodeId != secondShardPlacement->nodeId)
2795 		{
2796 			return false;
2797 		}
2798 	}
2799 
2800 	return true;
2801 }
2802 
2803 
2804 /*
2805  * ShardIntervalsEqual checks if given shard intervals have equal min/max values.
2806  */
2807 static bool
ShardIntervalsEqual(FmgrInfo * comparisonFunction,Oid collation,ShardInterval * firstInterval,ShardInterval * secondInterval)2808 ShardIntervalsEqual(FmgrInfo *comparisonFunction, Oid collation,
2809 					ShardInterval *firstInterval, ShardInterval *secondInterval)
2810 {
2811 	bool shardIntervalsEqual = false;
2812 
2813 	Datum firstMin = firstInterval->minValue;
2814 	Datum firstMax = firstInterval->maxValue;
2815 	Datum secondMin = secondInterval->minValue;
2816 	Datum secondMax = secondInterval->maxValue;
2817 
2818 	if (firstInterval->minValueExists && firstInterval->maxValueExists &&
2819 		secondInterval->minValueExists && secondInterval->maxValueExists)
2820 	{
2821 		Datum minDatum = FunctionCall2Coll(comparisonFunction, collation, firstMin,
2822 										   secondMin);
2823 		Datum maxDatum = FunctionCall2Coll(comparisonFunction, collation, firstMax,
2824 										   secondMax);
2825 		int firstComparison = DatumGetInt32(minDatum);
2826 		int secondComparison = DatumGetInt32(maxDatum);
2827 
2828 		if (firstComparison == 0 && secondComparison == 0)
2829 		{
2830 			shardIntervalsEqual = true;
2831 		}
2832 	}
2833 
2834 	return shardIntervalsEqual;
2835 }
2836 
2837 
2838 /*
2839  * SqlTaskList creates a list of SQL tasks to execute the given job. For this,
2840  * the function walks over each range table in the job's range table list, gets
2841  * each range table's table fragments, and prunes unneeded table fragments. The
2842  * function then joins table fragments from different range tables, and creates
2843  * all fragment combinations. For each created combination, the function builds
2844  * a SQL task, and appends this task to a task list.
2845  */
2846 static List *
SqlTaskList(Job * job)2847 SqlTaskList(Job *job)
2848 {
2849 	List *sqlTaskList = NIL;
2850 	uint32 taskIdIndex = 1; /* 0 is reserved for invalid taskId */
2851 	uint64 jobId = job->jobId;
2852 	bool anchorRangeTableBasedAssignment = false;
2853 	uint32 anchorRangeTableId = 0;
2854 
2855 	Query *jobQuery = job->jobQuery;
2856 	List *rangeTableList = jobQuery->rtable;
2857 	List *whereClauseList = (List *) jobQuery->jointree->quals;
2858 	List *dependentJobList = job->dependentJobList;
2859 
2860 	/*
2861 	 * If we don't depend on a hash partition, then we determine the largest
2862 	 * table around which we build our queries. This reduces data fetching.
2863 	 */
2864 	bool dependsOnHashPartitionJob = DependsOnHashPartitionJob(job);
2865 	if (!dependsOnHashPartitionJob)
2866 	{
2867 		anchorRangeTableBasedAssignment = true;
2868 		anchorRangeTableId = AnchorRangeTableId(rangeTableList);
2869 
2870 		Assert(anchorRangeTableId != 0);
2871 		Assert(anchorRangeTableId <= list_length(rangeTableList));
2872 	}
2873 
2874 	/* adjust our column old attributes for partition pruning to work */
2875 	AdjustColumnOldAttributes(whereClauseList);
2876 	AdjustColumnOldAttributes(jobQuery->targetList);
2877 
2878 	/*
2879 	 * Ands are made implicit during shard pruning, as predicate comparison and
2880 	 * refutation depend on it being so. We need to make them explicit again so
2881 	 * that the query string is generated as (...) AND (...) as opposed to
2882 	 * (...), (...).
2883 	 */
2884 	Node *whereClauseTree = (Node *) make_ands_explicit(
2885 		(List *) jobQuery->jointree->quals);
2886 	jobQuery->jointree->quals = whereClauseTree;
2887 
2888 	/*
2889 	 * For each range table, we first get a list of their shards or merge tasks.
2890 	 * We also apply partition pruning based on the selection criteria. If all
2891 	 * range table fragments are pruned away, we return an empty task list.
2892 	 */
2893 	List *rangeTableFragmentsList = RangeTableFragmentsList(rangeTableList,
2894 															whereClauseList,
2895 															dependentJobList);
2896 	if (rangeTableFragmentsList == NIL)
2897 	{
2898 		return NIL;
2899 	}
2900 
2901 	/*
2902 	 * We then generate fragment combinations according to how range tables join
2903 	 * with each other (and apply join pruning). Each fragment combination then
2904 	 * represents one SQL task's dependencies.
2905 	 */
2906 	List *fragmentCombinationList = FragmentCombinationList(rangeTableFragmentsList,
2907 															jobQuery, dependentJobList);
2908 
2909 	ListCell *fragmentCombinationCell = NULL;
2910 	foreach(fragmentCombinationCell, fragmentCombinationList)
2911 	{
2912 		List *fragmentCombination = (List *) lfirst(fragmentCombinationCell);
2913 
2914 		/* create tasks to fetch fragments required for the sql task */
2915 		List *dataFetchTaskList = DataFetchTaskList(jobId, taskIdIndex,
2916 													fragmentCombination);
2917 		int32 dataFetchTaskCount = list_length(dataFetchTaskList);
2918 		taskIdIndex += dataFetchTaskCount;
2919 
2920 		/* update range table entries with fragment aliases (in place) */
2921 		Query *taskQuery = copyObject(jobQuery);
2922 		List *fragmentRangeTableList = taskQuery->rtable;
2923 		UpdateRangeTableAlias(fragmentRangeTableList, fragmentCombination);
2924 
2925 		/* transform the updated task query to a SQL query string */
2926 		StringInfo sqlQueryString = makeStringInfo();
2927 		pg_get_query_def(taskQuery, sqlQueryString);
2928 
2929 		Task *sqlTask = CreateBasicTask(jobId, taskIdIndex, READ_TASK,
2930 										sqlQueryString->data);
2931 		sqlTask->dependentTaskList = dataFetchTaskList;
2932 		sqlTask->relationShardList = BuildRelationShardList(fragmentRangeTableList,
2933 															fragmentCombination);
2934 
2935 		/* log the query string we generated */
2936 		ereport(DEBUG4, (errmsg("generated sql query for task %d", sqlTask->taskId),
2937 						 errdetail("query string: \"%s\"",
2938 								   ApplyLogRedaction(sqlQueryString->data))));
2939 
2940 		sqlTask->anchorShardId = INVALID_SHARD_ID;
2941 		if (anchorRangeTableBasedAssignment)
2942 		{
2943 			sqlTask->anchorShardId = AnchorShardId(fragmentCombination,
2944 												   anchorRangeTableId);
2945 		}
2946 
2947 		taskIdIndex++;
2948 		sqlTaskList = lappend(sqlTaskList, sqlTask);
2949 	}
2950 
2951 	return sqlTaskList;
2952 }
2953 
2954 
2955 /*
2956  * RelabelTypeToCollateExpr converts RelabelType's into CollationExpr's.
2957  * With that, we will be able to pushdown COLLATE's.
2958  */
2959 CollateExpr *
RelabelTypeToCollateExpr(RelabelType * relabelType)2960 RelabelTypeToCollateExpr(RelabelType *relabelType)
2961 {
2962 	Assert(OidIsValid(relabelType->resultcollid));
2963 
2964 	CollateExpr *collateExpr = makeNode(CollateExpr);
2965 	collateExpr->arg = relabelType->arg;
2966 	collateExpr->collOid = relabelType->resultcollid;
2967 	collateExpr->location = relabelType->location;
2968 
2969 	return collateExpr;
2970 }
2971 
2972 
2973 /*
2974  * DependsOnHashPartitionJob checks if the given job depends on a hash
2975  * partitioning job.
2976  */
2977 static bool
DependsOnHashPartitionJob(Job * job)2978 DependsOnHashPartitionJob(Job *job)
2979 {
2980 	bool dependsOnHashPartitionJob = false;
2981 	List *dependentJobList = job->dependentJobList;
2982 
2983 	uint32 dependentJobCount = (uint32) list_length(dependentJobList);
2984 	if (dependentJobCount > 0)
2985 	{
2986 		Job *dependentJob = (Job *) linitial(dependentJobList);
2987 		if (CitusIsA(dependentJob, MapMergeJob))
2988 		{
2989 			MapMergeJob *mapMergeJob = (MapMergeJob *) dependentJob;
2990 			if (mapMergeJob->partitionType == DUAL_HASH_PARTITION_TYPE)
2991 			{
2992 				dependsOnHashPartitionJob = true;
2993 			}
2994 		}
2995 	}
2996 
2997 	return dependsOnHashPartitionJob;
2998 }
2999 
3000 
3001 /*
3002  * AnchorRangeTableId determines the table around which we build our queries,
3003  * and returns this table's range table id. We refer to this table as the anchor
3004  * table, and make sure that the anchor table's shards are moved or cached only
3005  * when absolutely necessary.
3006  */
3007 static uint32
AnchorRangeTableId(List * rangeTableList)3008 AnchorRangeTableId(List *rangeTableList)
3009 {
3010 	uint32 anchorRangeTableId = 0;
3011 	uint64 maxTableSize = 0;
3012 
3013 	/*
3014 	 * We first filter anything but ordinary tables. Then, we pick the table(s)
3015 	 * with the most number of shards as our anchor table. If multiple tables
3016 	 * have the most number of shards, we have a draw.
3017 	 */
3018 	List *baseTableIdList = BaseRangeTableIdList(rangeTableList);
3019 	List *anchorTableIdList = AnchorRangeTableIdList(rangeTableList, baseTableIdList);
3020 	ListCell *anchorTableIdCell = NULL;
3021 
3022 	int anchorTableIdCount = list_length(anchorTableIdList);
3023 	Assert(anchorTableIdCount > 0);
3024 
3025 	if (anchorTableIdCount == 1)
3026 	{
3027 		anchorRangeTableId = (uint32) linitial_int(anchorTableIdList);
3028 		return anchorRangeTableId;
3029 	}
3030 
3031 	/*
3032 	 * If more than one table has the most number of shards, we break the draw
3033 	 * by comparing table sizes and picking the table with the largest size.
3034 	 */
3035 	foreach(anchorTableIdCell, anchorTableIdList)
3036 	{
3037 		uint32 anchorTableId = (uint32) lfirst_int(anchorTableIdCell);
3038 		RangeTblEntry *tableEntry = rt_fetch(anchorTableId, rangeTableList);
3039 		uint64 tableSize = 0;
3040 
3041 		List *shardList = LoadShardList(tableEntry->relid);
3042 		ListCell *shardCell = NULL;
3043 
3044 		foreach(shardCell, shardList)
3045 		{
3046 			uint64 *shardIdPointer = (uint64 *) lfirst(shardCell);
3047 			uint64 shardId = (*shardIdPointer);
3048 			uint64 shardSize = ShardLength(shardId);
3049 
3050 			tableSize += shardSize;
3051 		}
3052 
3053 		if (tableSize > maxTableSize)
3054 		{
3055 			maxTableSize = tableSize;
3056 			anchorRangeTableId = anchorTableId;
3057 		}
3058 	}
3059 
3060 	if (anchorRangeTableId == 0)
3061 	{
3062 		/* all tables have the same shard count and size 0, pick the first */
3063 		anchorRangeTableId = (uint32) linitial_int(anchorTableIdList);
3064 	}
3065 
3066 	return anchorRangeTableId;
3067 }
3068 
3069 
3070 /*
3071  * BaseRangeTableIdList walks over range tables in the given range table list,
3072  * finds range tables that correspond to base (non-repartitioned) tables, and
3073  * returns these range tables' identifiers in a new list.
3074  */
3075 static List *
BaseRangeTableIdList(List * rangeTableList)3076 BaseRangeTableIdList(List *rangeTableList)
3077 {
3078 	List *baseRangeTableIdList = NIL;
3079 	uint32 rangeTableId = 1;
3080 
3081 	ListCell *rangeTableCell = NULL;
3082 	foreach(rangeTableCell, rangeTableList)
3083 	{
3084 		RangeTblEntry *rangeTableEntry = (RangeTblEntry *) lfirst(rangeTableCell);
3085 		if (GetRangeTblKind(rangeTableEntry) == CITUS_RTE_RELATION)
3086 		{
3087 			baseRangeTableIdList = lappend_int(baseRangeTableIdList, rangeTableId);
3088 		}
3089 
3090 		rangeTableId++;
3091 	}
3092 
3093 	return baseRangeTableIdList;
3094 }
3095 
3096 
3097 /*
3098  * AnchorRangeTableIdList finds ordinary table(s) with the most number of shards
3099  * and returns the corresponding range table id(s) in a list.
3100  */
3101 static List *
AnchorRangeTableIdList(List * rangeTableList,List * baseRangeTableIdList)3102 AnchorRangeTableIdList(List *rangeTableList, List *baseRangeTableIdList)
3103 {
3104 	List *anchorTableIdList = NIL;
3105 	uint32 maxShardCount = 0;
3106 	ListCell *baseRangeTableIdCell = NULL;
3107 
3108 	uint32 baseRangeTableCount = list_length(baseRangeTableIdList);
3109 	if (baseRangeTableCount == 1)
3110 	{
3111 		return baseRangeTableIdList;
3112 	}
3113 
3114 	foreach(baseRangeTableIdCell, baseRangeTableIdList)
3115 	{
3116 		uint32 baseRangeTableId = (uint32) lfirst_int(baseRangeTableIdCell);
3117 		RangeTblEntry *tableEntry = rt_fetch(baseRangeTableId, rangeTableList);
3118 		List *shardList = LoadShardList(tableEntry->relid);
3119 
3120 		uint32 shardCount = (uint32) list_length(shardList);
3121 		if (shardCount > maxShardCount)
3122 		{
3123 			anchorTableIdList = list_make1_int(baseRangeTableId);
3124 			maxShardCount = shardCount;
3125 		}
3126 		else if (shardCount == maxShardCount)
3127 		{
3128 			anchorTableIdList = lappend_int(anchorTableIdList, baseRangeTableId);
3129 		}
3130 	}
3131 
3132 	return anchorTableIdList;
3133 }
3134 
3135 
3136 /*
3137  * AdjustColumnOldAttributes adjust the old tableId (varnosyn) and old columnId
3138  * (varattnosyn), and sets them equal to the new values. We need this adjustment
3139  * for partition pruning where we compare these columns with partition columns
3140  * loaded from system catalogs. Since columns loaded from system catalogs always
3141  * have the same old and new values, we also need to adjust column values here.
3142  */
3143 static void
AdjustColumnOldAttributes(List * expressionList)3144 AdjustColumnOldAttributes(List *expressionList)
3145 {
3146 	List *columnList = pull_var_clause_default((Node *) expressionList);
3147 	ListCell *columnCell = NULL;
3148 
3149 	foreach(columnCell, columnList)
3150 	{
3151 		Var *column = (Var *) lfirst(columnCell);
3152 		column->varnosyn = column->varno;
3153 		column->varattnosyn = column->varattno;
3154 	}
3155 }
3156 
3157 
3158 /*
3159  * RangeTableFragmentsList walks over range tables in the given range table list
3160  * and for each table, the function creates a list of its fragments. A fragment
3161  * in this list represents either a regular shard or a merge task. Once a list
3162  * for each range table is constructed, the function applies partition pruning
3163  * using the given where clause list. Then, the function appends the fragment
3164  * list for each range table to a list of lists, and returns this list of lists.
3165  */
3166 static List *
RangeTableFragmentsList(List * rangeTableList,List * whereClauseList,List * dependentJobList)3167 RangeTableFragmentsList(List *rangeTableList, List *whereClauseList,
3168 						List *dependentJobList)
3169 {
3170 	List *rangeTableFragmentsList = NIL;
3171 	uint32 rangeTableIndex = 0;
3172 	const uint32 fragmentSize = sizeof(RangeTableFragment);
3173 
3174 	ListCell *rangeTableCell = NULL;
3175 	foreach(rangeTableCell, rangeTableList)
3176 	{
3177 		uint32 tableId = rangeTableIndex + 1; /* tableId starts from 1 */
3178 
3179 		RangeTblEntry *rangeTableEntry = (RangeTblEntry *) lfirst(rangeTableCell);
3180 		CitusRTEKind rangeTableKind = GetRangeTblKind(rangeTableEntry);
3181 
3182 		if (rangeTableKind == CITUS_RTE_RELATION)
3183 		{
3184 			Oid relationId = rangeTableEntry->relid;
3185 			ListCell *shardIntervalCell = NULL;
3186 			List *shardFragmentList = NIL;
3187 			List *prunedShardIntervalList = PruneShards(relationId, tableId,
3188 														whereClauseList, NULL);
3189 
3190 			/*
3191 			 * If we prune all shards for one table, query results will be empty.
3192 			 * We can therefore return NIL for the task list here.
3193 			 */
3194 			if (prunedShardIntervalList == NIL)
3195 			{
3196 				return NIL;
3197 			}
3198 
3199 			foreach(shardIntervalCell, prunedShardIntervalList)
3200 			{
3201 				ShardInterval *shardInterval =
3202 					(ShardInterval *) lfirst(shardIntervalCell);
3203 
3204 				RangeTableFragment *shardFragment = palloc0(fragmentSize);
3205 				shardFragment->fragmentReference = shardInterval;
3206 				shardFragment->fragmentType = CITUS_RTE_RELATION;
3207 				shardFragment->rangeTableId = tableId;
3208 
3209 				shardFragmentList = lappend(shardFragmentList, shardFragment);
3210 			}
3211 
3212 			rangeTableFragmentsList = lappend(rangeTableFragmentsList,
3213 											  shardFragmentList);
3214 		}
3215 		else if (rangeTableKind == CITUS_RTE_REMOTE_QUERY)
3216 		{
3217 			List *mergeTaskFragmentList = NIL;
3218 			ListCell *mergeTaskCell = NULL;
3219 
3220 			Job *dependentJob = JobForRangeTable(dependentJobList, rangeTableEntry);
3221 			Assert(CitusIsA(dependentJob, MapMergeJob));
3222 
3223 			MapMergeJob *dependentMapMergeJob = (MapMergeJob *) dependentJob;
3224 			List *mergeTaskList = dependentMapMergeJob->mergeTaskList;
3225 
3226 			/* if there are no tasks for the dependent job, just return NIL */
3227 			if (mergeTaskList == NIL)
3228 			{
3229 				return NIL;
3230 			}
3231 
3232 			foreach(mergeTaskCell, mergeTaskList)
3233 			{
3234 				Task *mergeTask = (Task *) lfirst(mergeTaskCell);
3235 
3236 				RangeTableFragment *mergeTaskFragment = palloc0(fragmentSize);
3237 				mergeTaskFragment->fragmentReference = mergeTask;
3238 				mergeTaskFragment->fragmentType = CITUS_RTE_REMOTE_QUERY;
3239 				mergeTaskFragment->rangeTableId = tableId;
3240 
3241 				mergeTaskFragmentList = lappend(mergeTaskFragmentList, mergeTaskFragment);
3242 			}
3243 
3244 			rangeTableFragmentsList = lappend(rangeTableFragmentsList,
3245 											  mergeTaskFragmentList);
3246 		}
3247 
3248 		rangeTableIndex++;
3249 	}
3250 
3251 	return rangeTableFragmentsList;
3252 }
3253 
3254 
3255 /*
3256  * BuildBaseConstraint builds and returns a base constraint. This constraint
3257  * implements an expression in the form of (column <= max && column >= min),
3258  * where column is the partition key, and min and max values represent a shard's
3259  * min and max values. These shard values are filled in after the constraint is
3260  * built.
3261  */
3262 Node *
BuildBaseConstraint(Var * column)3263 BuildBaseConstraint(Var *column)
3264 {
3265 	/* Build these expressions with only one argument for now */
3266 	OpExpr *lessThanExpr = MakeOpExpression(column, BTLessEqualStrategyNumber);
3267 	OpExpr *greaterThanExpr = MakeOpExpression(column, BTGreaterEqualStrategyNumber);
3268 
3269 	/* Build base constaint as an and of two qual conditions */
3270 	Node *baseConstraint = make_and_qual((Node *) lessThanExpr, (Node *) greaterThanExpr);
3271 
3272 	return baseConstraint;
3273 }
3274 
3275 
3276 /*
3277  * MakeOpExpression builds an operator expression node. This operator expression
3278  * implements the operator clause as defined by the variable and the strategy
3279  * number.
3280  */
3281 OpExpr *
MakeOpExpression(Var * variable,int16 strategyNumber)3282 MakeOpExpression(Var *variable, int16 strategyNumber)
3283 {
3284 	Oid typeId = variable->vartype;
3285 	Oid typeModId = variable->vartypmod;
3286 	Oid collationId = variable->varcollid;
3287 
3288 	Oid accessMethodId = BTREE_AM_OID;
3289 
3290 	OperatorCacheEntry *operatorCacheEntry = LookupOperatorByType(typeId, accessMethodId,
3291 																  strategyNumber);
3292 
3293 	Oid operatorId = operatorCacheEntry->operatorId;
3294 	Oid operatorClassInputType = operatorCacheEntry->operatorClassInputType;
3295 	char typeType = operatorCacheEntry->typeType;
3296 
3297 	/*
3298 	 * Relabel variable if input type of default operator class is not equal to
3299 	 * the variable type. Note that we don't relabel the variable if the default
3300 	 * operator class variable type is a pseudo-type.
3301 	 */
3302 	if (operatorClassInputType != typeId && typeType != TYPTYPE_PSEUDO)
3303 	{
3304 		variable = (Var *) makeRelabelType((Expr *) variable, operatorClassInputType,
3305 										   -1, collationId, COERCE_IMPLICIT_CAST);
3306 	}
3307 
3308 	Const *constantValue = makeNullConst(operatorClassInputType, typeModId, collationId);
3309 
3310 	/* Now make the expression with the given variable and a null constant */
3311 	OpExpr *expression = (OpExpr *) make_opclause(operatorId,
3312 												  InvalidOid, /* no result type yet */
3313 												  false, /* no return set */
3314 												  (Expr *) variable,
3315 												  (Expr *) constantValue,
3316 												  InvalidOid, collationId);
3317 
3318 	/* Set implementing function id and result type */
3319 	expression->opfuncid = get_opcode(operatorId);
3320 	expression->opresulttype = get_func_rettype(expression->opfuncid);
3321 
3322 	return expression;
3323 }
3324 
3325 
3326 /*
3327  * LookupOperatorByType is a wrapper around GetOperatorByType(),
3328  * operatorClassInputType() and get_typtype() functions that uses a cache to avoid
3329  * multiple lookups of operators and its related fields within a single session by
3330  * their types, access methods and strategy numbers.
3331  * LookupOperatorByType function errors out if it cannot find corresponding
3332  * default operator class with the given parameters on the system catalogs.
3333  */
3334 static OperatorCacheEntry *
LookupOperatorByType(Oid typeId,Oid accessMethodId,int16 strategyNumber)3335 LookupOperatorByType(Oid typeId, Oid accessMethodId, int16 strategyNumber)
3336 {
3337 	OperatorCacheEntry *matchingCacheEntry = NULL;
3338 	ListCell *cacheEntryCell = NULL;
3339 
3340 	/* search the cache */
3341 	foreach(cacheEntryCell, OperatorCache)
3342 	{
3343 		OperatorCacheEntry *cacheEntry = lfirst(cacheEntryCell);
3344 
3345 		if ((cacheEntry->typeId == typeId) &&
3346 			(cacheEntry->accessMethodId == accessMethodId) &&
3347 			(cacheEntry->strategyNumber == strategyNumber))
3348 		{
3349 			matchingCacheEntry = cacheEntry;
3350 			break;
3351 		}
3352 	}
3353 
3354 	/* if not found in the cache, call GetOperatorByType and put the result in cache */
3355 	if (matchingCacheEntry == NULL)
3356 	{
3357 		Oid operatorClassId = GetDefaultOpClass(typeId, accessMethodId);
3358 
3359 		if (operatorClassId == InvalidOid)
3360 		{
3361 			/* if operatorId is invalid, error out */
3362 			ereport(ERROR, (errmsg("cannot find default operator class for type:%d,"
3363 								   " access method: %d", typeId, accessMethodId)));
3364 		}
3365 
3366 		/* fill the other fields to the cache */
3367 		Oid operatorId = GetOperatorByType(typeId, accessMethodId, strategyNumber);
3368 		Oid operatorClassInputType = get_opclass_input_type(operatorClassId);
3369 		char typeType = get_typtype(operatorClassInputType);
3370 
3371 		/* make sure we've initialized CacheMemoryContext */
3372 		if (CacheMemoryContext == NULL)
3373 		{
3374 			CreateCacheMemoryContext();
3375 		}
3376 
3377 		MemoryContext oldContext = MemoryContextSwitchTo(CacheMemoryContext);
3378 
3379 		matchingCacheEntry = palloc0(sizeof(OperatorCacheEntry));
3380 		matchingCacheEntry->typeId = typeId;
3381 		matchingCacheEntry->accessMethodId = accessMethodId;
3382 		matchingCacheEntry->strategyNumber = strategyNumber;
3383 		matchingCacheEntry->operatorId = operatorId;
3384 		matchingCacheEntry->operatorClassInputType = operatorClassInputType;
3385 		matchingCacheEntry->typeType = typeType;
3386 
3387 		OperatorCache = lappend(OperatorCache, matchingCacheEntry);
3388 
3389 		MemoryContextSwitchTo(oldContext);
3390 	}
3391 
3392 	return matchingCacheEntry;
3393 }
3394 
3395 
3396 /*
3397  * GetOperatorByType returns the operator oid for the given type, access method,
3398  * and strategy number.
3399  */
3400 static Oid
GetOperatorByType(Oid typeId,Oid accessMethodId,int16 strategyNumber)3401 GetOperatorByType(Oid typeId, Oid accessMethodId, int16 strategyNumber)
3402 {
3403 	/* Get default operator class from pg_opclass */
3404 	Oid operatorClassId = GetDefaultOpClass(typeId, accessMethodId);
3405 
3406 	Oid operatorFamily = get_opclass_family(operatorClassId);
3407 	Oid operatorClassInputType = get_opclass_input_type(operatorClassId);
3408 
3409 	/* Lookup for the operator with the desired input type in the family */
3410 	Oid operatorId = get_opfamily_member(operatorFamily, operatorClassInputType,
3411 										 operatorClassInputType, strategyNumber);
3412 	return operatorId;
3413 }
3414 
3415 
3416 /*
3417  * BinaryOpExpression checks that a given expression is a binary operator. If
3418  * this is the case it returns true and sets leftOperand and rightOperand to
3419  * the left and right hand side of the operator. left/rightOperand will be
3420  * stripped of implicit coercions by strip_implicit_coercions.
3421  */
3422 bool
BinaryOpExpression(Expr * clause,Node ** leftOperand,Node ** rightOperand)3423 BinaryOpExpression(Expr *clause, Node **leftOperand, Node **rightOperand)
3424 {
3425 	if (!is_opclause(clause) || list_length(((OpExpr *) clause)->args) != 2)
3426 	{
3427 		if (leftOperand != NULL)
3428 		{
3429 			*leftOperand = NULL;
3430 		}
3431 		if (rightOperand != NULL)
3432 		{
3433 			*leftOperand = NULL;
3434 		}
3435 		return false;
3436 	}
3437 	if (leftOperand != NULL)
3438 	{
3439 		*leftOperand = get_leftop(clause);
3440 		Assert(*leftOperand != NULL);
3441 		*leftOperand = strip_implicit_coercions(*leftOperand);
3442 	}
3443 	if (rightOperand != NULL)
3444 	{
3445 		*rightOperand = get_rightop(clause);
3446 		Assert(*rightOperand != NULL);
3447 		*rightOperand = strip_implicit_coercions(*rightOperand);
3448 	}
3449 	return true;
3450 }
3451 
3452 
3453 /*
3454  * SimpleOpExpression checks that given expression is a simple operator
3455  * expression. A simple operator expression is a binary operator expression with
3456  * operands of a var and a non-null constant.
3457  */
3458 bool
SimpleOpExpression(Expr * clause)3459 SimpleOpExpression(Expr *clause)
3460 {
3461 	Const *constantClause = NULL;
3462 
3463 	Node *leftOperand;
3464 	Node *rightOperand;
3465 	if (!BinaryOpExpression(clause, &leftOperand, &rightOperand))
3466 	{
3467 		return false;
3468 	}
3469 
3470 	if (IsA(rightOperand, Const) && IsA(leftOperand, Var))
3471 	{
3472 		constantClause = (Const *) rightOperand;
3473 	}
3474 	else if (IsA(leftOperand, Const) && IsA(rightOperand, Var))
3475 	{
3476 		constantClause = (Const *) leftOperand;
3477 	}
3478 	else
3479 	{
3480 		return false;
3481 	}
3482 
3483 	if (constantClause->constisnull)
3484 	{
3485 		return false;
3486 	}
3487 
3488 	return true;
3489 }
3490 
3491 
3492 /*
3493  * MakeInt4Column creates a column of int4 type with invalid table id and max
3494  * attribute number.
3495  */
3496 Var *
MakeInt4Column()3497 MakeInt4Column()
3498 {
3499 	Index tableId = 0;
3500 	AttrNumber columnAttributeNumber = RESERVED_HASHED_COLUMN_ID;
3501 	Oid columnType = INT4OID;
3502 	int32 columnTypeMod = -1;
3503 	Oid columnCollationOid = InvalidOid;
3504 	Index columnLevelSup = 0;
3505 
3506 	Var *int4Column = makeVar(tableId, columnAttributeNumber, columnType,
3507 							  columnTypeMod, columnCollationOid, columnLevelSup);
3508 	return int4Column;
3509 }
3510 
3511 
3512 /* Updates the base constraint with the given min/max values. */
3513 void
UpdateConstraint(Node * baseConstraint,ShardInterval * shardInterval)3514 UpdateConstraint(Node *baseConstraint, ShardInterval *shardInterval)
3515 {
3516 	BoolExpr *andExpr = (BoolExpr *) baseConstraint;
3517 	Node *lessThanExpr = (Node *) linitial(andExpr->args);
3518 	Node *greaterThanExpr = (Node *) lsecond(andExpr->args);
3519 
3520 	Node *minNode = get_rightop((Expr *) greaterThanExpr); /* right op */
3521 	Node *maxNode = get_rightop((Expr *) lessThanExpr);    /* right op */
3522 
3523 	Assert(shardInterval != NULL);
3524 	Assert(shardInterval->minValueExists);
3525 	Assert(shardInterval->maxValueExists);
3526 	Assert(minNode != NULL);
3527 	Assert(maxNode != NULL);
3528 	Assert(IsA(minNode, Const));
3529 	Assert(IsA(maxNode, Const));
3530 
3531 	Const *minConstant = (Const *) minNode;
3532 	Const *maxConstant = (Const *) maxNode;
3533 
3534 	minConstant->constvalue = datumCopy(shardInterval->minValue,
3535 										shardInterval->valueByVal,
3536 										shardInterval->valueTypeLen);
3537 	maxConstant->constvalue = datumCopy(shardInterval->maxValue,
3538 										shardInterval->valueByVal,
3539 										shardInterval->valueTypeLen);
3540 
3541 	minConstant->constisnull = false;
3542 	maxConstant->constisnull = false;
3543 }
3544 
3545 
3546 /*
3547  * FragmentCombinationList first builds an ordered sequence of range tables that
3548  * join together. The function then iteratively adds fragments from each joined
3549  * range table, and forms fragment combinations (lists) that cover all tables.
3550  * While doing so, the function also performs join pruning to remove unnecessary
3551  * fragment pairs. Last, the function adds each fragment combination (list) to a
3552  * list, and returns this list.
3553  */
3554 static List *
FragmentCombinationList(List * rangeTableFragmentsList,Query * jobQuery,List * dependentJobList)3555 FragmentCombinationList(List *rangeTableFragmentsList, Query *jobQuery,
3556 						List *dependentJobList)
3557 {
3558 	List *fragmentCombinationList = NIL;
3559 	List *fragmentCombinationQueue = NIL;
3560 	List *emptyList = NIL;
3561 
3562 	/* find a sequence that joins the range tables in the list */
3563 	JoinSequenceNode *joinSequenceArray = JoinSequenceArray(rangeTableFragmentsList,
3564 															jobQuery,
3565 															dependentJobList);
3566 
3567 	/*
3568 	 * We use breadth-first search with pruning to create fragment combinations.
3569 	 * For this, we first queue the root node (an empty combination), and then
3570 	 * start traversing our search space.
3571 	 */
3572 	fragmentCombinationQueue = lappend(fragmentCombinationQueue, emptyList);
3573 	while (fragmentCombinationQueue != NIL)
3574 	{
3575 		ListCell *tableFragmentCell = NULL;
3576 		int32 joiningTableSequenceIndex = -1;
3577 
3578 		/* pop first element from the fragment queue */
3579 		List *fragmentCombination = linitial(fragmentCombinationQueue);
3580 		fragmentCombinationQueue = list_delete_first(fragmentCombinationQueue);
3581 
3582 		/*
3583 		 * If this combination covered all range tables in a join sequence, add
3584 		 * this combination to our result set.
3585 		 */
3586 		int32 joinSequenceIndex = list_length(fragmentCombination);
3587 		int32 rangeTableCount = list_length(rangeTableFragmentsList);
3588 		if (joinSequenceIndex == rangeTableCount)
3589 		{
3590 			fragmentCombinationList = lappend(fragmentCombinationList,
3591 											  fragmentCombination);
3592 			continue;
3593 		}
3594 
3595 		/* find the next range table to add to our search space */
3596 		uint32 tableId = joinSequenceArray[joinSequenceIndex].rangeTableId;
3597 		List *tableFragments = FindRangeTableFragmentsList(rangeTableFragmentsList,
3598 														   tableId);
3599 
3600 		/* resolve sequence index for the previous range table we join against */
3601 		int32 joiningTableId = joinSequenceArray[joinSequenceIndex].joiningRangeTableId;
3602 		if (joiningTableId != NON_PRUNABLE_JOIN)
3603 		{
3604 			for (int32 sequenceIndex = 0; sequenceIndex < rangeTableCount;
3605 				 sequenceIndex++)
3606 			{
3607 				JoinSequenceNode *joinSequenceNode = &joinSequenceArray[sequenceIndex];
3608 				if (joinSequenceNode->rangeTableId == joiningTableId)
3609 				{
3610 					joiningTableSequenceIndex = sequenceIndex;
3611 					break;
3612 				}
3613 			}
3614 
3615 			Assert(joiningTableSequenceIndex != -1);
3616 		}
3617 
3618 		/*
3619 		 * We walk over each range table fragment, and check if we can prune out
3620 		 * this fragment joining with the existing fragment combination. If we
3621 		 * can't prune away, we create a new fragment combination and add it to
3622 		 * our search space.
3623 		 */
3624 		foreach(tableFragmentCell, tableFragments)
3625 		{
3626 			RangeTableFragment *tableFragment = lfirst(tableFragmentCell);
3627 			bool joinPrunable = false;
3628 
3629 			if (joiningTableId != NON_PRUNABLE_JOIN)
3630 			{
3631 				RangeTableFragment *joiningTableFragment =
3632 					list_nth(fragmentCombination, joiningTableSequenceIndex);
3633 
3634 				joinPrunable = JoinPrunable(joiningTableFragment, tableFragment);
3635 			}
3636 
3637 			/* if join can't be pruned, extend fragment combination and search */
3638 			if (!joinPrunable)
3639 			{
3640 				List *newFragmentCombination = list_copy(fragmentCombination);
3641 				newFragmentCombination = lappend(newFragmentCombination, tableFragment);
3642 
3643 				fragmentCombinationQueue = lappend(fragmentCombinationQueue,
3644 												   newFragmentCombination);
3645 			}
3646 		}
3647 	}
3648 
3649 	return fragmentCombinationList;
3650 }
3651 
3652 
3653 /*
3654  * NodeIsRangeTblRefReferenceTable checks if the node is a RangeTblRef that
3655  * points to a reference table in the rangeTableList.
3656  */
3657 static bool
NodeIsRangeTblRefReferenceTable(Node * node,List * rangeTableList)3658 NodeIsRangeTblRefReferenceTable(Node *node, List *rangeTableList)
3659 {
3660 	if (!IsA(node, RangeTblRef))
3661 	{
3662 		return false;
3663 	}
3664 	RangeTblRef *tableRef = castNode(RangeTblRef, node);
3665 	RangeTblEntry *rangeTableEntry = rt_fetch(tableRef->rtindex, rangeTableList);
3666 	CitusRTEKind rangeTableType = GetRangeTblKind(rangeTableEntry);
3667 	if (rangeTableType != CITUS_RTE_RELATION)
3668 	{
3669 		return false;
3670 	}
3671 	return IsCitusTableType(rangeTableEntry->relid, REFERENCE_TABLE);
3672 }
3673 
3674 
3675 /*
3676  * FetchEqualityAttrNumsForRTE fetches the attribute numbers from quals
3677  * which have an equality operator
3678  */
3679 List *
FetchEqualityAttrNumsForRTE(Node * node)3680 FetchEqualityAttrNumsForRTE(Node *node)
3681 {
3682 	if (node == NULL)
3683 	{
3684 		return NIL;
3685 	}
3686 	if (IsA(node, List))
3687 	{
3688 		return FetchEqualityAttrNumsForList((List *) node);
3689 	}
3690 	else if (IsA(node, OpExpr))
3691 	{
3692 		return FetchEqualityAttrNumsForRTEOpExpr((OpExpr *) node);
3693 	}
3694 	else if (IsA(node, BoolExpr))
3695 	{
3696 		return FetchEqualityAttrNumsForRTEBoolExpr((BoolExpr *) node);
3697 	}
3698 	return NIL;
3699 }
3700 
3701 
3702 /*
3703  * FetchEqualityAttrNumsForList fetches the attribute numbers of expression
3704  * of the form "= constant" from the given node list.
3705  */
3706 static List *
FetchEqualityAttrNumsForList(List * nodeList)3707 FetchEqualityAttrNumsForList(List *nodeList)
3708 {
3709 	List *attributeNums = NIL;
3710 	Node *node = NULL;
3711 	bool hasAtLeastOneEquality = false;
3712 	foreach_ptr(node, nodeList)
3713 	{
3714 		List *fetchedEqualityAttrNums =
3715 			FetchEqualityAttrNumsForRTE(node);
3716 		hasAtLeastOneEquality |= list_length(fetchedEqualityAttrNums) > 0;
3717 		attributeNums = list_concat(attributeNums, fetchedEqualityAttrNums);
3718 	}
3719 
3720 	/*
3721 	 * the given list is in the form of AND'ed expressions
3722 	 * hence if we have one equality then it is enough.
3723 	 * E.g: dist.a = 5 AND dist.a > 10
3724 	 */
3725 	if (hasAtLeastOneEquality)
3726 	{
3727 		return attributeNums;
3728 	}
3729 	return NIL;
3730 }
3731 
3732 
3733 /*
3734  * FetchEqualityAttrNumsForRTEOpExpr fetches the attribute numbers of expression
3735  * of the form "= constant" from the given opExpr.
3736  */
3737 static List *
FetchEqualityAttrNumsForRTEOpExpr(OpExpr * opExpr)3738 FetchEqualityAttrNumsForRTEOpExpr(OpExpr *opExpr)
3739 {
3740 	if (!OperatorImplementsEquality(opExpr->opno))
3741 	{
3742 		return NIL;
3743 	}
3744 
3745 	List *attributeNums = NIL;
3746 	Var *var = NULL;
3747 	if (VarConstOpExprClause(opExpr, &var, NULL))
3748 	{
3749 		attributeNums = lappend_int(attributeNums, var->varattno);
3750 	}
3751 	return attributeNums;
3752 }
3753 
3754 
3755 /*
3756  * FetchEqualityAttrNumsForRTEBoolExpr fetches the attribute numbers of expression
3757  * of the form "= constant" from the given boolExpr.
3758  */
3759 static List *
FetchEqualityAttrNumsForRTEBoolExpr(BoolExpr * boolExpr)3760 FetchEqualityAttrNumsForRTEBoolExpr(BoolExpr *boolExpr)
3761 {
3762 	if (boolExpr->boolop != AND_EXPR && boolExpr->boolop != OR_EXPR)
3763 	{
3764 		return NIL;
3765 	}
3766 
3767 	List *attributeNums = NIL;
3768 	bool hasEquality = true;
3769 	Node *arg = NULL;
3770 	foreach_ptr(arg, boolExpr->args)
3771 	{
3772 		List *attributeNumsInSubExpression = FetchEqualityAttrNumsForRTE(arg);
3773 		if (boolExpr->boolop == AND_EXPR)
3774 		{
3775 			hasEquality |= list_length(attributeNumsInSubExpression) > 0;
3776 		}
3777 		else if (boolExpr->boolop == OR_EXPR)
3778 		{
3779 			hasEquality &= list_length(attributeNumsInSubExpression) > 0;
3780 		}
3781 		attributeNums = list_concat(attributeNums, attributeNumsInSubExpression);
3782 	}
3783 	if (hasEquality)
3784 	{
3785 		return attributeNums;
3786 	}
3787 	return NIL;
3788 }
3789 
3790 
3791 /*
3792  * JoinSequenceArray walks over the join nodes in the job query and constructs a join
3793  * sequence containing an entry for each joined table. The function then returns an
3794  * array of join sequence nodes, in which each node contains the id of a table in the
3795  * range table list and the id of a preceding table with which it is joined, if any.
3796  */
3797 static JoinSequenceNode *
JoinSequenceArray(List * rangeTableFragmentsList,Query * jobQuery,List * dependentJobList)3798 JoinSequenceArray(List *rangeTableFragmentsList, Query *jobQuery, List *dependentJobList)
3799 {
3800 	List *rangeTableList = jobQuery->rtable;
3801 	uint32 rangeTableCount = (uint32) list_length(rangeTableList);
3802 	uint32 sequenceNodeSize = sizeof(JoinSequenceNode);
3803 	uint32 joinedTableCount = 0;
3804 	ListCell *joinExprCell = NULL;
3805 	uint32 firstRangeTableId = 1;
3806 	JoinSequenceNode *joinSequenceArray = palloc0(rangeTableCount * sequenceNodeSize);
3807 
3808 	List *joinExprList = JoinExprList(jobQuery->jointree);
3809 
3810 	/* pick first range table as starting table for the join sequence */
3811 	if (list_length(joinExprList) > 0)
3812 	{
3813 		JoinExpr *firstExpr = (JoinExpr *) linitial(joinExprList);
3814 		RangeTblRef *leftTableRef = (RangeTblRef *) firstExpr->larg;
3815 		firstRangeTableId = leftTableRef->rtindex;
3816 	}
3817 	else
3818 	{
3819 		/* when there are no joins, the join sequence contains a node for the table */
3820 		firstRangeTableId = 1;
3821 	}
3822 
3823 	joinSequenceArray[joinedTableCount].rangeTableId = firstRangeTableId;
3824 	joinSequenceArray[joinedTableCount].joiningRangeTableId = NON_PRUNABLE_JOIN;
3825 	joinedTableCount++;
3826 
3827 	foreach(joinExprCell, joinExprList)
3828 	{
3829 		JoinExpr *joinExpr = (JoinExpr *) lfirst(joinExprCell);
3830 		RangeTblRef *rightTableRef = castNode(RangeTblRef, joinExpr->rarg);
3831 		uint32 nextRangeTableId = rightTableRef->rtindex;
3832 		Index existingRangeTableId = 0;
3833 		bool applyJoinPruning = false;
3834 
3835 		List *nextJoinClauseList = make_ands_implicit((Expr *) joinExpr->quals);
3836 		bool leftIsReferenceTable = NodeIsRangeTblRefReferenceTable(joinExpr->larg,
3837 																	rangeTableList);
3838 		bool rightIsReferenceTable = NodeIsRangeTblRefReferenceTable(joinExpr->rarg,
3839 																	 rangeTableList);
3840 		bool isReferenceJoin = IsSupportedReferenceJoin(joinExpr->jointype,
3841 														leftIsReferenceTable,
3842 														rightIsReferenceTable);
3843 
3844 		/*
3845 		 * If next join clause list is empty, the user tried a cartesian product
3846 		 * between tables. We don't support this functionality for non
3847 		 * reference joins, and error out.
3848 		 */
3849 		if (nextJoinClauseList == NIL && !isReferenceJoin)
3850 		{
3851 			ereport(ERROR, (errcode(ERRCODE_FEATURE_NOT_SUPPORTED),
3852 							errmsg("cannot perform distributed planning on this query"),
3853 							errdetail("Cartesian products are currently unsupported")));
3854 		}
3855 
3856 		/*
3857 		 * We now determine if we can apply join pruning between existing range
3858 		 * tables and this new one.
3859 		 */
3860 		Node *nextJoinClause = NULL;
3861 		foreach_ptr(nextJoinClause, nextJoinClauseList)
3862 		{
3863 			if (!NodeIsEqualsOpExpr(nextJoinClause))
3864 			{
3865 				continue;
3866 			}
3867 
3868 			OpExpr *nextJoinClauseOpExpr = castNode(OpExpr, nextJoinClause);
3869 
3870 			if (!IsJoinClause((Node *) nextJoinClauseOpExpr))
3871 			{
3872 				continue;
3873 			}
3874 
3875 			Var *leftColumn = LeftColumnOrNULL(nextJoinClauseOpExpr);
3876 			Var *rightColumn = RightColumnOrNULL(nextJoinClauseOpExpr);
3877 			if (leftColumn == NULL || rightColumn == NULL)
3878 			{
3879 				continue;
3880 			}
3881 
3882 			Index leftRangeTableId = leftColumn->varno;
3883 			Index rightRangeTableId = rightColumn->varno;
3884 
3885 			/*
3886 			 * We have a table from the existing join list joining with the next
3887 			 * table. First resolve the existing table's range table id.
3888 			 */
3889 			if (leftRangeTableId == nextRangeTableId)
3890 			{
3891 				existingRangeTableId = rightRangeTableId;
3892 			}
3893 			else
3894 			{
3895 				existingRangeTableId = leftRangeTableId;
3896 			}
3897 
3898 			/*
3899 			 * Then, we check if we can apply join pruning between the existing
3900 			 * range table and this new one. For this, columns need to have the
3901 			 * same type and be the partition column for their respective tables.
3902 			 */
3903 			if (leftColumn->vartype != rightColumn->vartype)
3904 			{
3905 				continue;
3906 			}
3907 
3908 			bool leftPartitioned = PartitionedOnColumn(leftColumn, rangeTableList,
3909 													   dependentJobList);
3910 			bool rightPartitioned = PartitionedOnColumn(rightColumn, rangeTableList,
3911 														dependentJobList);
3912 			if (leftPartitioned && rightPartitioned)
3913 			{
3914 				/* make sure this join clause references only simple columns */
3915 				CheckJoinBetweenColumns(nextJoinClauseOpExpr);
3916 
3917 				applyJoinPruning = true;
3918 				break;
3919 			}
3920 		}
3921 
3922 		/* set next joining range table's info in the join sequence */
3923 		JoinSequenceNode *nextJoinSequenceNode = &joinSequenceArray[joinedTableCount];
3924 		if (applyJoinPruning)
3925 		{
3926 			nextJoinSequenceNode->rangeTableId = nextRangeTableId;
3927 			nextJoinSequenceNode->joiningRangeTableId = (int32) existingRangeTableId;
3928 		}
3929 		else
3930 		{
3931 			nextJoinSequenceNode->rangeTableId = nextRangeTableId;
3932 			nextJoinSequenceNode->joiningRangeTableId = NON_PRUNABLE_JOIN;
3933 		}
3934 
3935 		joinedTableCount++;
3936 	}
3937 
3938 	return joinSequenceArray;
3939 }
3940 
3941 
3942 /*
3943  * PartitionedOnColumn finds the given column's range table entry, and checks if
3944  * that range table is partitioned on the given column. Note that since reference
3945  * tables do not have partition columns, the function returns false when the distributed
3946  * relation is a reference table.
3947  */
3948 static bool
PartitionedOnColumn(Var * column,List * rangeTableList,List * dependentJobList)3949 PartitionedOnColumn(Var *column, List *rangeTableList, List *dependentJobList)
3950 {
3951 	bool partitionedOnColumn = false;
3952 	Index rangeTableId = column->varno;
3953 	RangeTblEntry *rangeTableEntry = rt_fetch(rangeTableId, rangeTableList);
3954 
3955 	CitusRTEKind rangeTableType = GetRangeTblKind(rangeTableEntry);
3956 	if (rangeTableType == CITUS_RTE_RELATION)
3957 	{
3958 		Oid relationId = rangeTableEntry->relid;
3959 		Var *partitionColumn = PartitionColumn(relationId, rangeTableId);
3960 
3961 		/* non-distributed tables do not have partition columns */
3962 		if (IsCitusTableType(relationId, CITUS_TABLE_WITH_NO_DIST_KEY))
3963 		{
3964 			return false;
3965 		}
3966 
3967 		if (partitionColumn->varattno == column->varattno)
3968 		{
3969 			partitionedOnColumn = true;
3970 		}
3971 	}
3972 	else if (rangeTableType == CITUS_RTE_REMOTE_QUERY)
3973 	{
3974 		Job *job = JobForRangeTable(dependentJobList, rangeTableEntry);
3975 		MapMergeJob *mapMergeJob = (MapMergeJob *) job;
3976 
3977 		/*
3978 		 * The column's current attribute number is it's location in the target
3979 		 * list for the table represented by the remote query. We retrieve this
3980 		 * value from the target list to compare against the partition column
3981 		 * as stored in the job.
3982 		 */
3983 		List *targetEntryList = job->jobQuery->targetList;
3984 		int32 columnIndex = column->varattno - 1;
3985 		Assert(columnIndex >= 0);
3986 		Assert(columnIndex < list_length(targetEntryList));
3987 
3988 		TargetEntry *targetEntry = (TargetEntry *) list_nth(targetEntryList, columnIndex);
3989 		Var *remoteRelationColumn = (Var *) targetEntry->expr;
3990 		Assert(IsA(remoteRelationColumn, Var));
3991 
3992 		/* retrieve the partition column for the job */
3993 		Var *partitionColumn = mapMergeJob->partitionColumn;
3994 		if (partitionColumn->varattno == remoteRelationColumn->varattno)
3995 		{
3996 			partitionedOnColumn = true;
3997 		}
3998 	}
3999 
4000 	return partitionedOnColumn;
4001 }
4002 
4003 
4004 /* Checks that the join clause references only simple columns. */
4005 static void
CheckJoinBetweenColumns(OpExpr * joinClause)4006 CheckJoinBetweenColumns(OpExpr *joinClause)
4007 {
4008 	List *argumentList = joinClause->args;
4009 	Node *leftArgument = (Node *) linitial(argumentList);
4010 	Node *rightArgument = (Node *) lsecond(argumentList);
4011 	Node *strippedLeftArgument = strip_implicit_coercions(leftArgument);
4012 	Node *strippedRightArgument = strip_implicit_coercions(rightArgument);
4013 
4014 	NodeTag leftArgumentType = nodeTag(strippedLeftArgument);
4015 	NodeTag rightArgumentType = nodeTag(strippedRightArgument);
4016 
4017 	if (leftArgumentType != T_Var || rightArgumentType != T_Var)
4018 	{
4019 		ereport(ERROR, (errcode(ERRCODE_FEATURE_NOT_SUPPORTED),
4020 						errmsg("cannot perform local joins that involve expressions"),
4021 						errdetail("local joins can be performed between columns only")));
4022 	}
4023 }
4024 
4025 
4026 /*
4027  * FindRangeTableFragmentsList walks over the given list of range table fragments
4028  * and, returns the one with the given table id.
4029  */
4030 static List *
FindRangeTableFragmentsList(List * rangeTableFragmentsList,int tableId)4031 FindRangeTableFragmentsList(List *rangeTableFragmentsList, int tableId)
4032 {
4033 	List *foundTableFragments = NIL;
4034 	ListCell *rangeTableFragmentsCell = NULL;
4035 
4036 	foreach(rangeTableFragmentsCell, rangeTableFragmentsList)
4037 	{
4038 		List *tableFragments = (List *) lfirst(rangeTableFragmentsCell);
4039 		if (tableFragments != NIL)
4040 		{
4041 			RangeTableFragment *tableFragment =
4042 				(RangeTableFragment *) linitial(tableFragments);
4043 			if (tableFragment->rangeTableId == tableId)
4044 			{
4045 				foundTableFragments = tableFragments;
4046 				break;
4047 			}
4048 		}
4049 	}
4050 
4051 	return foundTableFragments;
4052 }
4053 
4054 
4055 /*
4056  * JoinPrunable checks if a join between the given left and right fragments can
4057  * be pruned away, without performing the actual join. To do this, the function
4058  * checks if we have a hash repartition join. If we do, the function determines
4059  * pruning based on partitionIds. Else if we have a merge repartition join, the
4060  * function checks if the two fragments have disjoint intervals.
4061  */
4062 static bool
JoinPrunable(RangeTableFragment * leftFragment,RangeTableFragment * rightFragment)4063 JoinPrunable(RangeTableFragment *leftFragment, RangeTableFragment *rightFragment)
4064 {
4065 	/*
4066 	 * If both range tables are remote queries, we then have a hash repartition
4067 	 * join. In that case, we can just prune away this join if left and right
4068 	 * hand side fragments have the same partitionId.
4069 	 */
4070 	if (leftFragment->fragmentType == CITUS_RTE_REMOTE_QUERY &&
4071 		rightFragment->fragmentType == CITUS_RTE_REMOTE_QUERY)
4072 	{
4073 		Task *leftMergeTask = (Task *) leftFragment->fragmentReference;
4074 		Task *rightMergeTask = (Task *) rightFragment->fragmentReference;
4075 
4076 
4077 		if (leftMergeTask->partitionId != rightMergeTask->partitionId)
4078 		{
4079 			ereport(DEBUG2, (errmsg("join prunable for task partitionId %u and %u",
4080 									leftMergeTask->partitionId,
4081 									rightMergeTask->partitionId)));
4082 			return true;
4083 		}
4084 		else
4085 		{
4086 			return false;
4087 		}
4088 	}
4089 
4090 
4091 	/*
4092 	 * We have a single (re)partition join. We now get shard intervals for both
4093 	 * fragments, and then check if these intervals overlap.
4094 	 */
4095 	ShardInterval *leftFragmentInterval = FragmentInterval(leftFragment);
4096 	ShardInterval *rightFragmentInterval = FragmentInterval(rightFragment);
4097 
4098 	bool overlap = ShardIntervalsOverlap(leftFragmentInterval, rightFragmentInterval);
4099 	if (!overlap)
4100 	{
4101 		if (IsLoggableLevel(DEBUG2))
4102 		{
4103 			StringInfo leftString = FragmentIntervalString(leftFragmentInterval);
4104 			StringInfo rightString = FragmentIntervalString(rightFragmentInterval);
4105 
4106 			ereport(DEBUG2, (errmsg("join prunable for intervals %s and %s",
4107 									leftString->data, rightString->data)));
4108 		}
4109 
4110 		return true;
4111 	}
4112 
4113 	return false;
4114 }
4115 
4116 
4117 /*
4118  * FragmentInterval takes the given fragment, and determines the range of data
4119  * covered by this fragment. The function then returns this range (interval).
4120  */
4121 static ShardInterval *
FragmentInterval(RangeTableFragment * fragment)4122 FragmentInterval(RangeTableFragment *fragment)
4123 {
4124 	ShardInterval *fragmentInterval = NULL;
4125 	if (fragment->fragmentType == CITUS_RTE_RELATION)
4126 	{
4127 		Assert(CitusIsA(fragment->fragmentReference, ShardInterval));
4128 		fragmentInterval = (ShardInterval *) fragment->fragmentReference;
4129 	}
4130 	else if (fragment->fragmentType == CITUS_RTE_REMOTE_QUERY)
4131 	{
4132 		Assert(CitusIsA(fragment->fragmentReference, Task));
4133 
4134 		Task *mergeTask = (Task *) fragment->fragmentReference;
4135 		fragmentInterval = mergeTask->shardInterval;
4136 	}
4137 
4138 	return fragmentInterval;
4139 }
4140 
4141 
4142 /* Checks if the given shard intervals have overlapping ranges. */
4143 bool
ShardIntervalsOverlap(ShardInterval * firstInterval,ShardInterval * secondInterval)4144 ShardIntervalsOverlap(ShardInterval *firstInterval, ShardInterval *secondInterval)
4145 {
4146 	CitusTableCacheEntry *intervalRelation =
4147 		GetCitusTableCacheEntry(firstInterval->relationId);
4148 
4149 	Assert(IsCitusTableTypeCacheEntry(intervalRelation, DISTRIBUTED_TABLE));
4150 
4151 	if (!(firstInterval->minValueExists && firstInterval->maxValueExists &&
4152 		  secondInterval->minValueExists && secondInterval->maxValueExists))
4153 	{
4154 		return true;
4155 	}
4156 
4157 	Datum firstMin = firstInterval->minValue;
4158 	Datum firstMax = firstInterval->maxValue;
4159 	Datum secondMin = secondInterval->minValue;
4160 	Datum secondMax = secondInterval->maxValue;
4161 
4162 	FmgrInfo *comparisonFunction = intervalRelation->shardIntervalCompareFunction;
4163 	Oid collation = intervalRelation->partitionColumn->varcollid;
4164 
4165 	return ShardIntervalsOverlapWithParams(firstMin, firstMax, secondMin, secondMax,
4166 										   comparisonFunction, collation);
4167 }
4168 
4169 
4170 /*
4171  * ShardIntervalsOverlapWithParams is a helper function which compares the input
4172  * shard min/max values, and returns true if the shards overlap.
4173  * The caller is responsible to ensure the input shard min/max values are not NULL.
4174  */
4175 bool
ShardIntervalsOverlapWithParams(Datum firstMin,Datum firstMax,Datum secondMin,Datum secondMax,FmgrInfo * comparisonFunction,Oid collation)4176 ShardIntervalsOverlapWithParams(Datum firstMin, Datum firstMax, Datum secondMin,
4177 								Datum secondMax, FmgrInfo *comparisonFunction,
4178 								Oid collation)
4179 {
4180 	/*
4181 	 * We need to have min/max values for both intervals first. Then, we assume
4182 	 * two intervals i1 = [min1, max1] and i2 = [min2, max2] do not overlap if
4183 	 * (max1 < min2) or (max2 < min1). For details, please see the explanation
4184 	 * on overlapping intervals at http://www.rgrjr.com/emacs/overlap.html.
4185 	 */
4186 	Datum firstDatum = FunctionCall2Coll(comparisonFunction, collation, firstMax,
4187 										 secondMin);
4188 	Datum secondDatum = FunctionCall2Coll(comparisonFunction, collation, secondMax,
4189 										  firstMin);
4190 	int firstComparison = DatumGetInt32(firstDatum);
4191 	int secondComparison = DatumGetInt32(secondDatum);
4192 
4193 	if (firstComparison < 0 || secondComparison < 0)
4194 	{
4195 		return false;
4196 	}
4197 
4198 	return true;
4199 }
4200 
4201 
4202 /*
4203  * FragmentIntervalString takes the given fragment interval, and converts this
4204  * interval into its string representation for use in debug messages.
4205  */
4206 static StringInfo
FragmentIntervalString(ShardInterval * fragmentInterval)4207 FragmentIntervalString(ShardInterval *fragmentInterval)
4208 {
4209 	Oid typeId = fragmentInterval->valueTypeId;
4210 	Oid outputFunctionId = InvalidOid;
4211 	bool typeVariableLength = false;
4212 
4213 	Assert(fragmentInterval->minValueExists);
4214 	Assert(fragmentInterval->maxValueExists);
4215 
4216 	FmgrInfo *outputFunction = (FmgrInfo *) palloc0(sizeof(FmgrInfo));
4217 	getTypeOutputInfo(typeId, &outputFunctionId, &typeVariableLength);
4218 	fmgr_info(outputFunctionId, outputFunction);
4219 
4220 	char *minValueString = OutputFunctionCall(outputFunction, fragmentInterval->minValue);
4221 	char *maxValueString = OutputFunctionCall(outputFunction, fragmentInterval->maxValue);
4222 
4223 	StringInfo fragmentIntervalString = makeStringInfo();
4224 	appendStringInfo(fragmentIntervalString, "[%s,%s]", minValueString, maxValueString);
4225 
4226 	return fragmentIntervalString;
4227 }
4228 
4229 
4230 /*
4231  * DataFetchTaskList builds a merge fetch task for every remote query result
4232  * in the given fragment list, appends these merge fetch tasks into a list,
4233  * and returns this list.
4234  */
4235 static List *
DataFetchTaskList(uint64 jobId,uint32 taskIdIndex,List * fragmentList)4236 DataFetchTaskList(uint64 jobId, uint32 taskIdIndex, List *fragmentList)
4237 {
4238 	List *dataFetchTaskList = NIL;
4239 	ListCell *fragmentCell = NULL;
4240 
4241 	foreach(fragmentCell, fragmentList)
4242 	{
4243 		RangeTableFragment *fragment = (RangeTableFragment *) lfirst(fragmentCell);
4244 		if (fragment->fragmentType == CITUS_RTE_REMOTE_QUERY)
4245 		{
4246 			Task *mergeTask = (Task *) fragment->fragmentReference;
4247 			char *undefinedQueryString = NULL;
4248 
4249 			/* create merge fetch task and have it depend on the merge task */
4250 			Task *mergeFetchTask = CreateBasicTask(jobId, taskIdIndex, MERGE_FETCH_TASK,
4251 												   undefinedQueryString);
4252 			mergeFetchTask->dependentTaskList = list_make1(mergeTask);
4253 
4254 			dataFetchTaskList = lappend(dataFetchTaskList, mergeFetchTask);
4255 			taskIdIndex++;
4256 		}
4257 	}
4258 
4259 	return dataFetchTaskList;
4260 }
4261 
4262 
4263 /* Helper function to return a datum array's external string representation. */
4264 static StringInfo
DatumArrayString(Datum * datumArray,uint32 datumCount,Oid datumTypeId)4265 DatumArrayString(Datum *datumArray, uint32 datumCount, Oid datumTypeId)
4266 {
4267 	int16 typeLength = 0;
4268 	bool typeByValue = false;
4269 	char typeAlignment = 0;
4270 
4271 	/* construct the array object from the given array */
4272 	get_typlenbyvalalign(datumTypeId, &typeLength, &typeByValue, &typeAlignment);
4273 	ArrayType *arrayObject = construct_array(datumArray, datumCount, datumTypeId,
4274 											 typeLength, typeByValue, typeAlignment);
4275 	Datum arrayObjectDatum = PointerGetDatum(arrayObject);
4276 
4277 	/* convert the array object to its string representation */
4278 	FmgrInfo *arrayOutFunction = (FmgrInfo *) palloc0(sizeof(FmgrInfo));
4279 	fmgr_info(F_ARRAY_OUT, arrayOutFunction);
4280 
4281 	Datum arrayStringDatum = FunctionCall1(arrayOutFunction, arrayObjectDatum);
4282 	char *arrayString = DatumGetCString(arrayStringDatum);
4283 
4284 	StringInfo arrayStringInfo = makeStringInfo();
4285 	appendStringInfo(arrayStringInfo, "%s", arrayString);
4286 
4287 	return arrayStringInfo;
4288 }
4289 
4290 
4291 /*
4292  * CreateBasicTask creates a task, initializes fields that are common to each task,
4293  * and returns the created task.
4294  */
4295 Task *
CreateBasicTask(uint64 jobId,uint32 taskId,TaskType taskType,char * queryString)4296 CreateBasicTask(uint64 jobId, uint32 taskId, TaskType taskType, char *queryString)
4297 {
4298 	Task *task = CitusMakeNode(Task);
4299 	task->jobId = jobId;
4300 	task->taskId = taskId;
4301 	task->taskType = taskType;
4302 	task->replicationModel = REPLICATION_MODEL_INVALID;
4303 	SetTaskQueryString(task, queryString);
4304 
4305 	return task;
4306 }
4307 
4308 
4309 /*
4310  * BuildRelationShardList builds a list of RelationShard pairs for a task.
4311  * This represents the mapping of range table entries to shard IDs for a
4312  * task for the purposes of locking, deparsing, and connection management.
4313  */
4314 static List *
BuildRelationShardList(List * rangeTableList,List * fragmentList)4315 BuildRelationShardList(List *rangeTableList, List *fragmentList)
4316 {
4317 	List *relationShardList = NIL;
4318 	ListCell *fragmentCell = NULL;
4319 
4320 	foreach(fragmentCell, fragmentList)
4321 	{
4322 		RangeTableFragment *fragment = (RangeTableFragment *) lfirst(fragmentCell);
4323 		Index rangeTableId = fragment->rangeTableId;
4324 		RangeTblEntry *rangeTableEntry = rt_fetch(rangeTableId, rangeTableList);
4325 
4326 		CitusRTEKind fragmentType = fragment->fragmentType;
4327 		if (fragmentType == CITUS_RTE_RELATION)
4328 		{
4329 			ShardInterval *shardInterval = (ShardInterval *) fragment->fragmentReference;
4330 			RelationShard *relationShard = CitusMakeNode(RelationShard);
4331 
4332 			relationShard->relationId = rangeTableEntry->relid;
4333 			relationShard->shardId = shardInterval->shardId;
4334 
4335 			relationShardList = lappend(relationShardList, relationShard);
4336 		}
4337 	}
4338 
4339 	return relationShardList;
4340 }
4341 
4342 
4343 /*
4344  * UpdateRangeTableAlias walks over each fragment in the given fragment list,
4345  * and creates an alias that represents the fragment name to be used in the
4346  * query. The function then updates the corresponding range table entry with
4347  * this alias.
4348  */
4349 static void
UpdateRangeTableAlias(List * rangeTableList,List * fragmentList)4350 UpdateRangeTableAlias(List *rangeTableList, List *fragmentList)
4351 {
4352 	ListCell *fragmentCell = NULL;
4353 	foreach(fragmentCell, fragmentList)
4354 	{
4355 		RangeTableFragment *fragment = (RangeTableFragment *) lfirst(fragmentCell);
4356 		Index rangeTableId = fragment->rangeTableId;
4357 		RangeTblEntry *rangeTableEntry = rt_fetch(rangeTableId, rangeTableList);
4358 
4359 		Alias *fragmentAlias = FragmentAlias(rangeTableEntry, fragment);
4360 		rangeTableEntry->alias = fragmentAlias;
4361 	}
4362 }
4363 
4364 
4365 /*
4366  * FragmentAlias creates an alias structure that captures the table fragment's
4367  * name on the worker node. Each fragment represents either a regular shard, or
4368  * a merge task.
4369  */
4370 static Alias *
FragmentAlias(RangeTblEntry * rangeTableEntry,RangeTableFragment * fragment)4371 FragmentAlias(RangeTblEntry *rangeTableEntry, RangeTableFragment *fragment)
4372 {
4373 	char *aliasName = NULL;
4374 	char *schemaName = NULL;
4375 	char *fragmentName = NULL;
4376 
4377 	CitusRTEKind fragmentType = fragment->fragmentType;
4378 	if (fragmentType == CITUS_RTE_RELATION)
4379 	{
4380 		ShardInterval *shardInterval = (ShardInterval *) fragment->fragmentReference;
4381 		uint64 shardId = shardInterval->shardId;
4382 
4383 		Oid relationId = rangeTableEntry->relid;
4384 		char *relationName = get_rel_name(relationId);
4385 
4386 		Oid schemaId = get_rel_namespace(relationId);
4387 		schemaName = get_namespace_name(schemaId);
4388 
4389 		aliasName = relationName;
4390 
4391 		/*
4392 		 * Set shard name in alias to <relation_name>_<shard_id>.
4393 		 */
4394 		fragmentName = pstrdup(relationName);
4395 		AppendShardIdToName(&fragmentName, shardId);
4396 	}
4397 	else if (fragmentType == CITUS_RTE_REMOTE_QUERY)
4398 	{
4399 		Task *mergeTask = (Task *) fragment->fragmentReference;
4400 		uint64 jobId = mergeTask->jobId;
4401 		uint32 taskId = mergeTask->taskId;
4402 
4403 		StringInfo jobSchemaName = JobSchemaName(jobId);
4404 		StringInfo taskTableName = TaskTableName(taskId);
4405 
4406 		StringInfo aliasNameString = makeStringInfo();
4407 		appendStringInfo(aliasNameString, "%s.%s",
4408 						 jobSchemaName->data, taskTableName->data);
4409 
4410 		aliasName = aliasNameString->data;
4411 		fragmentName = taskTableName->data;
4412 		schemaName = jobSchemaName->data;
4413 	}
4414 
4415 	/*
4416 	 * We need to set the aliasname to relation name, as pg_get_query_def() uses
4417 	 * the relation name to disambiguate column names from different tables.
4418 	 */
4419 	Alias *alias = rangeTableEntry->alias;
4420 	if (alias == NULL)
4421 	{
4422 		alias = makeNode(Alias);
4423 		alias->aliasname = aliasName;
4424 	}
4425 
4426 	ModifyRangeTblExtraData(rangeTableEntry, CITUS_RTE_SHARD,
4427 							schemaName, fragmentName, NIL);
4428 
4429 	return alias;
4430 }
4431 
4432 
4433 /*
4434  * AnchorShardId walks over each fragment in the given fragment list, finds the
4435  * fragment that corresponds to the given anchor range tableId, and returns this
4436  * fragment's shard identifier. Note that the given tableId must correspond to a
4437  * base relation.
4438  */
4439 static uint64
AnchorShardId(List * fragmentList,uint32 anchorRangeTableId)4440 AnchorShardId(List *fragmentList, uint32 anchorRangeTableId)
4441 {
4442 	uint64 anchorShardId = INVALID_SHARD_ID;
4443 	ListCell *fragmentCell = NULL;
4444 
4445 	foreach(fragmentCell, fragmentList)
4446 	{
4447 		RangeTableFragment *fragment = (RangeTableFragment *) lfirst(fragmentCell);
4448 		if (fragment->rangeTableId == anchorRangeTableId)
4449 		{
4450 			Assert(fragment->fragmentType == CITUS_RTE_RELATION);
4451 			Assert(CitusIsA(fragment->fragmentReference, ShardInterval));
4452 
4453 			ShardInterval *shardInterval = (ShardInterval *) fragment->fragmentReference;
4454 			anchorShardId = shardInterval->shardId;
4455 			break;
4456 		}
4457 	}
4458 
4459 	Assert(anchorShardId != INVALID_SHARD_ID);
4460 	return anchorShardId;
4461 }
4462 
4463 
4464 /*
4465  * PruneSqlTaskDependencies iterates over each sql task from the given sql task
4466  * list, and prunes away merge-fetch tasks, as the task assignment algorithm
4467  * ensures co-location of these tasks.
4468  */
4469 static List *
PruneSqlTaskDependencies(List * sqlTaskList)4470 PruneSqlTaskDependencies(List *sqlTaskList)
4471 {
4472 	ListCell *sqlTaskCell = NULL;
4473 	foreach(sqlTaskCell, sqlTaskList)
4474 	{
4475 		Task *sqlTask = (Task *) lfirst(sqlTaskCell);
4476 		List *dependentTaskList = sqlTask->dependentTaskList;
4477 		List *prunedDependendTaskList = NIL;
4478 
4479 		ListCell *dependentTaskCell = NULL;
4480 		foreach(dependentTaskCell, dependentTaskList)
4481 		{
4482 			Task *dataFetchTask = (Task *) lfirst(dependentTaskCell);
4483 
4484 			/*
4485 			 * If we have a merge fetch task, our task assignment algorithm makes
4486 			 * sure that the sql task is colocated with the anchor shard / merge
4487 			 * task. We can therefore prune out this data fetch task.
4488 			 */
4489 			if (dataFetchTask->taskType == MERGE_FETCH_TASK)
4490 			{
4491 				List *mergeFetchDependencyList = dataFetchTask->dependentTaskList;
4492 				Assert(list_length(mergeFetchDependencyList) == 1);
4493 
4494 				Task *mergeTaskReference = (Task *) linitial(mergeFetchDependencyList);
4495 				prunedDependendTaskList = lappend(prunedDependendTaskList,
4496 												  mergeTaskReference);
4497 
4498 				ereport(DEBUG2, (errmsg("pruning merge fetch taskId %d",
4499 										dataFetchTask->taskId),
4500 								 errdetail("Creating dependency on merge taskId %d",
4501 										   mergeTaskReference->taskId)));
4502 			}
4503 		}
4504 
4505 		sqlTask->dependentTaskList = prunedDependendTaskList;
4506 	}
4507 
4508 	return sqlTaskList;
4509 }
4510 
4511 
4512 /*
4513  * MapTaskList creates a list of map tasks for the given MapMerge job. For this,
4514  * the function walks over each filter task (sql task) in the given filter task
4515  * list, and wraps this task with a map function call. The map function call
4516  * repartitions the filter task's output according to MapMerge job's parameters.
4517  */
4518 static List *
MapTaskList(MapMergeJob * mapMergeJob,List * filterTaskList)4519 MapTaskList(MapMergeJob *mapMergeJob, List *filterTaskList)
4520 {
4521 	List *mapTaskList = NIL;
4522 	Query *filterQuery = mapMergeJob->job.jobQuery;
4523 	ListCell *filterTaskCell = NULL;
4524 	Var *partitionColumn = mapMergeJob->partitionColumn;
4525 
4526 	uint32 partitionColumnResNo = 0;
4527 	List *groupClauseList = filterQuery->groupClause;
4528 	if (groupClauseList != NIL)
4529 	{
4530 		List *targetEntryList = filterQuery->targetList;
4531 		List *groupTargetEntryList = GroupTargetEntryList(groupClauseList,
4532 														  targetEntryList);
4533 		TargetEntry *groupByTargetEntry = (TargetEntry *) linitial(groupTargetEntryList);
4534 
4535 		partitionColumnResNo = groupByTargetEntry->resno;
4536 	}
4537 	else
4538 	{
4539 		partitionColumnResNo = PartitionColumnIndex(partitionColumn,
4540 													filterQuery->targetList);
4541 	}
4542 
4543 	foreach(filterTaskCell, filterTaskList)
4544 	{
4545 		Task *filterTask = (Task *) lfirst(filterTaskCell);
4546 		StringInfo mapQueryString = CreateMapQueryString(mapMergeJob, filterTask,
4547 														 partitionColumnResNo);
4548 
4549 		/* convert filter query task into map task */
4550 		Task *mapTask = filterTask;
4551 		SetTaskQueryString(mapTask, mapQueryString->data);
4552 		mapTask->taskType = MAP_TASK;
4553 
4554 		mapTaskList = lappend(mapTaskList, mapTask);
4555 	}
4556 
4557 	return mapTaskList;
4558 }
4559 
4560 
4561 /*
4562  * PartitionColumnIndex finds the index of the given target var.
4563  */
4564 static int
PartitionColumnIndex(Var * targetVar,List * targetList)4565 PartitionColumnIndex(Var *targetVar, List *targetList)
4566 {
4567 	TargetEntry *targetEntry = NULL;
4568 	int resNo = 1;
4569 	foreach_ptr(targetEntry, targetList)
4570 	{
4571 		if (IsA(targetEntry->expr, Var))
4572 		{
4573 			Var *candidateVar = (Var *) targetEntry->expr;
4574 			if (candidateVar->varattno == targetVar->varattno &&
4575 				candidateVar->varno == targetVar->varno)
4576 			{
4577 				return resNo;
4578 			}
4579 			resNo++;
4580 		}
4581 	}
4582 
4583 	ereport(ERROR, (errmsg("unexpected state: %d varno %d varattno couldn't be found",
4584 						   targetVar->varno, targetVar->varattno)));
4585 	return resNo;
4586 }
4587 
4588 
4589 /*
4590  * CreateMapQueryString creates and returns the map query string for the given filterTask.
4591  */
4592 static StringInfo
CreateMapQueryString(MapMergeJob * mapMergeJob,Task * filterTask,uint32 partitionColumnIndex)4593 CreateMapQueryString(MapMergeJob *mapMergeJob, Task *filterTask,
4594 					 uint32 partitionColumnIndex)
4595 {
4596 	uint64 jobId = filterTask->jobId;
4597 	uint32 taskId = filterTask->taskId;
4598 
4599 	/* wrap repartition query string around filter query string */
4600 	StringInfo mapQueryString = makeStringInfo();
4601 	char *filterQueryString = TaskQueryString(filterTask);
4602 	char *filterQueryEscapedText = quote_literal_cstr(filterQueryString);
4603 	PartitionType partitionType = mapMergeJob->partitionType;
4604 
4605 	Var *partitionColumn = mapMergeJob->partitionColumn;
4606 	Oid partitionColumnType = partitionColumn->vartype;
4607 	char *partitionColumnTypeFullName = format_type_be_qualified(partitionColumnType);
4608 	int32 partitionColumnTypeMod = partitionColumn->vartypmod;
4609 
4610 	ShardInterval **intervalArray = mapMergeJob->sortedShardIntervalArray;
4611 	uint32 intervalCount = mapMergeJob->partitionCount;
4612 
4613 	if (partitionType == DUAL_HASH_PARTITION_TYPE)
4614 	{
4615 		partitionColumnType = INT4OID;
4616 		partitionColumnTypeMod = get_typmodin(INT4OID);
4617 		intervalArray = GenerateSyntheticShardIntervalArray(intervalCount);
4618 	}
4619 	else if (partitionType == SINGLE_HASH_PARTITION_TYPE)
4620 	{
4621 		partitionColumnType = INT4OID;
4622 		partitionColumnTypeMod = get_typmodin(INT4OID);
4623 	}
4624 
4625 	ArrayType *splitPointObject = SplitPointObject(intervalArray, intervalCount);
4626 	StringInfo splitPointString = ArrayObjectToString(splitPointObject,
4627 													  partitionColumnType,
4628 													  partitionColumnTypeMod);
4629 
4630 	char *partitionCommand = NULL;
4631 	if (partitionType == RANGE_PARTITION_TYPE)
4632 	{
4633 		partitionCommand = RANGE_PARTITION_COMMAND;
4634 	}
4635 	else
4636 	{
4637 		partitionCommand = HASH_PARTITION_COMMAND;
4638 	}
4639 
4640 	char *partitionColumnIndextText = ConvertIntToString(partitionColumnIndex);
4641 	appendStringInfo(mapQueryString, partitionCommand, jobId, taskId,
4642 					 filterQueryEscapedText, partitionColumnIndextText,
4643 					 partitionColumnTypeFullName, splitPointString->data);
4644 	return mapQueryString;
4645 }
4646 
4647 
4648 /*
4649  * GenerateSyntheticShardIntervalArray returns a shard interval pointer array
4650  * which has a uniform hash distribution for the given input partitionCount.
4651  *
4652  * The function only fills the min/max values of shard the intervals. Thus, should
4653  * not be used for general purpose operations.
4654  */
4655 ShardInterval **
GenerateSyntheticShardIntervalArray(int partitionCount)4656 GenerateSyntheticShardIntervalArray(int partitionCount)
4657 {
4658 	ShardInterval **shardIntervalArray = palloc0(partitionCount *
4659 												 sizeof(ShardInterval *));
4660 	uint64 hashTokenIncrement = HASH_TOKEN_COUNT / partitionCount;
4661 
4662 	for (int shardIndex = 0; shardIndex < partitionCount; ++shardIndex)
4663 	{
4664 		ShardInterval *shardInterval = CitusMakeNode(ShardInterval);
4665 
4666 		/* calculate the split of the hash space */
4667 		int32 shardMinHashToken = PG_INT32_MIN + (shardIndex * hashTokenIncrement);
4668 		int32 shardMaxHashToken = shardMinHashToken + (hashTokenIncrement - 1);
4669 
4670 		shardInterval->relationId = InvalidOid;
4671 		shardInterval->minValueExists = true;
4672 		shardInterval->minValue = Int32GetDatum(shardMinHashToken);
4673 
4674 		shardInterval->maxValueExists = true;
4675 		shardInterval->maxValue = Int32GetDatum(shardMaxHashToken);
4676 
4677 		shardInterval->shardId = INVALID_SHARD_ID;
4678 		shardInterval->valueTypeId = INT4OID;
4679 
4680 		shardIntervalArray[shardIndex] = shardInterval;
4681 	}
4682 
4683 	return shardIntervalArray;
4684 }
4685 
4686 
4687 /*
4688  * Determine RowModifyLevel required for given query
4689  */
4690 RowModifyLevel
RowModifyLevelForQuery(Query * query)4691 RowModifyLevelForQuery(Query *query)
4692 {
4693 	CmdType commandType = query->commandType;
4694 
4695 	if (commandType == CMD_SELECT)
4696 	{
4697 		if (query->hasModifyingCTE)
4698 		{
4699 			/* skip checking for INSERT as those CTEs are recursively planned */
4700 			CommonTableExpr *cte = NULL;
4701 			foreach_ptr(cte, query->cteList)
4702 			{
4703 				Query *cteQuery = (Query *) cte->ctequery;
4704 
4705 				if (cteQuery->commandType == CMD_UPDATE ||
4706 					cteQuery->commandType == CMD_DELETE)
4707 				{
4708 					return ROW_MODIFY_NONCOMMUTATIVE;
4709 				}
4710 			}
4711 		}
4712 
4713 		return ROW_MODIFY_READONLY;
4714 	}
4715 
4716 	if (commandType == CMD_INSERT)
4717 	{
4718 		if (query->onConflict == NULL)
4719 		{
4720 			return ROW_MODIFY_COMMUTATIVE;
4721 		}
4722 		else
4723 		{
4724 			return ROW_MODIFY_NONCOMMUTATIVE;
4725 		}
4726 	}
4727 
4728 	if (commandType == CMD_UPDATE ||
4729 		commandType == CMD_DELETE)
4730 	{
4731 		return ROW_MODIFY_NONCOMMUTATIVE;
4732 	}
4733 
4734 	return ROW_MODIFY_NONE;
4735 }
4736 
4737 
4738 /*
4739  * ArrayObjectToString converts an SQL object to its string representation.
4740  */
4741 StringInfo
ArrayObjectToString(ArrayType * arrayObject,Oid columnType,int32 columnTypeMod)4742 ArrayObjectToString(ArrayType *arrayObject, Oid columnType, int32 columnTypeMod)
4743 {
4744 	Datum arrayDatum = PointerGetDatum(arrayObject);
4745 	Oid outputFunctionId = InvalidOid;
4746 	bool typeVariableLength = false;
4747 
4748 	Oid arrayOutType = get_array_type(columnType);
4749 	if (arrayOutType == InvalidOid)
4750 	{
4751 		char *columnTypeName = format_type_be(columnType);
4752 		ereport(ERROR, (errmsg("cannot range repartition table on column type %s",
4753 							   columnTypeName)));
4754 	}
4755 
4756 	FmgrInfo *arrayOutFunction = (FmgrInfo *) palloc0(sizeof(FmgrInfo));
4757 	getTypeOutputInfo(arrayOutType, &outputFunctionId, &typeVariableLength);
4758 	fmgr_info(outputFunctionId, arrayOutFunction);
4759 
4760 	char *arrayOutputText = OutputFunctionCall(arrayOutFunction, arrayDatum);
4761 	char *arrayOutputEscapedText = quote_literal_cstr(arrayOutputText);
4762 
4763 	/* add an explicit cast to array's string representation */
4764 	char *arrayOutTypeName = format_type_with_typemod(arrayOutType, columnTypeMod);
4765 
4766 	StringInfo arrayString = makeStringInfo();
4767 	appendStringInfo(arrayString, "%s::%s",
4768 					 arrayOutputEscapedText, arrayOutTypeName);
4769 
4770 	return arrayString;
4771 }
4772 
4773 
4774 /*
4775  * MergeTaskList creates a list of merge tasks for the given MapMerge job. While
4776  * doing this, the function also establishes dependencies between each merge
4777  * task and its downstream map task dependencies by creating "map fetch" tasks.
4778  */
4779 static List *
MergeTaskList(MapMergeJob * mapMergeJob,List * mapTaskList,uint32 taskIdIndex)4780 MergeTaskList(MapMergeJob *mapMergeJob, List *mapTaskList, uint32 taskIdIndex)
4781 {
4782 	List *mergeTaskList = NIL;
4783 	uint64 jobId = mapMergeJob->job.jobId;
4784 	uint32 partitionCount = mapMergeJob->partitionCount;
4785 
4786 	/* build column name and column type arrays (table schema) */
4787 	Query *filterQuery = mapMergeJob->job.jobQuery;
4788 	List *targetEntryList = filterQuery->targetList;
4789 
4790 	/* if all map tasks were pruned away, return NIL for merge tasks */
4791 	if (mapTaskList == NIL)
4792 	{
4793 		return NIL;
4794 	}
4795 
4796 	/*
4797 	 * XXX: We currently ignore the 0th partition bucket that range partitioning
4798 	 * generates. This bucket holds all values less than the minimum value or
4799 	 * NULLs, both of which we can currently ignore. However, when we support
4800 	 * range re-partitioned OUTER joins, we will need these rows for the
4801 	 * relation whose rows are retained in the OUTER join.
4802 	 */
4803 	uint32 initialPartitionId = 0;
4804 	if (mapMergeJob->partitionType == RANGE_PARTITION_TYPE)
4805 	{
4806 		initialPartitionId = 1;
4807 		partitionCount = partitionCount + 1;
4808 	}
4809 	else if (mapMergeJob->partitionType == SINGLE_HASH_PARTITION_TYPE)
4810 	{
4811 		initialPartitionId = 0;
4812 	}
4813 
4814 	/* build merge tasks and their associated "map output fetch" tasks */
4815 	for (uint32 partitionId = initialPartitionId; partitionId < partitionCount;
4816 		 partitionId++)
4817 	{
4818 		Task *mergeTask = NULL;
4819 		List *mapOutputFetchTaskList = NIL;
4820 		ListCell *mapTaskCell = NULL;
4821 		uint32 mergeTaskId = taskIdIndex;
4822 
4823 		Query *reduceQuery = mapMergeJob->reduceQuery;
4824 		if (reduceQuery == NULL)
4825 		{
4826 			uint32 columnCount = (uint32) list_length(targetEntryList);
4827 			StringInfo columnNames = ColumnNameArrayString(columnCount, jobId);
4828 			StringInfo columnTypes = ColumnTypeArrayString(targetEntryList);
4829 
4830 			StringInfo mergeQueryString = makeStringInfo();
4831 			appendStringInfo(mergeQueryString, MERGE_FILES_INTO_TABLE_COMMAND,
4832 							 jobId, taskIdIndex, columnNames->data, columnTypes->data);
4833 
4834 			/* create merge task */
4835 			mergeTask = CreateBasicTask(jobId, mergeTaskId, MERGE_TASK,
4836 										mergeQueryString->data);
4837 		}
4838 		mergeTask->partitionId = partitionId;
4839 		taskIdIndex++;
4840 
4841 		/* create tasks to fetch map outputs to this merge task */
4842 		foreach(mapTaskCell, mapTaskList)
4843 		{
4844 			Task *mapTask = (Task *) lfirst(mapTaskCell);
4845 
4846 			/* find the node name/port for map task's execution */
4847 			List *mapTaskPlacementList = mapTask->taskPlacementList;
4848 
4849 			ShardPlacement *mapTaskPlacement = linitial(mapTaskPlacementList);
4850 			char *mapTaskNodeName = mapTaskPlacement->nodeName;
4851 			uint32 mapTaskNodePort = mapTaskPlacement->nodePort;
4852 
4853 			/*
4854 			 * We will use the first node even if replication factor is greater than 1
4855 			 * When replication factor is greater than 1 and there
4856 			 * is a connection problem to the node that has done the map task, we will get
4857 			 * an error in fetch task execution.
4858 			 */
4859 			StringInfo mapFetchQueryString = makeStringInfo();
4860 			appendStringInfo(mapFetchQueryString, MAP_OUTPUT_FETCH_COMMAND,
4861 							 mapTask->jobId, mapTask->taskId, partitionId,
4862 							 mergeTaskId, /* fetch results to merge task */
4863 							 mapTaskNodeName, mapTaskNodePort);
4864 
4865 			Task *mapOutputFetchTask = CreateBasicTask(jobId, taskIdIndex,
4866 													   MAP_OUTPUT_FETCH_TASK,
4867 													   mapFetchQueryString->data);
4868 			mapOutputFetchTask->partitionId = partitionId;
4869 			mapOutputFetchTask->upstreamTaskId = mergeTaskId;
4870 			mapOutputFetchTask->dependentTaskList = list_make1(mapTask);
4871 			taskIdIndex++;
4872 
4873 			mapOutputFetchTaskList = lappend(mapOutputFetchTaskList, mapOutputFetchTask);
4874 		}
4875 
4876 		/* merge task depends on completion of fetch tasks */
4877 		mergeTask->dependentTaskList = mapOutputFetchTaskList;
4878 
4879 		/* if single repartitioned, each merge task represents an interval */
4880 		if (mapMergeJob->partitionType == RANGE_PARTITION_TYPE)
4881 		{
4882 			int32 mergeTaskIntervalId = partitionId - 1;
4883 			ShardInterval **mergeTaskIntervals = mapMergeJob->sortedShardIntervalArray;
4884 			Assert(mergeTaskIntervalId >= 0);
4885 
4886 			mergeTask->shardInterval = mergeTaskIntervals[mergeTaskIntervalId];
4887 		}
4888 		else if (mapMergeJob->partitionType == SINGLE_HASH_PARTITION_TYPE)
4889 		{
4890 			int32 mergeTaskIntervalId = partitionId;
4891 			ShardInterval **mergeTaskIntervals = mapMergeJob->sortedShardIntervalArray;
4892 			Assert(mergeTaskIntervalId >= 0);
4893 
4894 			mergeTask->shardInterval = mergeTaskIntervals[mergeTaskIntervalId];
4895 		}
4896 
4897 		mergeTaskList = lappend(mergeTaskList, mergeTask);
4898 	}
4899 
4900 	return mergeTaskList;
4901 }
4902 
4903 
4904 /*
4905  * ColumnNameArrayString creates a list of column names for a merged table, and
4906  * outputs this list of column names in their (array) string representation.
4907  */
4908 static StringInfo
ColumnNameArrayString(uint32 columnCount,uint64 generatingJobId)4909 ColumnNameArrayString(uint32 columnCount, uint64 generatingJobId)
4910 {
4911 	Datum *columnNameArray = palloc0(columnCount * sizeof(Datum));
4912 	uint32 columnNameIndex = 0;
4913 
4914 	/* build list of intermediate column names, generated by given jobId */
4915 	List *columnNameList = DerivedColumnNameList(columnCount, generatingJobId);
4916 
4917 	ListCell *columnNameCell = NULL;
4918 	foreach(columnNameCell, columnNameList)
4919 	{
4920 		Value *columnNameValue = (Value *) lfirst(columnNameCell);
4921 		char *columnNameString = strVal(columnNameValue);
4922 		Datum columnName = CStringGetDatum(columnNameString);
4923 
4924 		columnNameArray[columnNameIndex] = columnName;
4925 		columnNameIndex++;
4926 	}
4927 
4928 	StringInfo columnNameArrayString = DatumArrayString(columnNameArray, columnCount,
4929 														CSTRINGOID);
4930 
4931 	return columnNameArrayString;
4932 }
4933 
4934 
4935 /*
4936  * ColumnTypeArrayString resolves a list of column types for a merged table, and
4937  * outputs this list of column types in their (array) string representation.
4938  */
4939 static StringInfo
ColumnTypeArrayString(List * targetEntryList)4940 ColumnTypeArrayString(List *targetEntryList)
4941 {
4942 	ListCell *targetEntryCell = NULL;
4943 
4944 	uint32 columnCount = (uint32) list_length(targetEntryList);
4945 	Datum *columnTypeArray = palloc0(columnCount * sizeof(Datum));
4946 	uint32 columnTypeIndex = 0;
4947 
4948 	foreach(targetEntryCell, targetEntryList)
4949 	{
4950 		TargetEntry *targetEntry = (TargetEntry *) lfirst(targetEntryCell);
4951 		Node *columnExpression = (Node *) targetEntry->expr;
4952 		Oid columnTypeId = exprType(columnExpression);
4953 		int32 columnTypeMod = exprTypmod(columnExpression);
4954 
4955 		char *columnTypeName = format_type_with_typemod(columnTypeId, columnTypeMod);
4956 		Datum columnType = CStringGetDatum(columnTypeName);
4957 
4958 		columnTypeArray[columnTypeIndex] = columnType;
4959 		columnTypeIndex++;
4960 	}
4961 
4962 	StringInfo columnTypeArrayString = DatumArrayString(columnTypeArray, columnCount,
4963 														CSTRINGOID);
4964 
4965 	return columnTypeArrayString;
4966 }
4967 
4968 
4969 /*
4970  * AssignTaskList assigns locations to given tasks based on dependencies between
4971  * tasks and configured task assignment policies. The function also handles the
4972  * case where multiple SQL tasks depend on the same merge task, and makes sure
4973  * that this group of multiple SQL tasks and the merge task are assigned to the
4974  * same location.
4975  */
4976 static List *
AssignTaskList(List * sqlTaskList)4977 AssignTaskList(List *sqlTaskList)
4978 {
4979 	List *assignedSqlTaskList = NIL;
4980 	bool hasAnchorShardId = false;
4981 	ListCell *sqlTaskCell = NULL;
4982 	List *primarySqlTaskList = NIL;
4983 	ListCell *primarySqlTaskCell = NULL;
4984 	ListCell *constrainedSqlTaskCell = NULL;
4985 
4986 	/* no tasks to assign */
4987 	if (sqlTaskList == NIL)
4988 	{
4989 		return NIL;
4990 	}
4991 
4992 	Task *firstSqlTask = (Task *) linitial(sqlTaskList);
4993 	if (firstSqlTask->anchorShardId != INVALID_SHARD_ID)
4994 	{
4995 		hasAnchorShardId = true;
4996 	}
4997 
4998 	/*
4999 	 * If these SQL tasks don't depend on any merge tasks, we can assign each
5000 	 * one independently of the other. We therefore go ahead and assign these
5001 	 * SQL tasks using the "anchor shard based" assignment algorithms.
5002 	 */
5003 	bool hasMergeTaskDependencies = HasMergeTaskDependencies(sqlTaskList);
5004 	if (!hasMergeTaskDependencies)
5005 	{
5006 		Assert(hasAnchorShardId);
5007 
5008 		assignedSqlTaskList = AssignAnchorShardTaskList(sqlTaskList);
5009 
5010 		return assignedSqlTaskList;
5011 	}
5012 
5013 	/*
5014 	 * SQL tasks can depend on merge tasks in one of two ways: (1) each SQL task
5015 	 * depends on merge task(s) that no other SQL task depends upon, (2) several
5016 	 * SQL tasks depend on the same merge task(s) and all need to be assigned to
5017 	 * the same worker node. To handle the second case, we first pick a primary
5018 	 * SQL task among those that depend on the same merge task, and assign it.
5019 	 */
5020 	foreach(sqlTaskCell, sqlTaskList)
5021 	{
5022 		Task *sqlTask = (Task *) lfirst(sqlTaskCell);
5023 		List *mergeTaskList = FindDependentMergeTaskList(sqlTask);
5024 
5025 		Task *firstMergeTask = (Task *) linitial(mergeTaskList);
5026 		if (!firstMergeTask->assignmentConstrained)
5027 		{
5028 			firstMergeTask->assignmentConstrained = true;
5029 
5030 			primarySqlTaskList = lappend(primarySqlTaskList, sqlTask);
5031 		}
5032 	}
5033 
5034 	if (hasAnchorShardId)
5035 	{
5036 		primarySqlTaskList = AssignAnchorShardTaskList(primarySqlTaskList);
5037 	}
5038 	else
5039 	{
5040 		primarySqlTaskList = AssignDualHashTaskList(primarySqlTaskList);
5041 	}
5042 
5043 	/* propagate SQL task assignments to the merge tasks we depend upon */
5044 	foreach(primarySqlTaskCell, primarySqlTaskList)
5045 	{
5046 		Task *sqlTask = (Task *) lfirst(primarySqlTaskCell);
5047 		List *mergeTaskList = FindDependentMergeTaskList(sqlTask);
5048 
5049 		ListCell *mergeTaskCell = NULL;
5050 		foreach(mergeTaskCell, mergeTaskList)
5051 		{
5052 			Task *mergeTask = (Task *) lfirst(mergeTaskCell);
5053 			Assert(mergeTask->taskPlacementList == NIL);
5054 
5055 			mergeTask->taskPlacementList = list_copy(sqlTask->taskPlacementList);
5056 		}
5057 
5058 		assignedSqlTaskList = lappend(assignedSqlTaskList, sqlTask);
5059 	}
5060 
5061 	/*
5062 	 * If we had a set of SQL tasks depending on the same merge task, we only
5063 	 * assigned one SQL task from that set. We call the assigned SQL task the
5064 	 * primary, and note that the remaining SQL tasks are constrained by the
5065 	 * primary's task assignment. We propagate the primary's task assignment in
5066 	 * each set to the remaining (constrained) tasks.
5067 	 */
5068 	List *constrainedSqlTaskList = TaskListDifference(sqlTaskList, primarySqlTaskList);
5069 
5070 	foreach(constrainedSqlTaskCell, constrainedSqlTaskList)
5071 	{
5072 		Task *sqlTask = (Task *) lfirst(constrainedSqlTaskCell);
5073 		List *mergeTaskList = FindDependentMergeTaskList(sqlTask);
5074 		List *mergeTaskPlacementList = NIL;
5075 
5076 		ListCell *mergeTaskCell = NULL;
5077 		foreach(mergeTaskCell, mergeTaskList)
5078 		{
5079 			Task *mergeTask = (Task *) lfirst(mergeTaskCell);
5080 
5081 			/*
5082 			 * If we have more than one merge task, both of them should have the
5083 			 * same task placement list.
5084 			 */
5085 			mergeTaskPlacementList = mergeTask->taskPlacementList;
5086 			Assert(mergeTaskPlacementList != NIL);
5087 
5088 			ereport(DEBUG3, (errmsg("propagating assignment from merge task %d "
5089 									"to constrained sql task %d",
5090 									mergeTask->taskId, sqlTask->taskId)));
5091 		}
5092 
5093 		sqlTask->taskPlacementList = list_copy(mergeTaskPlacementList);
5094 
5095 		assignedSqlTaskList = lappend(assignedSqlTaskList, sqlTask);
5096 	}
5097 
5098 	return assignedSqlTaskList;
5099 }
5100 
5101 
5102 /*
5103  * HasMergeTaskDependencies checks if sql tasks in the given sql task list have
5104  * any dependencies on merge tasks. If they do, the function returns true.
5105  */
5106 static bool
HasMergeTaskDependencies(List * sqlTaskList)5107 HasMergeTaskDependencies(List *sqlTaskList)
5108 {
5109 	bool hasMergeTaskDependencies = false;
5110 	Task *sqlTask = (Task *) linitial(sqlTaskList);
5111 	List *dependentTaskList = sqlTask->dependentTaskList;
5112 
5113 	ListCell *dependentTaskCell = NULL;
5114 	foreach(dependentTaskCell, dependentTaskList)
5115 	{
5116 		Task *dependentTask = (Task *) lfirst(dependentTaskCell);
5117 		if (dependentTask->taskType == MERGE_TASK)
5118 		{
5119 			hasMergeTaskDependencies = true;
5120 			break;
5121 		}
5122 	}
5123 
5124 	return hasMergeTaskDependencies;
5125 }
5126 
5127 
5128 /* Return true if two tasks are equal, false otherwise. */
5129 bool
TasksEqual(const Task * a,const Task * b)5130 TasksEqual(const Task *a, const Task *b)
5131 {
5132 	Assert(CitusIsA(a, Task));
5133 	Assert(CitusIsA(b, Task));
5134 
5135 	if (a->taskType != b->taskType)
5136 	{
5137 		return false;
5138 	}
5139 	if (a->jobId != b->jobId)
5140 	{
5141 		return false;
5142 	}
5143 	if (a->taskId != b->taskId)
5144 	{
5145 		return false;
5146 	}
5147 
5148 	return true;
5149 }
5150 
5151 
5152 /* Is the passed in Task a member of the list. */
5153 bool
TaskListMember(const List * taskList,const Task * task)5154 TaskListMember(const List *taskList, const Task *task)
5155 {
5156 	const ListCell *taskCell = NULL;
5157 
5158 	foreach(taskCell, taskList)
5159 	{
5160 		if (TasksEqual((Task *) lfirst(taskCell), task))
5161 		{
5162 			return true;
5163 		}
5164 	}
5165 
5166 	return false;
5167 }
5168 
5169 
5170 /*
5171  * TaskListDifference returns a list that contains all the tasks in taskList1
5172  * that are not in taskList2. The returned list is freshly allocated via
5173  * palloc(), but the cells themselves point to the same objects as the cells
5174  * of the input lists.
5175  */
5176 List *
TaskListDifference(const List * list1,const List * list2)5177 TaskListDifference(const List *list1, const List *list2)
5178 {
5179 	const ListCell *taskCell = NULL;
5180 	List *resultList = NIL;
5181 
5182 	if (list2 == NIL)
5183 	{
5184 		return list_copy(list1);
5185 	}
5186 
5187 	foreach(taskCell, list1)
5188 	{
5189 		if (!TaskListMember(list2, lfirst(taskCell)))
5190 		{
5191 			resultList = lappend(resultList, lfirst(taskCell));
5192 		}
5193 	}
5194 
5195 	return resultList;
5196 }
5197 
5198 
5199 /*
5200  * AssignAnchorShardTaskList assigns locations to the given tasks based on the
5201  * configured task assignment policy. The distributed executor later sends these
5202  * tasks to their assigned locations for remote execution.
5203  */
5204 List *
AssignAnchorShardTaskList(List * taskList)5205 AssignAnchorShardTaskList(List *taskList)
5206 {
5207 	List *assignedTaskList = NIL;
5208 
5209 	/* choose task assignment policy based on config value */
5210 	if (TaskAssignmentPolicy == TASK_ASSIGNMENT_GREEDY)
5211 	{
5212 		assignedTaskList = GreedyAssignTaskList(taskList);
5213 	}
5214 	else if (TaskAssignmentPolicy == TASK_ASSIGNMENT_FIRST_REPLICA)
5215 	{
5216 		assignedTaskList = FirstReplicaAssignTaskList(taskList);
5217 	}
5218 	else if (TaskAssignmentPolicy == TASK_ASSIGNMENT_ROUND_ROBIN)
5219 	{
5220 		assignedTaskList = RoundRobinAssignTaskList(taskList);
5221 	}
5222 
5223 	Assert(assignedTaskList != NIL);
5224 	return assignedTaskList;
5225 }
5226 
5227 
5228 /*
5229  * GreedyAssignTaskList uses a greedy algorithm similar to Hadoop's, and assigns
5230  * locations to the given tasks. The ideal assignment algorithm balances three
5231  * properties: (a) determinism, (b) even load distribution, and (c) consistency
5232  * across similar task lists. To maintain these properties, the algorithm sorts
5233  * all its input lists.
5234  */
5235 static List *
GreedyAssignTaskList(List * taskList)5236 GreedyAssignTaskList(List *taskList)
5237 {
5238 	List *assignedTaskList = NIL;
5239 	uint32 assignedTaskCount = 0;
5240 	uint32 taskCount = list_length(taskList);
5241 
5242 	/* get the worker node list and sort the list */
5243 	List *workerNodeList = ActiveReadableNodeList();
5244 	workerNodeList = SortList(workerNodeList, CompareWorkerNodes);
5245 
5246 	/*
5247 	 * We first sort tasks by their anchor shard id. We then walk over each task
5248 	 * in the sorted list, get the task's anchor shard id, and look up the shard
5249 	 * placements (locations) for this shard id. Next, we sort the placements by
5250 	 * their insertion time, and append them to a new list.
5251 	 */
5252 	taskList = SortList(taskList, CompareTasksByShardId);
5253 	List *activeShardPlacementLists = ActiveShardPlacementLists(taskList);
5254 
5255 	while (assignedTaskCount < taskCount)
5256 	{
5257 		ListCell *workerNodeCell = NULL;
5258 		uint32 loopStartTaskCount = assignedTaskCount;
5259 
5260 		/* walk over each node and check if we can assign a task to it */
5261 		foreach(workerNodeCell, workerNodeList)
5262 		{
5263 			WorkerNode *workerNode = (WorkerNode *) lfirst(workerNodeCell);
5264 
5265 			Task *assignedTask = GreedyAssignTask(workerNode, taskList,
5266 												  activeShardPlacementLists);
5267 			if (assignedTask != NULL)
5268 			{
5269 				assignedTaskList = lappend(assignedTaskList, assignedTask);
5270 				assignedTaskCount++;
5271 			}
5272 		}
5273 
5274 		/* if we could not assign any new tasks, avoid looping forever */
5275 		if (assignedTaskCount == loopStartTaskCount)
5276 		{
5277 			uint32 remainingTaskCount = taskCount - assignedTaskCount;
5278 			ereport(ERROR, (errmsg("failed to assign %u task(s) to worker nodes",
5279 								   remainingTaskCount)));
5280 		}
5281 	}
5282 
5283 	return assignedTaskList;
5284 }
5285 
5286 
5287 /*
5288  * GreedyAssignTask tries to assign a task to the given worker node. To do this,
5289  * the function walks over tasks' anchor shard ids, and finds the first set of
5290  * nodes the shards were replicated to. If any of these replica nodes and the
5291  * given worker node match, the corresponding task is assigned to that node. If
5292  * not, the function goes on to search the second set of replicas and so forth.
5293  *
5294  * Note that this function has side-effects; when the function assigns a new
5295  * task, it overwrites the corresponding task list pointer.
5296  */
5297 static Task *
GreedyAssignTask(WorkerNode * workerNode,List * taskList,List * activeShardPlacementLists)5298 GreedyAssignTask(WorkerNode *workerNode, List *taskList, List *activeShardPlacementLists)
5299 {
5300 	Task *assignedTask = NULL;
5301 	List *taskPlacementList = NIL;
5302 	ShardPlacement *primaryPlacement = NULL;
5303 	uint32 rotatePlacementListBy = 0;
5304 	uint32 replicaIndex = 0;
5305 	uint32 replicaCount = ShardReplicationFactor;
5306 	const char *workerName = workerNode->workerName;
5307 	const uint32 workerPort = workerNode->workerPort;
5308 
5309 	while ((assignedTask == NULL) && (replicaIndex < replicaCount))
5310 	{
5311 		/* walk over all tasks and try to assign one */
5312 		ListCell *taskCell = NULL;
5313 		ListCell *placementListCell = NULL;
5314 
5315 		forboth(taskCell, taskList, placementListCell, activeShardPlacementLists)
5316 		{
5317 			Task *task = (Task *) lfirst(taskCell);
5318 			List *placementList = (List *) lfirst(placementListCell);
5319 
5320 			/* check if we already assigned this task */
5321 			if (task == NULL)
5322 			{
5323 				continue;
5324 			}
5325 
5326 			/* check if we have enough replicas */
5327 			uint32 placementCount = list_length(placementList);
5328 			if (placementCount <= replicaIndex)
5329 			{
5330 				continue;
5331 			}
5332 
5333 			ShardPlacement *placement = (ShardPlacement *) list_nth(placementList,
5334 																	replicaIndex);
5335 			if ((strncmp(placement->nodeName, workerName, WORKER_LENGTH) == 0) &&
5336 				(placement->nodePort == workerPort))
5337 			{
5338 				/* we found a task to assign to the given worker node */
5339 				assignedTask = task;
5340 				taskPlacementList = placementList;
5341 				rotatePlacementListBy = replicaIndex;
5342 
5343 				/* overwrite task list to signal that this task is assigned */
5344 				SetListCellPtr(taskCell, NULL);
5345 				break;
5346 			}
5347 		}
5348 
5349 		/* go over the next set of shard replica placements */
5350 		replicaIndex++;
5351 	}
5352 
5353 	/* if we found a task placement list, rotate and assign task placements */
5354 	if (assignedTask != NULL)
5355 	{
5356 		taskPlacementList = LeftRotateList(taskPlacementList, rotatePlacementListBy);
5357 		assignedTask->taskPlacementList = taskPlacementList;
5358 
5359 		primaryPlacement = (ShardPlacement *) linitial(assignedTask->taskPlacementList);
5360 		ereport(DEBUG3, (errmsg("assigned task %u to node %s:%u", assignedTask->taskId,
5361 								primaryPlacement->nodeName,
5362 								primaryPlacement->nodePort)));
5363 	}
5364 
5365 	return assignedTask;
5366 }
5367 
5368 
5369 /*
5370  * FirstReplicaAssignTaskList assigns locations to the given tasks simply by
5371  * looking at placements for a given shard. A particular task's assignments are
5372  * then ordered by the insertion order of the relevant placements rows. In other
5373  * words, a task for a specific shard is simply assigned to the first replica
5374  * for that shard. This algorithm is extremely simple and intended for use when
5375  * a customer has placed shards carefully and wants strong guarantees about
5376  * which shards will be used by what nodes (i.e. for stronger memory residency
5377  * guarantees).
5378  */
5379 List *
FirstReplicaAssignTaskList(List * taskList)5380 FirstReplicaAssignTaskList(List *taskList)
5381 {
5382 	/* No additional reordering need take place for this algorithm */
5383 	ReorderFunction reorderFunction = NULL;
5384 
5385 	taskList = ReorderAndAssignTaskList(taskList, reorderFunction);
5386 
5387 	return taskList;
5388 }
5389 
5390 
5391 /*
5392  * RoundRobinAssignTaskList uses a round-robin algorithm to assign locations to
5393  * the given tasks. An ideal round-robin implementation requires keeping shared
5394  * state for task assignments; and we instead approximate our implementation by
5395  * relying on the sequentially increasing jobId. For each task, we mod its jobId
5396  * by the number of active shard placements, and ensure that we rotate between
5397  * these placements across subsequent queries.
5398  */
5399 List *
RoundRobinAssignTaskList(List * taskList)5400 RoundRobinAssignTaskList(List *taskList)
5401 {
5402 	taskList = ReorderAndAssignTaskList(taskList, RoundRobinReorder);
5403 
5404 	return taskList;
5405 }
5406 
5407 
5408 /*
5409  * RoundRobinReorder implements the core of the round-robin assignment policy.
5410  * It takes a placement list and rotates a copy of it based on the latest stable
5411  * transaction id provided by PostgreSQL.
5412  *
5413  * We prefer to use transactionId as the seed for the rotation to use the replicas
5414  * in the same worker node within the same transaction. This becomes more important
5415  * when we're reading from (the same or multiple) reference tables within a
5416  * transaction. With this approach, we can prevent reads to expand the worker nodes
5417  * that participate in a distributed transaction.
5418  *
5419  * Note that we prefer PostgreSQL's transactionId over distributed transactionId that
5420  * Citus generates since the distributed transactionId is generated during the execution
5421  * where as task-assignment happens duing the planning.
5422  */
5423 List *
RoundRobinReorder(List * placementList)5424 RoundRobinReorder(List *placementList)
5425 {
5426 	TransactionId transactionId = GetMyProcLocalTransactionId();
5427 	uint32 activePlacementCount = list_length(placementList);
5428 	uint32 roundRobinIndex = (transactionId % activePlacementCount);
5429 
5430 	placementList = LeftRotateList(placementList, roundRobinIndex);
5431 
5432 	return placementList;
5433 }
5434 
5435 
5436 /*
5437  * ReorderAndAssignTaskList finds the placements for a task based on its anchor
5438  * shard id and then sorts them by insertion time. If reorderFunction is given,
5439  * it is used to reorder the placements list in a custom fashion (for instance,
5440  * by rotation or shuffling). Returns the task list with placements assigned.
5441  */
5442 static List *
ReorderAndAssignTaskList(List * taskList,ReorderFunction reorderFunction)5443 ReorderAndAssignTaskList(List *taskList, ReorderFunction reorderFunction)
5444 {
5445 	List *assignedTaskList = NIL;
5446 	ListCell *taskCell = NULL;
5447 	ListCell *placementListCell = NULL;
5448 	uint32 unAssignedTaskCount = 0;
5449 
5450 	if (taskList == NIL)
5451 	{
5452 		return NIL;
5453 	}
5454 
5455 	/*
5456 	 * We first sort tasks by their anchor shard id. We then sort placements for
5457 	 * each anchor shard by the placement's insertion time. Note that we sort
5458 	 * these lists just to make our policy more deterministic.
5459 	 */
5460 	taskList = SortList(taskList, CompareTasksByShardId);
5461 	List *activeShardPlacementLists = ActiveShardPlacementLists(taskList);
5462 
5463 	forboth(taskCell, taskList, placementListCell, activeShardPlacementLists)
5464 	{
5465 		Task *task = (Task *) lfirst(taskCell);
5466 		List *placementList = (List *) lfirst(placementListCell);
5467 
5468 		/* inactive placements are already filtered out */
5469 		uint32 activePlacementCount = list_length(placementList);
5470 		if (activePlacementCount > 0)
5471 		{
5472 			if (reorderFunction != NULL)
5473 			{
5474 				placementList = reorderFunction(placementList);
5475 			}
5476 			task->taskPlacementList = placementList;
5477 
5478 			ShardPlacement *primaryPlacement = (ShardPlacement *) linitial(
5479 				task->taskPlacementList);
5480 			ereport(DEBUG3, (errmsg("assigned task %u to node %s:%u", task->taskId,
5481 									primaryPlacement->nodeName,
5482 									primaryPlacement->nodePort)));
5483 
5484 			assignedTaskList = lappend(assignedTaskList, task);
5485 		}
5486 		else
5487 		{
5488 			unAssignedTaskCount++;
5489 		}
5490 	}
5491 
5492 	/* if we have unassigned tasks, error out */
5493 	if (unAssignedTaskCount > 0)
5494 	{
5495 		ereport(ERROR, (errmsg("failed to assign %u task(s) to worker nodes",
5496 							   unAssignedTaskCount)));
5497 	}
5498 
5499 	return assignedTaskList;
5500 }
5501 
5502 
5503 /* Helper function to compare two tasks by their anchor shardId. */
5504 static int
CompareTasksByShardId(const void * leftElement,const void * rightElement)5505 CompareTasksByShardId(const void *leftElement, const void *rightElement)
5506 {
5507 	const Task *leftTask = *((const Task **) leftElement);
5508 	const Task *rightTask = *((const Task **) rightElement);
5509 
5510 	uint64 leftShardId = leftTask->anchorShardId;
5511 	uint64 rightShardId = rightTask->anchorShardId;
5512 
5513 	/* we compare 64-bit integers, instead of casting their difference to int */
5514 	if (leftShardId > rightShardId)
5515 	{
5516 		return 1;
5517 	}
5518 	else if (leftShardId < rightShardId)
5519 	{
5520 		return -1;
5521 	}
5522 	else
5523 	{
5524 		return 0;
5525 	}
5526 }
5527 
5528 
5529 /*
5530  * ActiveShardPlacementLists finds the active shard placement list for each task in
5531  * the given task list, sorts each shard placement list by shard creation time,
5532  * and adds the sorted placement list into a new list of lists. The function also
5533  * ensures a one-to-one mapping between each placement list in the new list of
5534  * lists and each task in the given task list.
5535  */
5536 static List *
ActiveShardPlacementLists(List * taskList)5537 ActiveShardPlacementLists(List *taskList)
5538 {
5539 	List *shardPlacementLists = NIL;
5540 	ListCell *taskCell = NULL;
5541 
5542 	foreach(taskCell, taskList)
5543 	{
5544 		Task *task = (Task *) lfirst(taskCell);
5545 		uint64 anchorShardId = task->anchorShardId;
5546 		List *shardPlacementList = ActiveShardPlacementList(anchorShardId);
5547 
5548 		/* filter out shard placements that reside in inactive nodes */
5549 		List *activeShardPlacementList = ActivePlacementList(shardPlacementList);
5550 		if (activeShardPlacementList == NIL)
5551 		{
5552 			ereport(ERROR,
5553 					(errmsg("no active placements were found for shard " UINT64_FORMAT,
5554 							anchorShardId)));
5555 		}
5556 
5557 		/* sort shard placements by their creation time */
5558 		activeShardPlacementList = SortList(activeShardPlacementList,
5559 											CompareShardPlacements);
5560 		shardPlacementLists = lappend(shardPlacementLists, activeShardPlacementList);
5561 	}
5562 
5563 	return shardPlacementLists;
5564 }
5565 
5566 
5567 /*
5568  * CompareShardPlacements compares two shard placements by their tuple oid; this
5569  * oid reflects the tuple's insertion order into pg_dist_placement.
5570  */
5571 int
CompareShardPlacements(const void * leftElement,const void * rightElement)5572 CompareShardPlacements(const void *leftElement, const void *rightElement)
5573 {
5574 	const ShardPlacement *leftPlacement = *((const ShardPlacement **) leftElement);
5575 	const ShardPlacement *rightPlacement = *((const ShardPlacement **) rightElement);
5576 
5577 	uint64 leftPlacementId = leftPlacement->placementId;
5578 	uint64 rightPlacementId = rightPlacement->placementId;
5579 
5580 	if (leftPlacementId < rightPlacementId)
5581 	{
5582 		return -1;
5583 	}
5584 	else if (leftPlacementId > rightPlacementId)
5585 	{
5586 		return 1;
5587 	}
5588 	else
5589 	{
5590 		return 0;
5591 	}
5592 }
5593 
5594 
5595 /*
5596  * ActivePlacementList walks over shard placements in the given list, and finds
5597  * the corresponding worker node for each placement. The function then checks if
5598  * that worker node is active, and if it is, appends the placement to a new list.
5599  * The function last returns the new placement list.
5600  */
5601 static List *
ActivePlacementList(List * placementList)5602 ActivePlacementList(List *placementList)
5603 {
5604 	List *activePlacementList = NIL;
5605 	ListCell *placementCell = NULL;
5606 
5607 	foreach(placementCell, placementList)
5608 	{
5609 		ShardPlacement *placement = (ShardPlacement *) lfirst(placementCell);
5610 
5611 		/* check if the worker node for this shard placement is active */
5612 		WorkerNode *workerNode = FindWorkerNode(placement->nodeName, placement->nodePort);
5613 		if (workerNode != NULL && workerNode->isActive)
5614 		{
5615 			activePlacementList = lappend(activePlacementList, placement);
5616 		}
5617 	}
5618 
5619 	return activePlacementList;
5620 }
5621 
5622 
5623 /*
5624  * LeftRotateList returns a copy of the given list that has been cyclically
5625  * shifted to the left by the given rotation count. For this, the function
5626  * repeatedly moves the list's first element to the end of the list, and
5627  * then returns the newly rotated list.
5628  */
5629 static List *
LeftRotateList(List * list,uint32 rotateCount)5630 LeftRotateList(List *list, uint32 rotateCount)
5631 {
5632 	List *rotatedList = list_copy(list);
5633 
5634 	for (uint32 rotateIndex = 0; rotateIndex < rotateCount; rotateIndex++)
5635 	{
5636 		void *firstElement = linitial(rotatedList);
5637 
5638 		rotatedList = list_delete_first(rotatedList);
5639 		rotatedList = lappend(rotatedList, firstElement);
5640 	}
5641 
5642 	return rotatedList;
5643 }
5644 
5645 
5646 /*
5647  * FindDependentMergeTaskList walks over the given task's dependent task list,
5648  * finds the merge tasks in the list, and returns those found tasks in a new
5649  * list.
5650  */
5651 static List *
FindDependentMergeTaskList(Task * sqlTask)5652 FindDependentMergeTaskList(Task *sqlTask)
5653 {
5654 	List *dependentMergeTaskList = NIL;
5655 	List *dependentTaskList = sqlTask->dependentTaskList;
5656 
5657 	ListCell *dependentTaskCell = NULL;
5658 	foreach(dependentTaskCell, dependentTaskList)
5659 	{
5660 		Task *dependentTask = (Task *) lfirst(dependentTaskCell);
5661 		if (dependentTask->taskType == MERGE_TASK)
5662 		{
5663 			dependentMergeTaskList = lappend(dependentMergeTaskList, dependentTask);
5664 		}
5665 	}
5666 
5667 	return dependentMergeTaskList;
5668 }
5669 
5670 
5671 /*
5672  * AssignDualHashTaskList uses a round-robin algorithm to assign locations to
5673  * tasks; these tasks don't have any anchor shards and instead operate on (hash
5674  * repartitioned) merged tables.
5675  */
5676 static List *
AssignDualHashTaskList(List * taskList)5677 AssignDualHashTaskList(List *taskList)
5678 {
5679 	List *assignedTaskList = NIL;
5680 	ListCell *taskCell = NULL;
5681 	Task *firstTask = (Task *) linitial(taskList);
5682 	uint64 jobId = firstTask->jobId;
5683 	uint32 assignedTaskIndex = 0;
5684 
5685 	/*
5686 	 * We start assigning tasks at an index determined by the jobId. This way,
5687 	 * if subsequent jobs have a small number of tasks, we won't allocate the
5688 	 * tasks to the same worker repeatedly.
5689 	 */
5690 	List *workerNodeList = ActiveReadableNodeList();
5691 	uint32 workerNodeCount = (uint32) list_length(workerNodeList);
5692 	uint32 beginningNodeIndex = jobId % workerNodeCount;
5693 
5694 	/* sort worker node list and task list for deterministic results */
5695 	workerNodeList = SortList(workerNodeList, CompareWorkerNodes);
5696 	taskList = SortList(taskList, CompareTasksByTaskId);
5697 
5698 	foreach(taskCell, taskList)
5699 	{
5700 		Task *task = (Task *) lfirst(taskCell);
5701 		List *taskPlacementList = NIL;
5702 
5703 		for (uint32 replicaIndex = 0; replicaIndex < ShardReplicationFactor;
5704 			 replicaIndex++)
5705 		{
5706 			uint32 assignmentOffset = beginningNodeIndex + assignedTaskIndex +
5707 									  replicaIndex;
5708 			uint32 assignmentIndex = assignmentOffset % workerNodeCount;
5709 			WorkerNode *workerNode = list_nth(workerNodeList, assignmentIndex);
5710 
5711 			ShardPlacement *taskPlacement = CitusMakeNode(ShardPlacement);
5712 			SetPlacementNodeMetadata(taskPlacement, workerNode);
5713 
5714 			taskPlacementList = lappend(taskPlacementList, taskPlacement);
5715 		}
5716 
5717 		task->taskPlacementList = taskPlacementList;
5718 
5719 		ShardPlacement *primaryPlacement = (ShardPlacement *) linitial(
5720 			task->taskPlacementList);
5721 		ereport(DEBUG3, (errmsg("assigned task %u to node %s:%u", task->taskId,
5722 								primaryPlacement->nodeName,
5723 								primaryPlacement->nodePort)));
5724 
5725 		assignedTaskList = lappend(assignedTaskList, task);
5726 		assignedTaskIndex++;
5727 	}
5728 
5729 	return assignedTaskList;
5730 }
5731 
5732 
5733 /*
5734  * SetPlacementNodeMetadata sets nodename, nodeport, nodeid and groupid for the placement.
5735  */
5736 void
SetPlacementNodeMetadata(ShardPlacement * placement,WorkerNode * workerNode)5737 SetPlacementNodeMetadata(ShardPlacement *placement, WorkerNode *workerNode)
5738 {
5739 	placement->nodeName = pstrdup(workerNode->workerName);
5740 	placement->nodePort = workerNode->workerPort;
5741 	placement->nodeId = workerNode->nodeId;
5742 	placement->groupId = workerNode->groupId;
5743 }
5744 
5745 
5746 /*
5747  * CompareTasksByTaskId is a helper function to compare two tasks by their taskId.
5748  */
5749 int
CompareTasksByTaskId(const void * leftElement,const void * rightElement)5750 CompareTasksByTaskId(const void *leftElement, const void *rightElement)
5751 {
5752 	const Task *leftTask = *((const Task **) leftElement);
5753 	const Task *rightTask = *((const Task **) rightElement);
5754 
5755 	uint32 leftTaskId = leftTask->taskId;
5756 	uint32 rightTaskId = rightTask->taskId;
5757 
5758 	int taskIdDiff = leftTaskId - rightTaskId;
5759 	return taskIdDiff;
5760 }
5761 
5762 
5763 /*
5764  * AssignDataFetchDependencies walks over tasks in the given sql or merge task
5765  * list. The function then propagates worker node assignments from each sql or
5766  * merge task to the task's data fetch dependencies.
5767  */
5768 static void
AssignDataFetchDependencies(List * taskList)5769 AssignDataFetchDependencies(List *taskList)
5770 {
5771 	ListCell *taskCell = NULL;
5772 	foreach(taskCell, taskList)
5773 	{
5774 		Task *task = (Task *) lfirst(taskCell);
5775 		List *dependentTaskList = task->dependentTaskList;
5776 		ListCell *dependentTaskCell = NULL;
5777 
5778 		Assert(task->taskPlacementList != NIL);
5779 		Assert(task->taskType == READ_TASK || task->taskType == MERGE_TASK);
5780 
5781 		foreach(dependentTaskCell, dependentTaskList)
5782 		{
5783 			Task *dependentTask = (Task *) lfirst(dependentTaskCell);
5784 			if (dependentTask->taskType == MAP_OUTPUT_FETCH_TASK)
5785 			{
5786 				dependentTask->taskPlacementList = task->taskPlacementList;
5787 			}
5788 		}
5789 	}
5790 }
5791 
5792 
5793 /*
5794  * TaskListHighestTaskId walks over tasks in the given task list, finds the task
5795  * that has the largest taskId, and returns that taskId.
5796  *
5797  * Note: This function assumes that the dependent taskId's are set before the
5798  * taskId's for the given task list.
5799  */
5800 static uint32
TaskListHighestTaskId(List * taskList)5801 TaskListHighestTaskId(List *taskList)
5802 {
5803 	uint32 highestTaskId = 0;
5804 	ListCell *taskCell = NULL;
5805 
5806 	foreach(taskCell, taskList)
5807 	{
5808 		Task *task = (Task *) lfirst(taskCell);
5809 		if (task->taskId > highestTaskId)
5810 		{
5811 			highestTaskId = task->taskId;
5812 		}
5813 	}
5814 
5815 	return highestTaskId;
5816 }
5817