1 /*-------------------------------------------------------------------------
2 *
3 * multi_physical_planner.c
4 * Routines for creating physical plans from given multi-relational algebra
5 * trees.
6 *
7 * Copyright (c) Citus Data, Inc.
8 *
9 * $Id$
10 *
11 *-------------------------------------------------------------------------
12 */
13
14 #include "postgres.h"
15
16 #include "distributed/pg_version_constants.h"
17
18 #include <math.h>
19 #include <stdint.h>
20
21 #include "miscadmin.h"
22
23 #include "access/genam.h"
24 #include "access/hash.h"
25 #include "access/heapam.h"
26 #include "access/nbtree.h"
27 #include "access/skey.h"
28 #include "access/xlog.h"
29 #include "catalog/pg_aggregate.h"
30 #include "catalog/pg_am.h"
31 #include "catalog/pg_operator.h"
32 #include "catalog/pg_type.h"
33 #include "commands/defrem.h"
34 #include "commands/sequence.h"
35 #include "distributed/backend_data.h"
36 #include "distributed/listutils.h"
37 #include "distributed/citus_nodefuncs.h"
38 #include "distributed/citus_nodes.h"
39 #include "distributed/citus_ruleutils.h"
40 #include "distributed/colocation_utils.h"
41 #include "distributed/deparse_shard_query.h"
42 #include "distributed/coordinator_protocol.h"
43 #include "distributed/metadata_cache.h"
44 #include "distributed/multi_router_planner.h"
45 #include "distributed/multi_join_order.h"
46 #include "distributed/multi_logical_optimizer.h"
47 #include "distributed/multi_logical_planner.h"
48 #include "distributed/multi_partitioning_utils.h"
49 #include "distributed/multi_physical_planner.h"
50 #include "distributed/log_utils.h"
51 #include "distributed/pg_dist_partition.h"
52 #include "distributed/pg_dist_shard.h"
53 #include "distributed/query_pushdown_planning.h"
54 #include "distributed/query_utils.h"
55 #include "distributed/shardinterval_utils.h"
56 #include "distributed/shard_pruning.h"
57 #include "distributed/string_utils.h"
58
59 #include "distributed/worker_manager.h"
60 #include "distributed/worker_protocol.h"
61 #include "distributed/version_compat.h"
62 #include "nodes/makefuncs.h"
63 #include "nodes/nodeFuncs.h"
64 #include "optimizer/clauses.h"
65 #include "nodes/pathnodes.h"
66 #include "optimizer/optimizer.h"
67 #include "optimizer/restrictinfo.h"
68 #include "optimizer/tlist.h"
69 #include "parser/parse_relation.h"
70 #include "parser/parsetree.h"
71 #include "rewrite/rewriteManip.h"
72 #include "utils/builtins.h"
73 #include "utils/catcache.h"
74 #include "utils/datum.h"
75 #include "utils/fmgroids.h"
76 #include "utils/guc.h"
77 #include "utils/lsyscache.h"
78 #include "utils/memutils.h"
79 #include "utils/rel.h"
80 #include "utils/typcache.h"
81
82 /* RepartitionJoinBucketCountPerNode determines bucket amount during repartitions */
83 int RepartitionJoinBucketCountPerNode = 8;
84
85 /* Policy to use when assigning tasks to worker nodes */
86 int TaskAssignmentPolicy = TASK_ASSIGNMENT_GREEDY;
87 bool EnableUniqueJobIds = true;
88
89
90 /*
91 * OperatorCache is used for caching operator identifiers for given typeId,
92 * accessMethodId and strategyNumber. It is initialized to empty list as
93 * there are no items in the cache.
94 */
95 static List *OperatorCache = NIL;
96
97
98 /* context passed down in AddAnyValueAggregates mutator */
99 typedef struct AddAnyValueAggregatesContext
100 {
101 /* SortGroupClauses corresponding to the GROUP BY clause */
102 List *groupClauseList;
103
104 /* TargetEntry's to which the GROUP BY clauses refer */
105 List *groupByTargetEntryList;
106
107 /*
108 * haveNonVarGrouping is true if there are expressions in the
109 * GROUP BY target entries. We use this as an optimisation to
110 * skip expensive checks when possible.
111 */
112 bool haveNonVarGrouping;
113 } AddAnyValueAggregatesContext;
114
115
116 /* Local functions forward declarations for job creation */
117 static Job * BuildJobTree(MultiTreeRoot *multiTree);
118 static MultiNode * LeftMostNode(MultiTreeRoot *multiTree);
119 static Oid RangePartitionJoinBaseRelationId(MultiJoin *joinNode);
120 static MultiTable * FindTableNode(MultiNode *multiNode, int rangeTableId);
121 static Query * BuildJobQuery(MultiNode *multiNode, List *dependentJobList);
122 static List * BaseRangeTableList(MultiNode *multiNode);
123 static List * QueryTargetList(MultiNode *multiNode);
124 static List * TargetEntryList(List *expressionList);
125 static Node * AddAnyValueAggregates(Node *node, AddAnyValueAggregatesContext *context);
126 static List * QueryGroupClauseList(MultiNode *multiNode);
127 static List * QuerySelectClauseList(MultiNode *multiNode);
128 static List * QueryFromList(List *rangeTableList);
129 static Node * QueryJoinTree(MultiNode *multiNode, List *dependentJobList,
130 List **rangeTableList);
131 static void SetJoinRelatedColumnsCompat(RangeTblEntry *rangeTableEntry,
132 Oid leftRelId,
133 Oid rightRelId,
134 List *leftColumnVars,
135 List *rightColumnVars);
136 static RangeTblEntry * JoinRangeTableEntry(JoinExpr *joinExpr, List *dependentJobList,
137 List *rangeTableList);
138 static int ExtractRangeTableId(Node *node);
139 static void ExtractColumns(RangeTblEntry *callingRTE, int rangeTableId,
140 List **columnNames, List **columnVars);
141 static RangeTblEntry * ConstructCallingRTE(RangeTblEntry *rangeTableEntry,
142 List *dependentJobList);
143 static Query * BuildSubqueryJobQuery(MultiNode *multiNode);
144 static void UpdateAllColumnAttributes(Node *columnContainer, List *rangeTableList,
145 List *dependentJobList);
146 static void UpdateColumnAttributes(Var *column, List *rangeTableList,
147 List *dependentJobList);
148 static Index NewTableId(Index originalTableId, List *rangeTableList);
149 static AttrNumber NewColumnId(Index originalTableId, AttrNumber originalColumnId,
150 RangeTblEntry *newRangeTableEntry, List *dependentJobList);
151 static Job * JobForRangeTable(List *jobList, RangeTblEntry *rangeTableEntry);
152 static Job * JobForTableIdList(List *jobList, List *searchedTableIdList);
153 static List * ChildNodeList(MultiNode *multiNode);
154 static Job * BuildJob(Query *jobQuery, List *dependentJobList);
155 static MapMergeJob * BuildMapMergeJob(Query *jobQuery, List *dependentJobList,
156 Var *partitionKey, PartitionType partitionType,
157 Oid baseRelationId,
158 BoundaryNodeJobType boundaryNodeJobType);
159 static uint32 HashPartitionCount(void);
160 static ArrayType * SplitPointObject(ShardInterval **shardIntervalArray,
161 uint32 shardIntervalCount);
162
163 /* Local functions forward declarations for task list creation and helper functions */
164 static Job * BuildJobTreeTaskList(Job *jobTree,
165 PlannerRestrictionContext *plannerRestrictionContext);
166 static bool IsInnerTableOfOuterJoin(RelationRestriction *relationRestriction);
167 static void ErrorIfUnsupportedShardDistribution(Query *query);
168 static Task * QueryPushdownTaskCreate(Query *originalQuery, int shardIndex,
169 RelationRestrictionContext *restrictionContext,
170 uint32 taskId,
171 TaskType taskType,
172 bool modifyRequiresCoordinatorEvaluation,
173 DeferredErrorMessage **planningError);
174 static bool ShardIntervalsEqual(FmgrInfo *comparisonFunction,
175 Oid collation,
176 ShardInterval *firstInterval,
177 ShardInterval *secondInterval);
178 static List * SqlTaskList(Job *job);
179 static bool DependsOnHashPartitionJob(Job *job);
180 static uint32 AnchorRangeTableId(List *rangeTableList);
181 static List * BaseRangeTableIdList(List *rangeTableList);
182 static List * AnchorRangeTableIdList(List *rangeTableList, List *baseRangeTableIdList);
183 static void AdjustColumnOldAttributes(List *expressionList);
184 static List * RangeTableFragmentsList(List *rangeTableList, List *whereClauseList,
185 List *dependentJobList);
186 static OperatorCacheEntry * LookupOperatorByType(Oid typeId, Oid accessMethodId,
187 int16 strategyNumber);
188 static Oid GetOperatorByType(Oid typeId, Oid accessMethodId, int16 strategyNumber);
189 static List * FragmentCombinationList(List *rangeTableFragmentsList, Query *jobQuery,
190 List *dependentJobList);
191 static JoinSequenceNode * JoinSequenceArray(List *rangeTableFragmentsList,
192 Query *jobQuery, List *dependentJobList);
193 static bool PartitionedOnColumn(Var *column, List *rangeTableList,
194 List *dependentJobList);
195 static void CheckJoinBetweenColumns(OpExpr *joinClause);
196 static List * FindRangeTableFragmentsList(List *rangeTableFragmentsList, int taskId);
197 static bool JoinPrunable(RangeTableFragment *leftFragment,
198 RangeTableFragment *rightFragment);
199 static ShardInterval * FragmentInterval(RangeTableFragment *fragment);
200 static StringInfo FragmentIntervalString(ShardInterval *fragmentInterval);
201 static List * DataFetchTaskList(uint64 jobId, uint32 taskIdIndex, List *fragmentList);
202 static StringInfo DatumArrayString(Datum *datumArray, uint32 datumCount, Oid datumTypeId);
203 static List * BuildRelationShardList(List *rangeTableList, List *fragmentList);
204 static void UpdateRangeTableAlias(List *rangeTableList, List *fragmentList);
205 static Alias * FragmentAlias(RangeTblEntry *rangeTableEntry,
206 RangeTableFragment *fragment);
207 static uint64 AnchorShardId(List *fragmentList, uint32 anchorRangeTableId);
208 static List * PruneSqlTaskDependencies(List *sqlTaskList);
209 static List * AssignTaskList(List *sqlTaskList);
210 static bool HasMergeTaskDependencies(List *sqlTaskList);
211 static List * GreedyAssignTaskList(List *taskList);
212 static Task * GreedyAssignTask(WorkerNode *workerNode, List *taskList,
213 List *activeShardPlacementLists);
214 static List * ReorderAndAssignTaskList(List *taskList,
215 ReorderFunction reorderFunction);
216 static int CompareTasksByShardId(const void *leftElement, const void *rightElement);
217 static List * ActiveShardPlacementLists(List *taskList);
218 static List * ActivePlacementList(List *placementList);
219 static List * LeftRotateList(List *list, uint32 rotateCount);
220 static List * FindDependentMergeTaskList(Task *sqlTask);
221 static List * AssignDualHashTaskList(List *taskList);
222 static void AssignDataFetchDependencies(List *taskList);
223 static uint32 TaskListHighestTaskId(List *taskList);
224 static List * MapTaskList(MapMergeJob *mapMergeJob, List *filterTaskList);
225 static StringInfo CreateMapQueryString(MapMergeJob *mapMergeJob, Task *filterTask,
226 uint32 partitionColumnIndex);
227 static List * MergeTaskList(MapMergeJob *mapMergeJob, List *mapTaskList,
228 uint32 taskIdIndex);
229 static StringInfo ColumnNameArrayString(uint32 columnCount, uint64 generatingJobId);
230 static StringInfo ColumnTypeArrayString(List *targetEntryList);
231 static bool CoPlacedShardIntervals(ShardInterval *firstInterval,
232 ShardInterval *secondInterval);
233
234 static List * FetchEqualityAttrNumsForRTEOpExpr(OpExpr *opExpr);
235 static List * FetchEqualityAttrNumsForRTEBoolExpr(BoolExpr *boolExpr);
236 static List * FetchEqualityAttrNumsForList(List *nodeList);
237 static int PartitionColumnIndex(Var *targetVar, List *targetList);
238 #if PG_VERSION_NUM >= PG_VERSION_13
239 static List * GetColumnOriginalIndexes(Oid relationId);
240 #endif
241
242
243 /*
244 * CreatePhysicalDistributedPlan is the entry point for physical plan generation. The
245 * function builds the physical plan; this plan includes the list of tasks to be
246 * executed on worker nodes, and the final query to run on the master node.
247 */
248 DistributedPlan *
CreatePhysicalDistributedPlan(MultiTreeRoot * multiTree,PlannerRestrictionContext * plannerRestrictionContext)249 CreatePhysicalDistributedPlan(MultiTreeRoot *multiTree,
250 PlannerRestrictionContext *plannerRestrictionContext)
251 {
252 /* build the worker job tree and check that we only have one job in the tree */
253 Job *workerJob = BuildJobTree(multiTree);
254
255 /* create the tree of executable tasks for the worker job */
256 workerJob = BuildJobTreeTaskList(workerJob, plannerRestrictionContext);
257
258 /* build the final merge query to execute on the master */
259 List *masterDependentJobList = list_make1(workerJob);
260 Query *combineQuery = BuildJobQuery((MultiNode *) multiTree, masterDependentJobList);
261
262 DistributedPlan *distributedPlan = CitusMakeNode(DistributedPlan);
263 distributedPlan->workerJob = workerJob;
264 distributedPlan->combineQuery = combineQuery;
265 distributedPlan->modLevel = ROW_MODIFY_READONLY;
266 distributedPlan->expectResults = true;
267
268 return distributedPlan;
269 }
270
271
272 /*
273 * ModifyLocalTableJob returns true if the given task contains
274 * a modification of local table.
275 */
276 bool
ModifyLocalTableJob(Job * job)277 ModifyLocalTableJob(Job *job)
278 {
279 if (job == NULL)
280 {
281 return false;
282 }
283 List *taskList = job->taskList;
284 if (list_length(taskList) != 1)
285 {
286 return false;
287 }
288 Task *singleTask = (Task *) linitial(taskList);
289 return singleTask->isLocalTableModification;
290 }
291
292
293 /*
294 * BuildJobTree builds the physical job tree from the given logical plan tree.
295 * The function walks over the logical plan from the bottom up, finds boundaries
296 * for jobs, and creates the query structure for each job. The function also
297 * sets dependencies between jobs, and then returns the top level worker job.
298 */
299 static Job *
BuildJobTree(MultiTreeRoot * multiTree)300 BuildJobTree(MultiTreeRoot *multiTree)
301 {
302 /* start building the tree from the deepest left node */
303 MultiNode *leftMostNode = LeftMostNode(multiTree);
304 MultiNode *currentNode = leftMostNode;
305 MultiNode *parentNode = ParentNode(currentNode);
306 List *loopDependentJobList = NIL;
307 Job *topLevelJob = NULL;
308
309 while (parentNode != NULL)
310 {
311 CitusNodeTag currentNodeType = CitusNodeTag(currentNode);
312 CitusNodeTag parentNodeType = CitusNodeTag(parentNode);
313 BoundaryNodeJobType boundaryNodeJobType = JOB_INVALID_FIRST;
314
315 /* we first check if this node forms the boundary for a remote job */
316 if (currentNodeType == T_MultiJoin)
317 {
318 MultiJoin *joinNode = (MultiJoin *) currentNode;
319 if (joinNode->joinRuleType == SINGLE_HASH_PARTITION_JOIN ||
320 joinNode->joinRuleType == SINGLE_RANGE_PARTITION_JOIN ||
321 joinNode->joinRuleType == DUAL_PARTITION_JOIN)
322 {
323 boundaryNodeJobType = JOIN_MAP_MERGE_JOB;
324 }
325 }
326 else if (currentNodeType == T_MultiPartition &&
327 parentNodeType == T_MultiExtendedOp)
328 {
329 boundaryNodeJobType = SUBQUERY_MAP_MERGE_JOB;
330 }
331 else if (currentNodeType == T_MultiCollect &&
332 parentNodeType != T_MultiPartition)
333 {
334 boundaryNodeJobType = TOP_LEVEL_WORKER_JOB;
335 }
336
337 /*
338 * If this node is at the boundary for a repartition or top level worker
339 * job, we build the corresponding job(s) and set their dependencies.
340 */
341 if (boundaryNodeJobType == JOIN_MAP_MERGE_JOB)
342 {
343 MultiJoin *joinNode = (MultiJoin *) currentNode;
344 MultiNode *leftChildNode = joinNode->binaryNode.leftChildNode;
345 MultiNode *rightChildNode = joinNode->binaryNode.rightChildNode;
346
347 PartitionType partitionType = PARTITION_INVALID_FIRST;
348 Oid baseRelationId = InvalidOid;
349
350 if (joinNode->joinRuleType == SINGLE_RANGE_PARTITION_JOIN)
351 {
352 partitionType = RANGE_PARTITION_TYPE;
353 baseRelationId = RangePartitionJoinBaseRelationId(joinNode);
354 }
355 else if (joinNode->joinRuleType == SINGLE_HASH_PARTITION_JOIN)
356 {
357 partitionType = SINGLE_HASH_PARTITION_TYPE;
358 baseRelationId = RangePartitionJoinBaseRelationId(joinNode);
359 }
360 else if (joinNode->joinRuleType == DUAL_PARTITION_JOIN)
361 {
362 partitionType = DUAL_HASH_PARTITION_TYPE;
363 }
364
365 if (CitusIsA(leftChildNode, MultiPartition))
366 {
367 MultiPartition *partitionNode = (MultiPartition *) leftChildNode;
368 MultiNode *queryNode = GrandChildNode((MultiUnaryNode *) partitionNode);
369 Var *partitionKey = partitionNode->partitionColumn;
370
371 /* build query and partition job */
372 List *dependentJobList = list_copy(loopDependentJobList);
373 Query *jobQuery = BuildJobQuery(queryNode, dependentJobList);
374
375 MapMergeJob *mapMergeJob = BuildMapMergeJob(jobQuery, dependentJobList,
376 partitionKey, partitionType,
377 baseRelationId,
378 JOIN_MAP_MERGE_JOB);
379
380 /* reset dependent job list */
381 loopDependentJobList = NIL;
382 loopDependentJobList = list_make1(mapMergeJob);
383 }
384
385 if (CitusIsA(rightChildNode, MultiPartition))
386 {
387 MultiPartition *partitionNode = (MultiPartition *) rightChildNode;
388 MultiNode *queryNode = GrandChildNode((MultiUnaryNode *) partitionNode);
389 Var *partitionKey = partitionNode->partitionColumn;
390
391 /*
392 * The right query and right partition job do not depend on any
393 * jobs since our logical plan tree is left deep.
394 */
395 Query *jobQuery = BuildJobQuery(queryNode, NIL);
396 MapMergeJob *mapMergeJob = BuildMapMergeJob(jobQuery, NIL,
397 partitionKey, partitionType,
398 baseRelationId,
399 JOIN_MAP_MERGE_JOB);
400
401 /* append to the dependent job list for on-going dependencies */
402 loopDependentJobList = lappend(loopDependentJobList, mapMergeJob);
403 }
404 }
405 else if (boundaryNodeJobType == TOP_LEVEL_WORKER_JOB)
406 {
407 MultiNode *childNode = ChildNode((MultiUnaryNode *) currentNode);
408 List *dependentJobList = list_copy(loopDependentJobList);
409 bool subqueryPushdown = false;
410
411 List *subqueryMultiTableList = SubqueryMultiTableList(childNode);
412 int subqueryCount = list_length(subqueryMultiTableList);
413
414 if (subqueryCount > 0)
415 {
416 subqueryPushdown = true;
417 }
418
419 /*
420 * Build top level query. If subquery pushdown is set, we use
421 * sligthly different version of BuildJobQuery(). They are similar
422 * but we don't need some parts of BuildJobQuery() for subquery
423 * pushdown such as updating column attributes etc.
424 */
425 if (subqueryPushdown)
426 {
427 Query *topLevelQuery = BuildSubqueryJobQuery(childNode);
428
429 topLevelJob = BuildJob(topLevelQuery, dependentJobList);
430 topLevelJob->subqueryPushdown = true;
431 }
432 else
433 {
434 Query *topLevelQuery = BuildJobQuery(childNode, dependentJobList);
435
436 topLevelJob = BuildJob(topLevelQuery, dependentJobList);
437 }
438 }
439
440 /* walk up the tree */
441 currentNode = parentNode;
442 parentNode = ParentNode(currentNode);
443 }
444
445 return topLevelJob;
446 }
447
448
449 /*
450 * LeftMostNode finds the deepest left node in the left-deep logical plan tree.
451 * We build the physical plan by traversing the logical plan from the bottom up;
452 * and this function helps us find the bottom of the logical tree.
453 */
454 static MultiNode *
LeftMostNode(MultiTreeRoot * multiTree)455 LeftMostNode(MultiTreeRoot *multiTree)
456 {
457 MultiNode *currentNode = (MultiNode *) multiTree;
458 MultiNode *leftChildNode = ChildNode((MultiUnaryNode *) multiTree);
459
460 while (leftChildNode != NULL)
461 {
462 currentNode = leftChildNode;
463
464 if (UnaryOperator(currentNode))
465 {
466 leftChildNode = ChildNode((MultiUnaryNode *) currentNode);
467 }
468 else if (BinaryOperator(currentNode))
469 {
470 MultiBinaryNode *binaryNode = (MultiBinaryNode *) currentNode;
471 leftChildNode = binaryNode->leftChildNode;
472 }
473 }
474
475 return currentNode;
476 }
477
478
479 /*
480 * RangePartitionJoinBaseRelationId finds partition node from join node, and
481 * returns base relation id of this node. Note that this function assumes that
482 * given join node is range partition join type.
483 */
484 static Oid
RangePartitionJoinBaseRelationId(MultiJoin * joinNode)485 RangePartitionJoinBaseRelationId(MultiJoin *joinNode)
486 {
487 MultiPartition *partitionNode = NULL;
488
489 MultiNode *leftChildNode = joinNode->binaryNode.leftChildNode;
490 MultiNode *rightChildNode = joinNode->binaryNode.rightChildNode;
491
492 if (CitusIsA(leftChildNode, MultiPartition))
493 {
494 partitionNode = (MultiPartition *) leftChildNode;
495 }
496 else if (CitusIsA(rightChildNode, MultiPartition))
497 {
498 partitionNode = (MultiPartition *) rightChildNode;
499 }
500
501 Index baseTableId = partitionNode->splitPointTableId;
502 MultiTable *baseTable = FindTableNode((MultiNode *) joinNode, baseTableId);
503 Oid baseRelationId = baseTable->relationId;
504
505 return baseRelationId;
506 }
507
508
509 /*
510 * FindTableNode walks over the given logical plan tree, and returns the table
511 * node that corresponds to the given range tableId.
512 */
513 static MultiTable *
FindTableNode(MultiNode * multiNode,int rangeTableId)514 FindTableNode(MultiNode *multiNode, int rangeTableId)
515 {
516 MultiTable *foundTableNode = NULL;
517 List *tableNodeList = FindNodesOfType(multiNode, T_MultiTable);
518 ListCell *tableNodeCell = NULL;
519
520 foreach(tableNodeCell, tableNodeList)
521 {
522 MultiTable *tableNode = (MultiTable *) lfirst(tableNodeCell);
523 if (tableNode->rangeTableId == rangeTableId)
524 {
525 foundTableNode = tableNode;
526 break;
527 }
528 }
529
530 Assert(foundTableNode != NULL);
531 return foundTableNode;
532 }
533
534
535 /*
536 * BuildJobQuery traverses the given logical plan tree, determines the job that
537 * corresponds to this part of the tree, and builds the query structure for that
538 * particular job. The function assumes that jobs this particular job depends on
539 * have already been built, as their output is needed to build the query.
540 */
541 static Query *
BuildJobQuery(MultiNode * multiNode,List * dependentJobList)542 BuildJobQuery(MultiNode *multiNode, List *dependentJobList)
543 {
544 bool updateColumnAttributes = false;
545 List *targetList = NIL;
546 List *sortClauseList = NIL;
547 Node *limitCount = NULL;
548 Node *limitOffset = NULL;
549 #if PG_VERSION_NUM >= PG_VERSION_13
550 LimitOption limitOption = LIMIT_OPTION_DEFAULT;
551 #endif
552 Node *havingQual = NULL;
553 bool hasDistinctOn = false;
554 List *distinctClause = NIL;
555 bool isRepartitionJoin = false;
556 bool hasWindowFuncs = false;
557 List *windowClause = NIL;
558
559 /* we start building jobs from below the collect node */
560 Assert(!CitusIsA(multiNode, MultiCollect));
561
562 /*
563 * First check if we are building a master/worker query. If we are building
564 * a worker query, we update the column attributes for target entries, select
565 * and join columns. Because if underlying query includes repartition joins,
566 * then we create multiple queries from a join. In this case, range table lists
567 * and column lists are subject to change.
568 *
569 * Note that we don't do this for master queries, as column attributes for
570 * master target entries are already set during the master/worker split.
571 */
572 MultiNode *parentNode = ParentNode(multiNode);
573 if (parentNode != NULL)
574 {
575 updateColumnAttributes = true;
576 }
577
578 /*
579 * If we are building this query on a repartitioned subquery job then we
580 * don't need to update column attributes.
581 */
582 if (dependentJobList != NIL)
583 {
584 Job *job = (Job *) linitial(dependentJobList);
585 if (CitusIsA(job, MapMergeJob))
586 {
587 MapMergeJob *mapMergeJob = (MapMergeJob *) job;
588 isRepartitionJoin = true;
589 if (mapMergeJob->reduceQuery)
590 {
591 updateColumnAttributes = false;
592 }
593 }
594 }
595
596 /*
597 * If we have an extended operator, then we copy the operator's target list.
598 * Otherwise, we use the target list based on the MultiProject node at this
599 * level in the query tree.
600 */
601 List *extendedOpNodeList = FindNodesOfType(multiNode, T_MultiExtendedOp);
602 if (extendedOpNodeList != NIL)
603 {
604 MultiExtendedOp *extendedOp = (MultiExtendedOp *) linitial(extendedOpNodeList);
605 targetList = copyObject(extendedOp->targetList);
606 distinctClause = extendedOp->distinctClause;
607 hasDistinctOn = extendedOp->hasDistinctOn;
608 hasWindowFuncs = extendedOp->hasWindowFuncs;
609 windowClause = extendedOp->windowClause;
610 }
611 else
612 {
613 targetList = QueryTargetList(multiNode);
614 }
615
616 /* build the join tree and the range table list */
617 List *rangeTableList = BaseRangeTableList(multiNode);
618 Node *joinRoot = QueryJoinTree(multiNode, dependentJobList, &rangeTableList);
619
620 /* update the column attributes for target entries */
621 if (updateColumnAttributes)
622 {
623 UpdateAllColumnAttributes((Node *) targetList, rangeTableList, dependentJobList);
624 }
625
626 /* extract limit count/offset and sort clauses */
627 if (extendedOpNodeList != NIL)
628 {
629 MultiExtendedOp *extendedOp = (MultiExtendedOp *) linitial(extendedOpNodeList);
630
631 limitCount = extendedOp->limitCount;
632 limitOffset = extendedOp->limitOffset;
633 #if PG_VERSION_NUM >= PG_VERSION_13
634 limitOption = extendedOp->limitOption;
635 #endif
636 sortClauseList = extendedOp->sortClauseList;
637 havingQual = extendedOp->havingQual;
638 }
639
640 /* build group clauses */
641 List *groupClauseList = QueryGroupClauseList(multiNode);
642
643
644 /* build the where clause list using select predicates */
645 List *selectClauseList = QuerySelectClauseList(multiNode);
646
647 /* set correct column attributes for select and having clauses */
648 if (updateColumnAttributes)
649 {
650 UpdateAllColumnAttributes((Node *) selectClauseList, rangeTableList,
651 dependentJobList);
652 UpdateAllColumnAttributes(havingQual, rangeTableList, dependentJobList);
653 }
654
655 /*
656 * Group by on primary key allows all columns to appear in the target
657 * list, but after re-partitioning we will be querying an intermediate
658 * table that does not have the primary key. We therefore wrap all the
659 * columns that do not appear in the GROUP BY in an any_value aggregate.
660 */
661 if (groupClauseList != NIL && isRepartitionJoin)
662 {
663 targetList = (List *) WrapUngroupedVarsInAnyValueAggregate(
664 (Node *) targetList, groupClauseList, targetList, true);
665
666 havingQual = WrapUngroupedVarsInAnyValueAggregate(
667 (Node *) havingQual, groupClauseList, targetList, false);
668 }
669
670 /*
671 * Build the From/Where construct. We keep the where-clause list implicitly
672 * AND'd, since both partition and join pruning depends on the clauses being
673 * expressed as a list.
674 */
675 FromExpr *joinTree = makeNode(FromExpr);
676 joinTree->quals = (Node *) list_copy(selectClauseList);
677 joinTree->fromlist = list_make1(joinRoot);
678
679 /* build the query structure for this job */
680 Query *jobQuery = makeNode(Query);
681 jobQuery->commandType = CMD_SELECT;
682 jobQuery->querySource = QSRC_ORIGINAL;
683 jobQuery->canSetTag = true;
684 jobQuery->rtable = rangeTableList;
685 jobQuery->targetList = targetList;
686 jobQuery->jointree = joinTree;
687 jobQuery->sortClause = sortClauseList;
688 jobQuery->groupClause = groupClauseList;
689 jobQuery->limitOffset = limitOffset;
690 jobQuery->limitCount = limitCount;
691 #if PG_VERSION_NUM >= PG_VERSION_13
692 jobQuery->limitOption = limitOption;
693 #endif
694 jobQuery->havingQual = havingQual;
695 jobQuery->hasAggs = contain_aggs_of_level((Node *) targetList, 0) ||
696 contain_aggs_of_level((Node *) havingQual, 0);
697 jobQuery->distinctClause = distinctClause;
698 jobQuery->hasDistinctOn = hasDistinctOn;
699 jobQuery->windowClause = windowClause;
700 jobQuery->hasWindowFuncs = hasWindowFuncs;
701 jobQuery->hasSubLinks = checkExprHasSubLink((Node *) jobQuery);
702
703 Assert(jobQuery->hasWindowFuncs == contain_window_function((Node *) jobQuery));
704
705 return jobQuery;
706 }
707
708
709 /*
710 * BaseRangeTableList returns the list of range table entries for base tables in
711 * the query. These base tables stand in contrast to derived tables generated by
712 * repartition jobs. Note that this function only considers base tables relevant
713 * to the current query, and does not visit nodes under the collect node.
714 */
715 static List *
BaseRangeTableList(MultiNode * multiNode)716 BaseRangeTableList(MultiNode *multiNode)
717 {
718 List *baseRangeTableList = NIL;
719 List *pendingNodeList = list_make1(multiNode);
720
721 while (pendingNodeList != NIL)
722 {
723 MultiNode *currMultiNode = (MultiNode *) linitial(pendingNodeList);
724 CitusNodeTag nodeType = CitusNodeTag(currMultiNode);
725 pendingNodeList = list_delete_first(pendingNodeList);
726
727 if (nodeType == T_MultiTable)
728 {
729 /*
730 * We represent subqueries as MultiTables, and so for base table
731 * entries we skip the subquery ones.
732 */
733 MultiTable *multiTable = (MultiTable *) currMultiNode;
734 if (multiTable->relationId != SUBQUERY_RELATION_ID &&
735 multiTable->relationId != SUBQUERY_PUSHDOWN_RELATION_ID)
736 {
737 RangeTblEntry *rangeTableEntry = makeNode(RangeTblEntry);
738 rangeTableEntry->inFromCl = true;
739 rangeTableEntry->eref = multiTable->referenceNames;
740 rangeTableEntry->alias = multiTable->alias;
741 rangeTableEntry->relid = multiTable->relationId;
742 rangeTableEntry->inh = multiTable->includePartitions;
743
744 SetRangeTblExtraData(rangeTableEntry, CITUS_RTE_RELATION, NULL, NULL,
745 list_make1_int(multiTable->rangeTableId),
746 NIL, NIL, NIL, NIL);
747
748 baseRangeTableList = lappend(baseRangeTableList, rangeTableEntry);
749 }
750 }
751
752 /* do not visit nodes that belong to remote queries */
753 if (nodeType != T_MultiCollect)
754 {
755 List *childNodeList = ChildNodeList(currMultiNode);
756 pendingNodeList = list_concat(pendingNodeList, childNodeList);
757 }
758 }
759
760 return baseRangeTableList;
761 }
762
763
764 /*
765 * DerivedRangeTableEntry builds a range table entry for the derived table. This
766 * derived table either represents the output of a repartition job; or the data
767 * on worker nodes in case of the master node query.
768 */
769 RangeTblEntry *
DerivedRangeTableEntry(MultiNode * multiNode,List * columnList,List * tableIdList,List * funcColumnNames,List * funcColumnTypes,List * funcColumnTypeMods,List * funcCollations)770 DerivedRangeTableEntry(MultiNode *multiNode, List *columnList, List *tableIdList,
771 List *funcColumnNames, List *funcColumnTypes,
772 List *funcColumnTypeMods, List *funcCollations)
773 {
774 RangeTblEntry *rangeTableEntry = makeNode(RangeTblEntry);
775 rangeTableEntry->inFromCl = true;
776 rangeTableEntry->eref = makeNode(Alias);
777 rangeTableEntry->eref->colnames = columnList;
778
779 SetRangeTblExtraData(rangeTableEntry, CITUS_RTE_REMOTE_QUERY, NULL, NULL, tableIdList,
780 funcColumnNames, funcColumnTypes, funcColumnTypeMods,
781 funcCollations);
782
783 return rangeTableEntry;
784 }
785
786
787 /*
788 * DerivedColumnNameList builds a column name list for derived (intermediate)
789 * tables. These column names are then used when building the create stament
790 * query string for derived tables.
791 */
792 List *
DerivedColumnNameList(uint32 columnCount,uint64 generatingJobId)793 DerivedColumnNameList(uint32 columnCount, uint64 generatingJobId)
794 {
795 List *columnNameList = NIL;
796
797 for (uint32 columnIndex = 0; columnIndex < columnCount; columnIndex++)
798 {
799 StringInfo columnName = makeStringInfo();
800
801 appendStringInfo(columnName, "intermediate_column_");
802 appendStringInfo(columnName, UINT64_FORMAT "_", generatingJobId);
803 appendStringInfo(columnName, "%u", columnIndex);
804
805 Value *columnValue = makeString(columnName->data);
806 columnNameList = lappend(columnNameList, columnValue);
807 }
808
809 return columnNameList;
810 }
811
812
813 /*
814 * QueryTargetList returns the target entry list for the projected columns
815 * needed to evaluate the operators above the given multiNode. To do this,
816 * the function retrieves a list of all MultiProject nodes below the given
817 * node and picks the columns from the top-most MultiProject node, as this
818 * will be the minimal list of columns needed. Note that this function relies
819 * on a pre-order traversal of the operator tree by the function FindNodesOfType.
820 */
821 static List *
QueryTargetList(MultiNode * multiNode)822 QueryTargetList(MultiNode *multiNode)
823 {
824 List *projectNodeList = FindNodesOfType(multiNode, T_MultiProject);
825 if (list_length(projectNodeList) == 0)
826 {
827 /*
828 * The physical planner assumes that all worker queries would have
829 * target list entries based on the fact that at least the column
830 * on the JOINs have to be on the target list. However, there is
831 * an exception to that if there is a cartesian product join and
832 * there is no additional target list entries belong to one side
833 * of the JOIN. Once we support cartesian product join, we should
834 * remove this error.
835 */
836 ereport(ERROR, (errcode(ERRCODE_FEATURE_NOT_SUPPORTED),
837 errmsg("cannot perform distributed planning on this query"),
838 errdetail("Cartesian products are currently unsupported")));
839 }
840
841 MultiProject *topProjectNode = (MultiProject *) linitial(projectNodeList);
842 List *columnList = topProjectNode->columnList;
843 List *queryTargetList = TargetEntryList(columnList);
844
845 Assert(queryTargetList != NIL);
846 return queryTargetList;
847 }
848
849
850 /*
851 * TargetEntryList creates a target entry for each expression in the given list,
852 * and returns the newly created target entries in a list.
853 */
854 static List *
TargetEntryList(List * expressionList)855 TargetEntryList(List *expressionList)
856 {
857 List *targetEntryList = NIL;
858 ListCell *expressionCell = NULL;
859
860 foreach(expressionCell, expressionList)
861 {
862 Expr *expression = (Expr *) lfirst(expressionCell);
863
864 TargetEntry *targetEntry = makeTargetEntry(expression,
865 list_length(targetEntryList) + 1,
866 NULL, false);
867 targetEntryList = lappend(targetEntryList, targetEntry);
868 }
869
870 return targetEntryList;
871 }
872
873
874 /*
875 * WrapUngroupedVarsInAnyValueAggregate finds Var nodes in the expression
876 * that do not refer to any GROUP BY column and wraps them in an any_value
877 * aggregate. These columns are allowed when the GROUP BY is on a primary
878 * key of a relation, but not if we wrap the relation in a subquery.
879 * However, since we still know the value is unique, any_value gives the
880 * right result.
881 */
882 Node *
WrapUngroupedVarsInAnyValueAggregate(Node * expression,List * groupClauseList,List * targetList,bool checkExpressionEquality)883 WrapUngroupedVarsInAnyValueAggregate(Node *expression, List *groupClauseList,
884 List *targetList, bool checkExpressionEquality)
885 {
886 if (expression == NULL)
887 {
888 return NULL;
889 }
890
891 AddAnyValueAggregatesContext context;
892 context.groupClauseList = groupClauseList;
893 context.groupByTargetEntryList = GroupTargetEntryList(groupClauseList, targetList);
894 context.haveNonVarGrouping = false;
895
896 if (checkExpressionEquality)
897 {
898 /*
899 * If the GROUP BY contains non-Var expressions, we need to do an expensive
900 * subexpression equality check.
901 */
902 TargetEntry *targetEntry = NULL;
903 foreach_ptr(targetEntry, context.groupByTargetEntryList)
904 {
905 if (!IsA(targetEntry->expr, Var))
906 {
907 context.haveNonVarGrouping = true;
908 break;
909 }
910 }
911 }
912
913 /* put the result in the same memory context */
914 MemoryContext nodeContext = GetMemoryChunkContext(expression);
915 MemoryContext oldContext = MemoryContextSwitchTo(nodeContext);
916
917 Node *result = expression_tree_mutator(expression, AddAnyValueAggregates,
918 &context);
919
920 MemoryContextSwitchTo(oldContext);
921
922 return result;
923 }
924
925
926 /*
927 * AddAnyValueAggregates wraps all vars that do not appear in the GROUP BY
928 * clause or are inside an aggregate function in an any_value aggregate
929 * function. This is needed because postgres allows columns that are not
930 * in the GROUP BY to appear on the target list as long as the primary key
931 * of the table is in the GROUP BY, but we sometimes wrap the join tree
932 * in a subquery in which case the primary key information is lost.
933 *
934 * This function copies parts of the node tree, but may contain references
935 * to the original node tree.
936 *
937 * The implementation is derived from / inspired by
938 * check_ungrouped_columns_walker.
939 */
940 static Node *
AddAnyValueAggregates(Node * node,AddAnyValueAggregatesContext * context)941 AddAnyValueAggregates(Node *node, AddAnyValueAggregatesContext *context)
942 {
943 if (node == NULL)
944 {
945 return node;
946 }
947
948 if (IsA(node, Aggref) || IsA(node, GroupingFunc))
949 {
950 /* any column is allowed to appear in an aggregate or grouping */
951 return node;
952 }
953 else if (IsA(node, Var))
954 {
955 Var *var = (Var *) node;
956
957 /*
958 * Check whether this Var appears in the GROUP BY.
959 */
960 TargetEntry *groupByTargetEntry = NULL;
961 foreach_ptr(groupByTargetEntry, context->groupByTargetEntryList)
962 {
963 if (!IsA(groupByTargetEntry->expr, Var))
964 {
965 continue;
966 }
967
968 Var *groupByVar = (Var *) groupByTargetEntry->expr;
969
970 /* we should only be doing this at the top level of the query */
971 Assert(groupByVar->varlevelsup == 0);
972
973 if (var->varno == groupByVar->varno &&
974 var->varattno == groupByVar->varattno)
975 {
976 /* this Var is in the GROUP BY, do not wrap it */
977 return node;
978 }
979 }
980
981 /*
982 * We have found a Var that does not appear in the GROUP BY.
983 * Wrap it in an any_value aggregate.
984 */
985 Aggref *agg = makeNode(Aggref);
986 agg->aggfnoid = CitusAnyValueFunctionId();
987 agg->aggtype = var->vartype;
988 agg->args = list_make1(makeTargetEntry((Expr *) var, 1, NULL, false));
989 agg->aggkind = AGGKIND_NORMAL;
990 agg->aggtranstype = InvalidOid;
991 agg->aggargtypes = list_make1_oid(var->vartype);
992 agg->aggsplit = AGGSPLIT_SIMPLE;
993 agg->aggcollid = exprCollation((Node *) var);
994 return (Node *) agg;
995 }
996 else if (context->haveNonVarGrouping)
997 {
998 /*
999 * The GROUP BY contains at least one expression. Check whether the
1000 * current expression is equal to one of the GROUP BY expressions.
1001 * Otherwise, continue to descend into subexpressions.
1002 */
1003 TargetEntry *groupByTargetEntry = NULL;
1004 foreach_ptr(groupByTargetEntry, context->groupByTargetEntryList)
1005 {
1006 if (equal(node, groupByTargetEntry->expr))
1007 {
1008 /* do not descend into mutator, all Vars are safe */
1009 return node;
1010 }
1011 }
1012 }
1013
1014 return expression_tree_mutator(node, AddAnyValueAggregates, context);
1015 }
1016
1017
1018 /*
1019 * QueryGroupClauseList extracts the group clause list from the logical plan. If
1020 * no grouping clauses exist, the function returns an empty list.
1021 */
1022 static List *
QueryGroupClauseList(MultiNode * multiNode)1023 QueryGroupClauseList(MultiNode *multiNode)
1024 {
1025 List *groupClauseList = NIL;
1026 List *pendingNodeList = list_make1(multiNode);
1027
1028 while (pendingNodeList != NIL)
1029 {
1030 MultiNode *currMultiNode = (MultiNode *) linitial(pendingNodeList);
1031 CitusNodeTag nodeType = CitusNodeTag(currMultiNode);
1032 pendingNodeList = list_delete_first(pendingNodeList);
1033
1034 /* extract the group clause list from the extended operator */
1035 if (nodeType == T_MultiExtendedOp)
1036 {
1037 MultiExtendedOp *extendedOpNode = (MultiExtendedOp *) currMultiNode;
1038 groupClauseList = extendedOpNode->groupClauseList;
1039 }
1040
1041 /* add children only if this node isn't a multi collect and multi table */
1042 if (nodeType != T_MultiCollect && nodeType != T_MultiTable)
1043 {
1044 List *childNodeList = ChildNodeList(currMultiNode);
1045 pendingNodeList = list_concat(pendingNodeList, childNodeList);
1046 }
1047 }
1048
1049 return groupClauseList;
1050 }
1051
1052
1053 /*
1054 * QuerySelectClauseList traverses the given logical plan tree, and extracts all
1055 * select clauses from the select nodes. Note that this function does not walk
1056 * below a collect node; the clauses below the collect node apply to a remote
1057 * query, and they would have been captured by the remote job we depend upon.
1058 */
1059 static List *
QuerySelectClauseList(MultiNode * multiNode)1060 QuerySelectClauseList(MultiNode *multiNode)
1061 {
1062 List *selectClauseList = NIL;
1063 List *pendingNodeList = list_make1(multiNode);
1064
1065 while (pendingNodeList != NIL)
1066 {
1067 MultiNode *currMultiNode = (MultiNode *) linitial(pendingNodeList);
1068 CitusNodeTag nodeType = CitusNodeTag(currMultiNode);
1069 pendingNodeList = list_delete_first(pendingNodeList);
1070
1071 /* extract select clauses from the multi select node */
1072 if (nodeType == T_MultiSelect)
1073 {
1074 MultiSelect *selectNode = (MultiSelect *) currMultiNode;
1075 List *clauseList = copyObject(selectNode->selectClauseList);
1076 selectClauseList = list_concat(selectClauseList, clauseList);
1077 }
1078
1079 /* add children only if this node isn't a multi collect */
1080 if (nodeType != T_MultiCollect)
1081 {
1082 List *childNodeList = ChildNodeList(currMultiNode);
1083 pendingNodeList = list_concat(pendingNodeList, childNodeList);
1084 }
1085 }
1086
1087 return selectClauseList;
1088 }
1089
1090
1091 /*
1092 * Create a tree of JoinExpr and RangeTblRef nodes for the job query from
1093 * a given multiNode. If the tree contains MultiCollect or MultiJoin nodes,
1094 * add corresponding entries to the range table list. We need to construct
1095 * the entries at the same time as the tree to know the appropriate rtindex.
1096 */
1097 static Node *
QueryJoinTree(MultiNode * multiNode,List * dependentJobList,List ** rangeTableList)1098 QueryJoinTree(MultiNode *multiNode, List *dependentJobList, List **rangeTableList)
1099 {
1100 CitusNodeTag nodeType = CitusNodeTag(multiNode);
1101
1102 switch (nodeType)
1103 {
1104 case T_MultiJoin:
1105 {
1106 MultiJoin *joinNode = (MultiJoin *) multiNode;
1107 MultiBinaryNode *binaryNode = (MultiBinaryNode *) multiNode;
1108 ListCell *columnCell = NULL;
1109 JoinExpr *joinExpr = makeNode(JoinExpr);
1110 joinExpr->jointype = joinNode->joinType;
1111 joinExpr->isNatural = false;
1112 joinExpr->larg = QueryJoinTree(binaryNode->leftChildNode, dependentJobList,
1113 rangeTableList);
1114 joinExpr->rarg = QueryJoinTree(binaryNode->rightChildNode, dependentJobList,
1115 rangeTableList);
1116 joinExpr->usingClause = NIL;
1117 joinExpr->alias = NULL;
1118 joinExpr->rtindex = list_length(*rangeTableList) + 1;
1119
1120 /*
1121 * PostgreSQL's optimizer may mark left joins as anti-joins, when there
1122 * is a right-hand-join-key-is-null restriction, but there is no logic
1123 * in ruleutils to deparse anti-joins, so we cannot construct a task
1124 * query containing anti-joins. We therefore translate anti-joins back
1125 * into left-joins. At some point, we may also want to use different
1126 * join pruning logic for anti-joins.
1127 *
1128 * This approach would not work for anti-joins introduced via NOT EXISTS
1129 * sublinks, but currently such queries are prevented by error checks in
1130 * the logical planner.
1131 */
1132 if (joinExpr->jointype == JOIN_ANTI)
1133 {
1134 joinExpr->jointype = JOIN_LEFT;
1135 }
1136
1137 /* fix the column attributes in ON (...) clauses */
1138 List *columnList = pull_var_clause_default((Node *) joinNode->joinClauseList);
1139 foreach(columnCell, columnList)
1140 {
1141 Var *column = (Var *) lfirst(columnCell);
1142 UpdateColumnAttributes(column, *rangeTableList, dependentJobList);
1143
1144 /* adjust our column old attributes for partition pruning to work */
1145 column->varnosyn = column->varno;
1146 column->varattnosyn = column->varattno;
1147 }
1148
1149 /* make AND clauses explicit after fixing them */
1150 joinExpr->quals = (Node *) make_ands_explicit(joinNode->joinClauseList);
1151
1152 RangeTblEntry *rangeTableEntry = JoinRangeTableEntry(joinExpr,
1153 dependentJobList,
1154 *rangeTableList);
1155 *rangeTableList = lappend(*rangeTableList, rangeTableEntry);
1156
1157 return (Node *) joinExpr;
1158 }
1159
1160 case T_MultiTable:
1161 {
1162 MultiTable *rangeTableNode = (MultiTable *) multiNode;
1163 MultiUnaryNode *unaryNode = (MultiUnaryNode *) multiNode;
1164
1165 if (unaryNode->childNode != NULL)
1166 {
1167 /* MultiTable is actually a subquery, return the query tree below */
1168 Node *childNode = QueryJoinTree(unaryNode->childNode, dependentJobList,
1169 rangeTableList);
1170
1171 return childNode;
1172 }
1173 else
1174 {
1175 RangeTblRef *rangeTableRef = makeNode(RangeTblRef);
1176 uint32 rangeTableId = rangeTableNode->rangeTableId;
1177 rangeTableRef->rtindex = NewTableId(rangeTableId, *rangeTableList);
1178
1179 return (Node *) rangeTableRef;
1180 }
1181 }
1182
1183 case T_MultiCollect:
1184 {
1185 List *tableIdList = OutputTableIdList(multiNode);
1186 Job *dependentJob = JobForTableIdList(dependentJobList, tableIdList);
1187 List *dependentTargetList = dependentJob->jobQuery->targetList;
1188
1189 /* compute column names for the derived table */
1190 uint32 columnCount = (uint32) list_length(dependentTargetList);
1191 List *columnNameList = DerivedColumnNameList(columnCount,
1192 dependentJob->jobId);
1193
1194 List *funcColumnNames = NIL;
1195 List *funcColumnTypes = NIL;
1196 List *funcColumnTypeMods = NIL;
1197 List *funcCollations = NIL;
1198
1199 TargetEntry *targetEntry = NULL;
1200 foreach_ptr(targetEntry, dependentTargetList)
1201 {
1202 Node *expr = (Node *) targetEntry->expr;
1203
1204 char *name = targetEntry->resname;
1205 if (name == NULL)
1206 {
1207 name = pstrdup("unnamed");
1208 }
1209
1210 funcColumnNames = lappend(funcColumnNames, makeString(name));
1211
1212 funcColumnTypes = lappend_oid(funcColumnTypes, exprType(expr));
1213 funcColumnTypeMods = lappend_int(funcColumnTypeMods, exprTypmod(expr));
1214 funcCollations = lappend_oid(funcCollations, exprCollation(expr));
1215 }
1216
1217 RangeTblEntry *rangeTableEntry = DerivedRangeTableEntry(multiNode,
1218 columnNameList,
1219 tableIdList,
1220 funcColumnNames,
1221 funcColumnTypes,
1222 funcColumnTypeMods,
1223 funcCollations);
1224
1225 RangeTblRef *rangeTableRef = makeNode(RangeTblRef);
1226
1227 rangeTableRef->rtindex = list_length(*rangeTableList) + 1;
1228 *rangeTableList = lappend(*rangeTableList, rangeTableEntry);
1229
1230 return (Node *) rangeTableRef;
1231 }
1232
1233 case T_MultiCartesianProduct:
1234 {
1235 MultiBinaryNode *binaryNode = (MultiBinaryNode *) multiNode;
1236
1237 JoinExpr *joinExpr = makeNode(JoinExpr);
1238 joinExpr->jointype = JOIN_INNER;
1239 joinExpr->isNatural = false;
1240 joinExpr->larg = QueryJoinTree(binaryNode->leftChildNode, dependentJobList,
1241 rangeTableList);
1242 joinExpr->rarg = QueryJoinTree(binaryNode->rightChildNode, dependentJobList,
1243 rangeTableList);
1244 joinExpr->usingClause = NIL;
1245 joinExpr->alias = NULL;
1246 joinExpr->quals = NULL;
1247 joinExpr->rtindex = list_length(*rangeTableList) + 1;
1248
1249 RangeTblEntry *rangeTableEntry = JoinRangeTableEntry(joinExpr,
1250 dependentJobList,
1251 *rangeTableList);
1252 *rangeTableList = lappend(*rangeTableList, rangeTableEntry);
1253
1254 return (Node *) joinExpr;
1255 }
1256
1257 case T_MultiTreeRoot:
1258 case T_MultiSelect:
1259 case T_MultiProject:
1260 case T_MultiExtendedOp:
1261 case T_MultiPartition:
1262 {
1263 MultiUnaryNode *unaryNode = (MultiUnaryNode *) multiNode;
1264
1265 Assert(UnaryOperator(multiNode));
1266
1267 Node *childNode = QueryJoinTree(unaryNode->childNode, dependentJobList,
1268 rangeTableList);
1269
1270 return childNode;
1271 }
1272
1273 default:
1274 {
1275 ereport(ERROR, (errmsg("unrecognized multi-node type: %d", nodeType)));
1276 }
1277 }
1278 }
1279
1280
1281 /*
1282 * JoinRangeTableEntry builds a range table entry for a fully initialized JoinExpr node.
1283 * The column names and vars are determined using expandRTE, analogous to
1284 * transformFromClauseItem.
1285 */
1286 static RangeTblEntry *
JoinRangeTableEntry(JoinExpr * joinExpr,List * dependentJobList,List * rangeTableList)1287 JoinRangeTableEntry(JoinExpr *joinExpr, List *dependentJobList, List *rangeTableList)
1288 {
1289 RangeTblEntry *rangeTableEntry = makeNode(RangeTblEntry);
1290 List *leftColumnNames = NIL;
1291 List *leftColumnVars = NIL;
1292 List *joinedColumnNames = NIL;
1293 List *joinedColumnVars = NIL;
1294 int leftRangeTableId = ExtractRangeTableId(joinExpr->larg);
1295 RangeTblEntry *leftRTE = rt_fetch(leftRangeTableId, rangeTableList);
1296 List *rightColumnNames = NIL;
1297 List *rightColumnVars = NIL;
1298 int rightRangeTableId = ExtractRangeTableId(joinExpr->rarg);
1299 RangeTblEntry *rightRTE = rt_fetch(rightRangeTableId, rangeTableList);
1300
1301 rangeTableEntry->rtekind = RTE_JOIN;
1302 rangeTableEntry->relid = InvalidOid;
1303 rangeTableEntry->inFromCl = true;
1304 rangeTableEntry->alias = joinExpr->alias;
1305 rangeTableEntry->jointype = joinExpr->jointype;
1306 rangeTableEntry->subquery = NULL;
1307 rangeTableEntry->eref = makeAlias("unnamed_join", NIL);
1308
1309 RangeTblEntry *leftCallingRTE = ConstructCallingRTE(leftRTE, dependentJobList);
1310 RangeTblEntry *rightCallingRte = ConstructCallingRTE(rightRTE, dependentJobList);
1311 ExtractColumns(leftCallingRTE, leftRangeTableId,
1312 &leftColumnNames, &leftColumnVars);
1313 ExtractColumns(rightCallingRte, rightRangeTableId,
1314 &rightColumnNames, &rightColumnVars);
1315 Oid leftRelId = leftCallingRTE->relid;
1316 Oid rightRelId = rightCallingRte->relid;
1317 joinedColumnNames = list_concat(joinedColumnNames, leftColumnNames);
1318 joinedColumnNames = list_concat(joinedColumnNames, rightColumnNames);
1319 joinedColumnVars = list_concat(joinedColumnVars, leftColumnVars);
1320 joinedColumnVars = list_concat(joinedColumnVars, rightColumnVars);
1321
1322 rangeTableEntry->eref->colnames = joinedColumnNames;
1323 rangeTableEntry->joinaliasvars = joinedColumnVars;
1324
1325 SetJoinRelatedColumnsCompat(rangeTableEntry, leftRelId, rightRelId, leftColumnVars,
1326 rightColumnVars);
1327
1328 return rangeTableEntry;
1329 }
1330
1331
1332 /*
1333 * SetJoinRelatedColumnsCompat sets join related fields on the given range table entry.
1334 * Currently it sets joinleftcols/joinrightcols which are introduced with postgres 13.
1335 * For more info see postgres commit: 9ce77d75c5ab094637cc4a446296dc3be6e3c221
1336 */
1337 static void
SetJoinRelatedColumnsCompat(RangeTblEntry * rangeTableEntry,Oid leftRelId,Oid rightRelId,List * leftColumnVars,List * rightColumnVars)1338 SetJoinRelatedColumnsCompat(RangeTblEntry *rangeTableEntry, Oid leftRelId, Oid rightRelId,
1339 List *leftColumnVars, List *rightColumnVars)
1340 {
1341 #if PG_VERSION_NUM >= PG_VERSION_13
1342
1343 /* We don't have any merged columns so set it to 0 */
1344 rangeTableEntry->joinmergedcols = 0;
1345
1346 if (OidIsValid(leftRelId))
1347 {
1348 rangeTableEntry->joinleftcols = GetColumnOriginalIndexes(leftRelId);
1349 }
1350 else
1351 {
1352 int leftColsSize = list_length(leftColumnVars);
1353 rangeTableEntry->joinleftcols = GeneratePositiveIntSequenceList(leftColsSize);
1354 }
1355
1356 if (OidIsValid(rightRelId))
1357 {
1358 rangeTableEntry->joinrightcols = GetColumnOriginalIndexes(rightRelId);
1359 }
1360 else
1361 {
1362 int rightColsSize = list_length(rightColumnVars);
1363 rangeTableEntry->joinrightcols = GeneratePositiveIntSequenceList(rightColsSize);
1364 }
1365
1366 #endif
1367 }
1368
1369
1370 #if PG_VERSION_NUM >= PG_VERSION_13
1371
1372 /*
1373 * GetColumnOriginalIndexes gets the original indexes of columns by taking column drops into account.
1374 */
1375 static List *
GetColumnOriginalIndexes(Oid relationId)1376 GetColumnOriginalIndexes(Oid relationId)
1377 {
1378 List *originalIndexes = NIL;
1379 Relation relation = table_open(relationId, AccessShareLock);
1380 TupleDesc tupleDescriptor = RelationGetDescr(relation);
1381 for (int columnIndex = 0; columnIndex < tupleDescriptor->natts; columnIndex++)
1382 {
1383 Form_pg_attribute currentColumn = TupleDescAttr(tupleDescriptor, columnIndex);
1384 if (currentColumn->attisdropped)
1385 {
1386 continue;
1387 }
1388 originalIndexes = lappend_int(originalIndexes, columnIndex + 1);
1389 }
1390 table_close(relation, NoLock);
1391 return originalIndexes;
1392 }
1393
1394
1395 #endif
1396
1397 /*
1398 * ExtractRangeTableId gets the range table id from a node that could
1399 * either be a JoinExpr or RangeTblRef.
1400 */
1401 static int
ExtractRangeTableId(Node * node)1402 ExtractRangeTableId(Node *node)
1403 {
1404 int rangeTableId = 0;
1405
1406 if (IsA(node, JoinExpr))
1407 {
1408 JoinExpr *joinExpr = (JoinExpr *) node;
1409 rangeTableId = joinExpr->rtindex;
1410 }
1411 else if (IsA(node, RangeTblRef))
1412 {
1413 RangeTblRef *rangeTableRef = (RangeTblRef *) node;
1414 rangeTableId = rangeTableRef->rtindex;
1415 }
1416
1417 Assert(rangeTableId > 0);
1418
1419 return rangeTableId;
1420 }
1421
1422
1423 /*
1424 * ExtractColumns gets a list of column names and vars for a given range
1425 * table entry using expandRTE.
1426 */
1427 static void
ExtractColumns(RangeTblEntry * callingRTE,int rangeTableId,List ** columnNames,List ** columnVars)1428 ExtractColumns(RangeTblEntry *callingRTE, int rangeTableId,
1429 List **columnNames, List **columnVars)
1430 {
1431 int subLevelsUp = 0;
1432 int location = -1;
1433 bool includeDroppedColumns = false;
1434 expandRTE(callingRTE, rangeTableId, subLevelsUp, location, includeDroppedColumns,
1435 columnNames, columnVars);
1436 }
1437
1438
1439 /*
1440 * ConstructCallingRTE constructs a calling RTE from the given range table entry and
1441 * dependentJobList in case of repartition joins. Since the range table entries in a job
1442 * query are mocked RTE_FUNCTION entries, this construction is needed to form an RTE
1443 * that expandRTE can handle.
1444 */
1445 static RangeTblEntry *
ConstructCallingRTE(RangeTblEntry * rangeTableEntry,List * dependentJobList)1446 ConstructCallingRTE(RangeTblEntry *rangeTableEntry, List *dependentJobList)
1447 {
1448 RangeTblEntry *callingRTE = NULL;
1449
1450 CitusRTEKind rangeTableKind = GetRangeTblKind(rangeTableEntry);
1451 if (rangeTableKind == CITUS_RTE_JOIN)
1452 {
1453 /*
1454 * For joins, we can call expandRTE directly.
1455 */
1456 callingRTE = rangeTableEntry;
1457 }
1458 else if (rangeTableKind == CITUS_RTE_RELATION)
1459 {
1460 /*
1461 * For distributed tables, we construct a regular table RTE to call
1462 * expandRTE, which will extract columns from the distributed table
1463 * schema.
1464 */
1465 callingRTE = makeNode(RangeTblEntry);
1466 callingRTE->rtekind = RTE_RELATION;
1467 callingRTE->eref = rangeTableEntry->eref;
1468 callingRTE->relid = rangeTableEntry->relid;
1469 callingRTE->inh = rangeTableEntry->inh;
1470 }
1471 else if (rangeTableKind == CITUS_RTE_REMOTE_QUERY)
1472 {
1473 Job *dependentJob = JobForRangeTable(dependentJobList, rangeTableEntry);
1474 Query *jobQuery = dependentJob->jobQuery;
1475
1476 /*
1477 * For re-partition jobs, we construct a subquery RTE to call expandRTE,
1478 * which will extract the columns from the target list of the job query.
1479 */
1480 callingRTE = makeNode(RangeTblEntry);
1481 callingRTE->rtekind = RTE_SUBQUERY;
1482 callingRTE->eref = rangeTableEntry->eref;
1483 callingRTE->subquery = jobQuery;
1484 }
1485 else
1486 {
1487 ereport(ERROR, (errmsg("unsupported Citus RTE kind: %d", rangeTableKind)));
1488 }
1489 return callingRTE;
1490 }
1491
1492
1493 /*
1494 * QueryFromList creates the from list construct that is used for building the
1495 * query's join tree. The function creates the from list by making a range table
1496 * reference for each entry in the given range table list.
1497 */
1498 static List *
QueryFromList(List * rangeTableList)1499 QueryFromList(List *rangeTableList)
1500 {
1501 List *fromList = NIL;
1502 int rangeTableCount = list_length(rangeTableList);
1503
1504 for (Index rangeTableIndex = 1; rangeTableIndex <= rangeTableCount; rangeTableIndex++)
1505 {
1506 RangeTblRef *rangeTableReference = makeNode(RangeTblRef);
1507 rangeTableReference->rtindex = rangeTableIndex;
1508
1509 fromList = lappend(fromList, rangeTableReference);
1510 }
1511
1512 return fromList;
1513 }
1514
1515
1516 /*
1517 * BuildSubqueryJobQuery traverses the given logical plan tree, finds MultiTable
1518 * which represents the subquery. It builds the query structure by adding this
1519 * subquery as it is to range table list of the query.
1520 *
1521 * Such as if user runs a query like this;
1522 *
1523 * SELECT avg(id) FROM (
1524 * SELECT ... FROM ()
1525 * )
1526 *
1527 * then this function will build this worker query as keeping subquery as it is;
1528 *
1529 * SELECT sum(id), count(id) FROM (
1530 * SELECT ... FROM ()
1531 * )
1532 */
1533 static Query *
BuildSubqueryJobQuery(MultiNode * multiNode)1534 BuildSubqueryJobQuery(MultiNode *multiNode)
1535 {
1536 List *targetList = NIL;
1537 List *sortClauseList = NIL;
1538 Node *havingQual = NULL;
1539 Node *limitCount = NULL;
1540 Node *limitOffset = NULL;
1541 bool hasAggregates = false;
1542 List *distinctClause = NIL;
1543 bool hasDistinctOn = false;
1544 bool hasWindowFuncs = false;
1545 List *windowClause = NIL;
1546
1547 /* we start building jobs from below the collect node */
1548 Assert(!CitusIsA(multiNode, MultiCollect));
1549
1550 List *subqueryMultiTableList = SubqueryMultiTableList(multiNode);
1551 Assert(list_length(subqueryMultiTableList) == 1);
1552
1553 MultiTable *multiTable = (MultiTable *) linitial(subqueryMultiTableList);
1554 Query *subquery = multiTable->subquery;
1555
1556 /* build subquery range table list */
1557 RangeTblEntry *rangeTableEntry = makeNode(RangeTblEntry);
1558 rangeTableEntry->rtekind = RTE_SUBQUERY;
1559 rangeTableEntry->inFromCl = true;
1560 rangeTableEntry->eref = multiTable->referenceNames;
1561 rangeTableEntry->alias = multiTable->alias;
1562 rangeTableEntry->subquery = subquery;
1563
1564 List *rangeTableList = list_make1(rangeTableEntry);
1565
1566 /*
1567 * If we have an extended operator, then we copy the operator's target list.
1568 * Otherwise, we use the target list based on the MultiProject node at this
1569 * level in the query tree.
1570 */
1571 List *extendedOpNodeList = FindNodesOfType(multiNode, T_MultiExtendedOp);
1572 if (extendedOpNodeList != NIL)
1573 {
1574 MultiExtendedOp *extendedOp = (MultiExtendedOp *) linitial(extendedOpNodeList);
1575 targetList = copyObject(extendedOp->targetList);
1576 }
1577 else
1578 {
1579 targetList = QueryTargetList(multiNode);
1580 }
1581
1582 /* extract limit count/offset, sort and having clauses */
1583 if (extendedOpNodeList != NIL)
1584 {
1585 MultiExtendedOp *extendedOp = (MultiExtendedOp *) linitial(extendedOpNodeList);
1586
1587 limitCount = extendedOp->limitCount;
1588 limitOffset = extendedOp->limitOffset;
1589 sortClauseList = extendedOp->sortClauseList;
1590 havingQual = extendedOp->havingQual;
1591 distinctClause = extendedOp->distinctClause;
1592 hasDistinctOn = extendedOp->hasDistinctOn;
1593 hasWindowFuncs = extendedOp->hasWindowFuncs;
1594 windowClause = extendedOp->windowClause;
1595 }
1596
1597 /* build group clauses */
1598 List *groupClauseList = QueryGroupClauseList(multiNode);
1599
1600 /* build the where clause list using select predicates */
1601 List *whereClauseList = QuerySelectClauseList(multiNode);
1602
1603 if (contain_aggs_of_level((Node *) targetList, 0) ||
1604 contain_aggs_of_level((Node *) havingQual, 0))
1605 {
1606 hasAggregates = true;
1607 }
1608
1609 /* distinct is not sent to worker query if there are top level aggregates */
1610 if (hasAggregates)
1611 {
1612 hasDistinctOn = false;
1613 distinctClause = NIL;
1614 }
1615
1616
1617 /*
1618 * Build the From/Where construct. We keep the where-clause list implicitly
1619 * AND'd, since both partition and join pruning depends on the clauses being
1620 * expressed as a list.
1621 */
1622 FromExpr *joinTree = makeNode(FromExpr);
1623 joinTree->quals = (Node *) whereClauseList;
1624 joinTree->fromlist = QueryFromList(rangeTableList);
1625
1626 /* build the query structure for this job */
1627 Query *jobQuery = makeNode(Query);
1628 jobQuery->commandType = CMD_SELECT;
1629 jobQuery->querySource = QSRC_ORIGINAL;
1630 jobQuery->canSetTag = true;
1631 jobQuery->rtable = rangeTableList;
1632 jobQuery->targetList = targetList;
1633 jobQuery->jointree = joinTree;
1634 jobQuery->sortClause = sortClauseList;
1635 jobQuery->groupClause = groupClauseList;
1636 jobQuery->limitOffset = limitOffset;
1637 jobQuery->limitCount = limitCount;
1638 jobQuery->havingQual = havingQual;
1639 jobQuery->hasAggs = hasAggregates;
1640 jobQuery->hasDistinctOn = hasDistinctOn;
1641 jobQuery->distinctClause = distinctClause;
1642 jobQuery->hasWindowFuncs = hasWindowFuncs;
1643 jobQuery->windowClause = windowClause;
1644 jobQuery->hasSubLinks = checkExprHasSubLink((Node *) jobQuery);
1645
1646 Assert(jobQuery->hasWindowFuncs == contain_window_function((Node *) jobQuery));
1647
1648 return jobQuery;
1649 }
1650
1651
1652 /*
1653 * UpdateAllColumnAttributes extracts column references from provided columnContainer
1654 * and calls UpdateColumnAttributes to updates the column's range table reference (varno) and
1655 * column attribute number for the range table (varattno).
1656 */
1657 static void
UpdateAllColumnAttributes(Node * columnContainer,List * rangeTableList,List * dependentJobList)1658 UpdateAllColumnAttributes(Node *columnContainer, List *rangeTableList,
1659 List *dependentJobList)
1660 {
1661 ListCell *columnCell = NULL;
1662 List *columnList = pull_var_clause_default(columnContainer);
1663 foreach(columnCell, columnList)
1664 {
1665 Var *column = (Var *) lfirst(columnCell);
1666 UpdateColumnAttributes(column, rangeTableList, dependentJobList);
1667 }
1668 }
1669
1670
1671 /*
1672 * UpdateColumnAttributes updates the column's range table reference (varno) and
1673 * column attribute number for the range table (varattno). The function uses the
1674 * newly built range table list to update the given column's attributes.
1675 */
1676 static void
UpdateColumnAttributes(Var * column,List * rangeTableList,List * dependentJobList)1677 UpdateColumnAttributes(Var *column, List *rangeTableList, List *dependentJobList)
1678 {
1679 Index originalTableId = column->varnosyn;
1680 AttrNumber originalColumnId = column->varattnosyn;
1681
1682 /* find the new table identifier */
1683 Index newTableId = NewTableId(originalTableId, rangeTableList);
1684 AttrNumber newColumnId = originalColumnId;
1685
1686 /* if this is a derived table, find the new column identifier */
1687 RangeTblEntry *newRangeTableEntry = rt_fetch(newTableId, rangeTableList);
1688 if (GetRangeTblKind(newRangeTableEntry) == CITUS_RTE_REMOTE_QUERY)
1689 {
1690 newColumnId = NewColumnId(originalTableId, originalColumnId,
1691 newRangeTableEntry, dependentJobList);
1692 }
1693
1694 column->varno = newTableId;
1695 column->varattno = newColumnId;
1696 }
1697
1698
1699 /*
1700 * NewTableId determines the new tableId for the query that is currently being
1701 * built. In this query, the original tableId represents the order of the table
1702 * in the initial parse tree. When queries involve repartitioning, we re-order
1703 * tables; and the new tableId corresponds to this new table order.
1704 */
1705 static Index
NewTableId(Index originalTableId,List * rangeTableList)1706 NewTableId(Index originalTableId, List *rangeTableList)
1707 {
1708 Index rangeTableIndex = 1;
1709 ListCell *rangeTableCell = NULL;
1710
1711 foreach(rangeTableCell, rangeTableList)
1712 {
1713 RangeTblEntry *rangeTableEntry = (RangeTblEntry *) lfirst(rangeTableCell);
1714 List *originalTableIdList = NIL;
1715
1716 ExtractRangeTblExtraData(rangeTableEntry, NULL, NULL, NULL, &originalTableIdList);
1717
1718 bool listMember = list_member_int(originalTableIdList, originalTableId);
1719 if (listMember)
1720 {
1721 return rangeTableIndex;
1722 }
1723
1724 rangeTableIndex++;
1725 }
1726
1727 ereport(ERROR, (errmsg("Unrecognized range table id %d", (int) originalTableId)));
1728
1729 return 0;
1730 }
1731
1732
1733 /*
1734 * NewColumnId determines the new columnId for the query that is currently being
1735 * built. In this query, the original columnId corresponds to the column in base
1736 * tables. When the current query is a partition job and generates intermediate
1737 * tables, the columns have a different order and the new columnId corresponds
1738 * to this order. Please note that this function assumes columnIds for dependent
1739 * jobs have already been updated.
1740 */
1741 static AttrNumber
NewColumnId(Index originalTableId,AttrNumber originalColumnId,RangeTblEntry * newRangeTableEntry,List * dependentJobList)1742 NewColumnId(Index originalTableId, AttrNumber originalColumnId,
1743 RangeTblEntry *newRangeTableEntry, List *dependentJobList)
1744 {
1745 AttrNumber newColumnId = 1;
1746 AttrNumber columnIndex = 1;
1747
1748 Job *dependentJob = JobForRangeTable(dependentJobList, newRangeTableEntry);
1749 List *targetEntryList = dependentJob->jobQuery->targetList;
1750
1751 ListCell *targetEntryCell = NULL;
1752 foreach(targetEntryCell, targetEntryList)
1753 {
1754 TargetEntry *targetEntry = (TargetEntry *) lfirst(targetEntryCell);
1755 Expr *expression = targetEntry->expr;
1756
1757 Var *column = (Var *) expression;
1758 Assert(IsA(expression, Var));
1759
1760 /*
1761 * Check against the *old* values for this column, as the new values
1762 * would have been updated already.
1763 */
1764 if (column->varnosyn == originalTableId &&
1765 column->varattnosyn == originalColumnId)
1766 {
1767 newColumnId = columnIndex;
1768 break;
1769 }
1770
1771 columnIndex++;
1772 }
1773
1774 return newColumnId;
1775 }
1776
1777
1778 /*
1779 * JobForRangeTable returns the job that corresponds to the given range table
1780 * entry. The function walks over jobs in the given job list, and compares each
1781 * job's table list against the given range table entry's table list. When two
1782 * table lists match, the function returns the matching job. Note that we call
1783 * this function in practice when we need to determine which one of the jobs we
1784 * depend upon corresponds to given range table entry.
1785 */
1786 static Job *
JobForRangeTable(List * jobList,RangeTblEntry * rangeTableEntry)1787 JobForRangeTable(List *jobList, RangeTblEntry *rangeTableEntry)
1788 {
1789 List *searchedTableIdList = NIL;
1790 CitusRTEKind rangeTableKind;
1791
1792 ExtractRangeTblExtraData(rangeTableEntry, &rangeTableKind, NULL, NULL,
1793 &searchedTableIdList);
1794
1795 Assert(rangeTableKind == CITUS_RTE_REMOTE_QUERY);
1796
1797 Job *searchedJob = JobForTableIdList(jobList, searchedTableIdList);
1798
1799 return searchedJob;
1800 }
1801
1802
1803 /*
1804 * JobForTableIdList returns the job that corresponds to the given
1805 * tableIdList. The function walks over jobs in the given job list, and
1806 * compares each job's table list against the given table list. When the
1807 * two table lists match, the function returns the matching job.
1808 */
1809 static Job *
JobForTableIdList(List * jobList,List * searchedTableIdList)1810 JobForTableIdList(List *jobList, List *searchedTableIdList)
1811 {
1812 Job *searchedJob = NULL;
1813 ListCell *jobCell = NULL;
1814
1815 foreach(jobCell, jobList)
1816 {
1817 Job *job = (Job *) lfirst(jobCell);
1818 List *jobRangeTableList = job->jobQuery->rtable;
1819 List *jobTableIdList = NIL;
1820 ListCell *jobRangeTableCell = NULL;
1821
1822 foreach(jobRangeTableCell, jobRangeTableList)
1823 {
1824 RangeTblEntry *jobRangeTable = (RangeTblEntry *) lfirst(jobRangeTableCell);
1825 List *tableIdList = NIL;
1826
1827 ExtractRangeTblExtraData(jobRangeTable, NULL, NULL, NULL, &tableIdList);
1828
1829 /* copy the list since list_concat is destructive */
1830 tableIdList = list_copy(tableIdList);
1831 jobTableIdList = list_concat(jobTableIdList, tableIdList);
1832 }
1833
1834 /*
1835 * Check if the searched range table's tableIds and the current job's
1836 * tableIds are the same.
1837 */
1838 List *lhsDiff = list_difference_int(jobTableIdList, searchedTableIdList);
1839 List *rhsDiff = list_difference_int(searchedTableIdList, jobTableIdList);
1840 if (lhsDiff == NIL && rhsDiff == NIL)
1841 {
1842 searchedJob = job;
1843 break;
1844 }
1845 }
1846
1847 Assert(searchedJob != NULL);
1848 return searchedJob;
1849 }
1850
1851
1852 /* Returns the list of children for the given multi node. */
1853 static List *
ChildNodeList(MultiNode * multiNode)1854 ChildNodeList(MultiNode *multiNode)
1855 {
1856 List *childNodeList = NIL;
1857 bool isUnaryNode = UnaryOperator(multiNode);
1858 bool isBinaryNode = BinaryOperator(multiNode);
1859
1860 /* relation table nodes don't have any children */
1861 if (CitusIsA(multiNode, MultiTable))
1862 {
1863 MultiTable *multiTable = (MultiTable *) multiNode;
1864 if (multiTable->relationId != SUBQUERY_RELATION_ID)
1865 {
1866 return NIL;
1867 }
1868 }
1869
1870 if (isUnaryNode)
1871 {
1872 MultiUnaryNode *unaryNode = (MultiUnaryNode *) multiNode;
1873 childNodeList = list_make1(unaryNode->childNode);
1874 }
1875 else if (isBinaryNode)
1876 {
1877 MultiBinaryNode *binaryNode = (MultiBinaryNode *) multiNode;
1878 childNodeList = list_make2(binaryNode->leftChildNode,
1879 binaryNode->rightChildNode);
1880 }
1881
1882 return childNodeList;
1883 }
1884
1885
1886 /*
1887 * UniqueJobId allocates and returns a unique jobId for the job to be executed.
1888 *
1889 * The resulting job ID is built up as:
1890 * <16-bit group ID><24-bit process ID><1-bit secondary flag><23-bit local counter>
1891 *
1892 * When citus.enable_unique_job_ids is off then only the local counter is
1893 * included to get repeatable results.
1894 */
1895 uint64
UniqueJobId(void)1896 UniqueJobId(void)
1897 {
1898 static uint32 jobIdCounter = 0;
1899
1900 uint64 jobId = 0;
1901 uint64 processId = 0;
1902 uint64 localGroupId = 0;
1903
1904 jobIdCounter++;
1905
1906 if (EnableUniqueJobIds)
1907 {
1908 /*
1909 * Add the local group id information to the jobId to
1910 * prevent concurrent jobs on different groups to conflict.
1911 */
1912 localGroupId = GetLocalGroupId() & 0xFF;
1913 jobId = jobId | (localGroupId << 48);
1914
1915 /*
1916 * Add the current process ID to distinguish jobs by this
1917 * backends from jobs started by other backends. Process
1918 * IDs can have at most 24-bits on platforms supported by
1919 * Citus.
1920 */
1921 processId = MyProcPid & 0xFFFFFF;
1922 jobId = jobId | (processId << 24);
1923
1924 /*
1925 * Add an extra bit for secondaries to distinguish their
1926 * jobs from primaries.
1927 */
1928 if (RecoveryInProgress())
1929 {
1930 jobId = jobId | (1 << 23);
1931 }
1932 }
1933
1934 /*
1935 * Use the remaining 23 bits to distinguish jobs by the
1936 * same backend.
1937 */
1938 uint64 jobIdNumber = jobIdCounter & 0x1FFFFFF;
1939 jobId = jobId | jobIdNumber;
1940
1941 return jobId;
1942 }
1943
1944
1945 /* Builds a job from the given job query and dependent job list. */
1946 static Job *
BuildJob(Query * jobQuery,List * dependentJobList)1947 BuildJob(Query *jobQuery, List *dependentJobList)
1948 {
1949 Job *job = CitusMakeNode(Job);
1950 job->jobId = UniqueJobId();
1951 job->jobQuery = jobQuery;
1952 job->dependentJobList = dependentJobList;
1953 job->requiresCoordinatorEvaluation = false;
1954
1955 return job;
1956 }
1957
1958
1959 /*
1960 * BuildMapMergeJob builds a MapMerge job from the given query and dependent job
1961 * list. The function then copies and updates the logical plan's partition
1962 * column, and uses the join rule type to determine the physical repartitioning
1963 * method to apply.
1964 */
1965 static MapMergeJob *
BuildMapMergeJob(Query * jobQuery,List * dependentJobList,Var * partitionKey,PartitionType partitionType,Oid baseRelationId,BoundaryNodeJobType boundaryNodeJobType)1966 BuildMapMergeJob(Query *jobQuery, List *dependentJobList, Var *partitionKey,
1967 PartitionType partitionType, Oid baseRelationId,
1968 BoundaryNodeJobType boundaryNodeJobType)
1969 {
1970 List *rangeTableList = jobQuery->rtable;
1971 Var *partitionColumn = copyObject(partitionKey);
1972
1973 /* update the logical partition key's table and column identifiers */
1974 if (boundaryNodeJobType != SUBQUERY_MAP_MERGE_JOB)
1975 {
1976 UpdateColumnAttributes(partitionColumn, rangeTableList, dependentJobList);
1977 }
1978
1979 MapMergeJob *mapMergeJob = CitusMakeNode(MapMergeJob);
1980 mapMergeJob->job.jobId = UniqueJobId();
1981 mapMergeJob->job.jobQuery = jobQuery;
1982 mapMergeJob->job.dependentJobList = dependentJobList;
1983 mapMergeJob->partitionColumn = partitionColumn;
1984 mapMergeJob->sortedShardIntervalArrayLength = 0;
1985
1986 /*
1987 * We assume dual partition join defaults to hash partitioning, and single
1988 * partition join defaults to range partitioning. In practice, the join type
1989 * should have no impact on the physical repartitioning (hash/range) method.
1990 * If join type is not set, this means this job represents a subquery, and
1991 * uses hash partitioning.
1992 */
1993 if (partitionType == DUAL_HASH_PARTITION_TYPE)
1994 {
1995 uint32 partitionCount = HashPartitionCount();
1996
1997 mapMergeJob->partitionType = DUAL_HASH_PARTITION_TYPE;
1998 mapMergeJob->partitionCount = partitionCount;
1999 }
2000 else if (partitionType == SINGLE_HASH_PARTITION_TYPE || partitionType ==
2001 RANGE_PARTITION_TYPE)
2002 {
2003 CitusTableCacheEntry *cache = GetCitusTableCacheEntry(baseRelationId);
2004 int shardCount = cache->shardIntervalArrayLength;
2005 ShardInterval **cachedSortedShardIntervalArray =
2006 cache->sortedShardIntervalArray;
2007 bool hasUninitializedShardInterval =
2008 cache->hasUninitializedShardInterval;
2009
2010 ShardInterval **sortedShardIntervalArray =
2011 palloc0(sizeof(ShardInterval) * shardCount);
2012
2013 for (int shardIndex = 0; shardIndex < shardCount; shardIndex++)
2014 {
2015 sortedShardIntervalArray[shardIndex] =
2016 CopyShardInterval(cachedSortedShardIntervalArray[shardIndex]);
2017 }
2018
2019 if (hasUninitializedShardInterval)
2020 {
2021 ereport(ERROR, (errmsg("cannot range repartition shard with "
2022 "missing min/max values")));
2023 }
2024
2025 mapMergeJob->partitionType = partitionType;
2026 mapMergeJob->partitionCount = (uint32) shardCount;
2027 mapMergeJob->sortedShardIntervalArray = sortedShardIntervalArray;
2028 mapMergeJob->sortedShardIntervalArrayLength = shardCount;
2029 }
2030
2031 return mapMergeJob;
2032 }
2033
2034
2035 /*
2036 * HashPartitionCount returns the number of partition files we create for a hash
2037 * partition task. The function follows Hadoop's method for picking the number
2038 * of reduce tasks: 0.95 or 1.75 * node count * max reduces per node. We choose
2039 * the lower constant 0.95 so that all tasks can start immediately, but round it
2040 * to 1.0 so that we have a smooth number of partition tasks.
2041 */
2042 static uint32
HashPartitionCount(void)2043 HashPartitionCount(void)
2044 {
2045 uint32 groupCount = list_length(ActiveReadableNodeList());
2046 double maxReduceTasksPerNode = RepartitionJoinBucketCountPerNode;
2047
2048 uint32 partitionCount = (uint32) rint(groupCount * maxReduceTasksPerNode);
2049 return partitionCount;
2050 }
2051
2052
2053 /*
2054 * SplitPointObject walks over shard intervals in the given array, extracts each
2055 * shard interval's minimum value, sorts and inserts these minimum values into a
2056 * new array. This sorted array is then used by the MapMerge job.
2057 */
2058 static ArrayType *
SplitPointObject(ShardInterval ** shardIntervalArray,uint32 shardIntervalCount)2059 SplitPointObject(ShardInterval **shardIntervalArray, uint32 shardIntervalCount)
2060 {
2061 Oid typeId = InvalidOid;
2062 bool typeByValue = false;
2063 char typeAlignment = 0;
2064 int16 typeLength = 0;
2065
2066 /* allocate an array for shard min values */
2067 uint32 minDatumCount = shardIntervalCount;
2068 Datum *minDatumArray = palloc0(minDatumCount * sizeof(Datum));
2069
2070 for (uint32 intervalIndex = 0; intervalIndex < shardIntervalCount; intervalIndex++)
2071 {
2072 ShardInterval *shardInterval = shardIntervalArray[intervalIndex];
2073 minDatumArray[intervalIndex] = shardInterval->minValue;
2074 Assert(shardInterval->minValueExists);
2075
2076 /* resolve the datum type on the first pass */
2077 if (intervalIndex == 0)
2078 {
2079 typeId = shardInterval->valueTypeId;
2080 }
2081 }
2082
2083 /* construct the split point object from the sorted array */
2084 get_typlenbyvalalign(typeId, &typeLength, &typeByValue, &typeAlignment);
2085 ArrayType *splitPointObject = construct_array(minDatumArray, minDatumCount, typeId,
2086 typeLength, typeByValue, typeAlignment);
2087
2088 return splitPointObject;
2089 }
2090
2091
2092 /* ------------------------------------------------------------
2093 * Functions that relate to building and assigning tasks follow
2094 * ------------------------------------------------------------
2095 */
2096
2097 /*
2098 * BuildJobTreeTaskList takes in the given job tree and walks over jobs in this
2099 * tree bottom up. The function then creates tasks for each job in the tree,
2100 * sets dependencies between tasks and their downstream dependencies and assigns
2101 * tasks to worker nodes.
2102 */
2103 static Job *
BuildJobTreeTaskList(Job * jobTree,PlannerRestrictionContext * plannerRestrictionContext)2104 BuildJobTreeTaskList(Job *jobTree, PlannerRestrictionContext *plannerRestrictionContext)
2105 {
2106 List *flattenedJobList = NIL;
2107
2108 /*
2109 * We traverse the job tree in preorder, and append each visited job to our
2110 * flattened list. This way, each job in our list appears before the jobs it
2111 * depends on.
2112 */
2113 List *jobStack = list_make1(jobTree);
2114 while (jobStack != NIL)
2115 {
2116 Job *job = (Job *) llast(jobStack);
2117 flattenedJobList = lappend(flattenedJobList, job);
2118
2119 /* pop top element and push its children to the stack */
2120 jobStack = list_delete_ptr(jobStack, job);
2121 jobStack = list_union_ptr(jobStack, job->dependentJobList);
2122 }
2123
2124 /*
2125 * We walk the job list in reverse order to visit jobs bottom up. This way,
2126 * we can create dependencies between tasks bottom up, and assign them to
2127 * worker nodes accordingly.
2128 */
2129 uint32 flattenedJobCount = (int32) list_length(flattenedJobList);
2130 for (int32 jobIndex = (flattenedJobCount - 1); jobIndex >= 0; jobIndex--)
2131 {
2132 Job *job = (Job *) list_nth(flattenedJobList, jobIndex);
2133 List *sqlTaskList = NIL;
2134 ListCell *assignedSqlTaskCell = NULL;
2135
2136 /* create sql tasks for the job, and prune redundant data fetch tasks */
2137 if (job->subqueryPushdown)
2138 {
2139 bool isMultiShardQuery = false;
2140 List *prunedRelationShardList =
2141 TargetShardIntervalsForRestrictInfo(plannerRestrictionContext->
2142 relationRestrictionContext,
2143 &isMultiShardQuery, NULL);
2144
2145 DeferredErrorMessage *deferredErrorMessage = NULL;
2146 sqlTaskList = QueryPushdownSqlTaskList(job->jobQuery, job->jobId,
2147 plannerRestrictionContext->
2148 relationRestrictionContext,
2149 prunedRelationShardList, READ_TASK,
2150 false,
2151 &deferredErrorMessage);
2152 if (deferredErrorMessage != NULL)
2153 {
2154 RaiseDeferredErrorInternal(deferredErrorMessage, ERROR);
2155 }
2156 }
2157 else
2158 {
2159 sqlTaskList = SqlTaskList(job);
2160 }
2161
2162 sqlTaskList = PruneSqlTaskDependencies(sqlTaskList);
2163
2164 /*
2165 * We first assign sql and merge tasks to worker nodes. Next, we assign
2166 * sql tasks' data fetch dependencies.
2167 */
2168 List *assignedSqlTaskList = AssignTaskList(sqlTaskList);
2169 AssignDataFetchDependencies(assignedSqlTaskList);
2170
2171 /* if the parameters has not been resolved, record it */
2172 job->parametersInJobQueryResolved =
2173 !HasUnresolvedExternParamsWalker((Node *) job->jobQuery, NULL);
2174
2175 /*
2176 * Make final adjustments for the assigned tasks.
2177 *
2178 * First, update SELECT tasks' parameters resolved field.
2179 *
2180 * Second, assign merge task's data fetch dependencies.
2181 */
2182 foreach(assignedSqlTaskCell, assignedSqlTaskList)
2183 {
2184 Task *assignedSqlTask = (Task *) lfirst(assignedSqlTaskCell);
2185
2186 /* we don't support parameters in the physical planner */
2187 if (assignedSqlTask->taskType == READ_TASK)
2188 {
2189 assignedSqlTask->parametersInQueryStringResolved =
2190 job->parametersInJobQueryResolved;
2191 }
2192
2193 List *assignedMergeTaskList = FindDependentMergeTaskList(assignedSqlTask);
2194 AssignDataFetchDependencies(assignedMergeTaskList);
2195 }
2196
2197 /*
2198 * If we have a MapMerge job, the map tasks in this job wrap around the
2199 * SQL tasks and their assignments.
2200 */
2201 if (CitusIsA(job, MapMergeJob))
2202 {
2203 MapMergeJob *mapMergeJob = (MapMergeJob *) job;
2204 uint32 taskIdIndex = TaskListHighestTaskId(assignedSqlTaskList) + 1;
2205
2206 List *mapTaskList = MapTaskList(mapMergeJob, assignedSqlTaskList);
2207 List *mergeTaskList = MergeTaskList(mapMergeJob, mapTaskList, taskIdIndex);
2208
2209 mapMergeJob->mapTaskList = mapTaskList;
2210 mapMergeJob->mergeTaskList = mergeTaskList;
2211 }
2212 else
2213 {
2214 job->taskList = assignedSqlTaskList;
2215 }
2216 }
2217
2218 return jobTree;
2219 }
2220
2221
2222 /*
2223 * QueryPushdownSqlTaskList creates a list of SQL tasks to execute the given subquery
2224 * pushdown job. For this, it is being checked whether the query is router
2225 * plannable per target shard interval. For those router plannable worker
2226 * queries, we create a SQL task and append the task to the task list that is going
2227 * to be executed.
2228 */
2229 List *
QueryPushdownSqlTaskList(Query * query,uint64 jobId,RelationRestrictionContext * relationRestrictionContext,List * prunedRelationShardList,TaskType taskType,bool modifyRequiresCoordinatorEvaluation,DeferredErrorMessage ** planningError)2230 QueryPushdownSqlTaskList(Query *query, uint64 jobId,
2231 RelationRestrictionContext *relationRestrictionContext,
2232 List *prunedRelationShardList, TaskType taskType, bool
2233 modifyRequiresCoordinatorEvaluation,
2234 DeferredErrorMessage **planningError)
2235 {
2236 List *sqlTaskList = NIL;
2237 ListCell *restrictionCell = NULL;
2238 uint32 taskIdIndex = 1; /* 0 is reserved for invalid taskId */
2239 int shardCount = 0;
2240 bool *taskRequiredForShardIndex = NULL;
2241 ListCell *prunedRelationShardCell = NULL;
2242
2243 /* error if shards are not co-partitioned */
2244 ErrorIfUnsupportedShardDistribution(query);
2245
2246 if (list_length(relationRestrictionContext->relationRestrictionList) == 0)
2247 {
2248 *planningError = DeferredError(ERRCODE_FEATURE_NOT_SUPPORTED,
2249 "cannot handle complex subqueries when the "
2250 "router executor is disabled",
2251 NULL, NULL);
2252 return NIL;
2253 }
2254
2255 /* defaults to be used if this is a reference table-only query */
2256 int minShardOffset = 0;
2257 int maxShardOffset = 0;
2258
2259 forboth(prunedRelationShardCell, prunedRelationShardList,
2260 restrictionCell, relationRestrictionContext->relationRestrictionList)
2261 {
2262 RelationRestriction *relationRestriction =
2263 (RelationRestriction *) lfirst(restrictionCell);
2264 Oid relationId = relationRestriction->relationId;
2265 List *prunedShardList = (List *) lfirst(prunedRelationShardCell);
2266 ListCell *shardIntervalCell = NULL;
2267
2268 CitusTableCacheEntry *cacheEntry = GetCitusTableCacheEntry(relationId);
2269 if (IsCitusTableTypeCacheEntry(cacheEntry, CITUS_TABLE_WITH_NO_DIST_KEY))
2270 {
2271 continue;
2272 }
2273
2274 /* we expect distributed tables to have the same shard count */
2275 if (shardCount > 0 && shardCount != cacheEntry->shardIntervalArrayLength)
2276 {
2277 *planningError = DeferredError(ERRCODE_FEATURE_NOT_SUPPORTED,
2278 "shard counts of co-located tables do not "
2279 "match",
2280 NULL, NULL);
2281 return NIL;
2282 }
2283
2284 if (taskRequiredForShardIndex == NULL)
2285 {
2286 shardCount = cacheEntry->shardIntervalArrayLength;
2287 taskRequiredForShardIndex = (bool *) palloc0(shardCount);
2288
2289 /* there is a distributed table, find the shard range */
2290 minShardOffset = shardCount;
2291 maxShardOffset = -1;
2292 }
2293
2294 /*
2295 * For left joins we don't care about the shards pruned for the right hand side.
2296 * If the right hand side would prune to a smaller set we should still send it to
2297 * all tables of the left hand side. However if the right hand side is bigger than
2298 * the left hand side we don't have to send the query to any shard that is not
2299 * matching anything on the left hand side.
2300 *
2301 * Instead we will simply skip any RelationRestriction if it is an OUTER join and
2302 * the table is part of the non-outer side of the join.
2303 */
2304 if (IsInnerTableOfOuterJoin(relationRestriction))
2305 {
2306 continue;
2307 }
2308
2309 foreach(shardIntervalCell, prunedShardList)
2310 {
2311 ShardInterval *shardInterval = (ShardInterval *) lfirst(shardIntervalCell);
2312 int shardIndex = shardInterval->shardIndex;
2313
2314 taskRequiredForShardIndex[shardIndex] = true;
2315
2316 minShardOffset = Min(minShardOffset, shardIndex);
2317 maxShardOffset = Max(maxShardOffset, shardIndex);
2318 }
2319 }
2320
2321 /*
2322 * To avoid iterating through all shards indexes we keep the minimum and maximum
2323 * offsets of shards that were not pruned away. This optimisation is primarily
2324 * relevant for queries on range-distributed tables that, due to range filters,
2325 * prune to a small number of adjacent shards.
2326 *
2327 * In other cases, such as an OR condition on a hash-distributed table, we may
2328 * still visit most or all shards even if some of them were pruned away. However,
2329 * given that hash-distributed tables typically only have a few shards the
2330 * iteration is still very fast.
2331 */
2332 for (int shardOffset = minShardOffset; shardOffset <= maxShardOffset; shardOffset++)
2333 {
2334 if (taskRequiredForShardIndex != NULL && !taskRequiredForShardIndex[shardOffset])
2335 {
2336 /* this shard index is pruned away for all relations */
2337 continue;
2338 }
2339
2340 Task *subqueryTask = QueryPushdownTaskCreate(query, shardOffset,
2341 relationRestrictionContext,
2342 taskIdIndex,
2343 taskType,
2344 modifyRequiresCoordinatorEvaluation,
2345 planningError);
2346 if (*planningError != NULL)
2347 {
2348 return NIL;
2349 }
2350 subqueryTask->jobId = jobId;
2351 sqlTaskList = lappend(sqlTaskList, subqueryTask);
2352
2353 ++taskIdIndex;
2354 }
2355
2356 /* If it is a modify task with multiple tables */
2357 if (taskType == MODIFY_TASK && list_length(
2358 relationRestrictionContext->relationRestrictionList) > 1)
2359 {
2360 ListCell *taskCell = NULL;
2361 foreach(taskCell, sqlTaskList)
2362 {
2363 Task *task = (Task *) lfirst(taskCell);
2364 task->modifyWithSubquery = true;
2365 }
2366 }
2367
2368 return sqlTaskList;
2369 }
2370
2371
2372 /*
2373 * IsInnerTableOfOuterJoin tests based on the join information envoded in a
2374 * RelationRestriction if the table accessed for this relation is
2375 * a) in an outer join
2376 * b) on the inner part of said join
2377 *
2378 * The function returns true only if both conditions above hold true
2379 */
2380 static bool
IsInnerTableOfOuterJoin(RelationRestriction * relationRestriction)2381 IsInnerTableOfOuterJoin(RelationRestriction *relationRestriction)
2382 {
2383 RestrictInfo *joinInfo = NULL;
2384 foreach_ptr(joinInfo, relationRestriction->relOptInfo->joininfo)
2385 {
2386 if (joinInfo->outer_relids == NULL)
2387 {
2388 /* not an outer join */
2389 continue;
2390 }
2391
2392 /*
2393 * This join restriction info describes an outer join, we need to figure out if
2394 * our table is in the non outer part of this join. If that is the case this is a
2395 * non outer table of an outer join.
2396 */
2397 bool isInOuter = bms_is_member(relationRestriction->relOptInfo->relid,
2398 joinInfo->outer_relids);
2399 if (!isInOuter)
2400 {
2401 /* this table is joined in the inner part of an outer join */
2402 return true;
2403 }
2404 }
2405
2406 /* we have not found any join clause that satisfies both requirements */
2407 return false;
2408 }
2409
2410
2411 /*
2412 * ErrorIfUnsupportedShardDistribution gets list of relations in the given query
2413 * and checks if two conditions below hold for them, otherwise it errors out.
2414 * a. Every relation is distributed by range or hash. This means shards are
2415 * disjoint based on the partition column.
2416 * b. All relations have 1-to-1 shard partitioning between them. This means
2417 * shard count for every relation is same and for every shard in a relation
2418 * there is exactly one shard in other relations with same min/max values.
2419 */
2420 static void
ErrorIfUnsupportedShardDistribution(Query * query)2421 ErrorIfUnsupportedShardDistribution(Query *query)
2422 {
2423 Oid firstTableRelationId = InvalidOid;
2424 List *relationIdList = DistributedRelationIdList(query);
2425 List *nonReferenceRelations = NIL;
2426 ListCell *relationIdCell = NULL;
2427 uint32 relationIndex = 0;
2428 uint32 rangeDistributedRelationCount = 0;
2429 uint32 hashDistributedRelationCount = 0;
2430 uint32 appendDistributedRelationCount = 0;
2431
2432 foreach(relationIdCell, relationIdList)
2433 {
2434 Oid relationId = lfirst_oid(relationIdCell);
2435 if (IsCitusTableType(relationId, RANGE_DISTRIBUTED))
2436 {
2437 rangeDistributedRelationCount++;
2438 nonReferenceRelations = lappend_oid(nonReferenceRelations,
2439 relationId);
2440 }
2441 else if (IsCitusTableType(relationId, HASH_DISTRIBUTED))
2442 {
2443 hashDistributedRelationCount++;
2444 nonReferenceRelations = lappend_oid(nonReferenceRelations,
2445 relationId);
2446 }
2447 else if (IsCitusTableType(relationId, CITUS_TABLE_WITH_NO_DIST_KEY))
2448 {
2449 /* do not need to handle non-distributed tables */
2450 continue;
2451 }
2452 else
2453 {
2454 CitusTableCacheEntry *distTableEntry = GetCitusTableCacheEntry(relationId);
2455 if (distTableEntry->hasOverlappingShardInterval)
2456 {
2457 ereport(ERROR, (errcode(ERRCODE_FEATURE_NOT_SUPPORTED),
2458 errmsg("cannot push down this subquery"),
2459 errdetail("Currently append partitioned relations "
2460 "with overlapping shard intervals are "
2461 "not supported")));
2462 }
2463
2464 appendDistributedRelationCount++;
2465 }
2466 }
2467
2468 if ((rangeDistributedRelationCount > 0) && (hashDistributedRelationCount > 0))
2469 {
2470 ereport(ERROR, (errcode(ERRCODE_FEATURE_NOT_SUPPORTED),
2471 errmsg("cannot push down this subquery"),
2472 errdetail("A query including both range and hash "
2473 "partitioned relations are unsupported")));
2474 }
2475 else if ((rangeDistributedRelationCount > 0) && (appendDistributedRelationCount > 0))
2476 {
2477 ereport(ERROR, (errcode(ERRCODE_FEATURE_NOT_SUPPORTED),
2478 errmsg("cannot push down this subquery"),
2479 errdetail("A query including both range and append "
2480 "partitioned relations are unsupported")));
2481 }
2482 else if ((appendDistributedRelationCount > 0) && (hashDistributedRelationCount > 0))
2483 {
2484 ereport(ERROR, (errcode(ERRCODE_FEATURE_NOT_SUPPORTED),
2485 errmsg("cannot push down this subquery"),
2486 errdetail("A query including both append and hash "
2487 "partitioned relations are unsupported")));
2488 }
2489
2490 foreach(relationIdCell, nonReferenceRelations)
2491 {
2492 Oid relationId = lfirst_oid(relationIdCell);
2493 Oid currentRelationId = relationId;
2494
2495 /* get shard list of first relation and continue for the next relation */
2496 if (relationIndex == 0)
2497 {
2498 firstTableRelationId = relationId;
2499 relationIndex++;
2500
2501 continue;
2502 }
2503
2504 /* check if this table has 1-1 shard partitioning with first table */
2505 bool coPartitionedTables = CoPartitionedTables(firstTableRelationId,
2506 currentRelationId);
2507 if (!coPartitionedTables)
2508 {
2509 ereport(ERROR, (errcode(ERRCODE_FEATURE_NOT_SUPPORTED),
2510 errmsg("cannot push down this subquery"),
2511 errdetail("Shards of relations in subquery need to "
2512 "have 1-to-1 shard partitioning")));
2513 }
2514 }
2515 }
2516
2517
2518 /*
2519 * SubqueryTaskCreate creates a sql task by replacing the target
2520 * shardInterval's boundary value.
2521 */
2522 static Task *
QueryPushdownTaskCreate(Query * originalQuery,int shardIndex,RelationRestrictionContext * restrictionContext,uint32 taskId,TaskType taskType,bool modifyRequiresCoordinatorEvaluation,DeferredErrorMessage ** planningError)2523 QueryPushdownTaskCreate(Query *originalQuery, int shardIndex,
2524 RelationRestrictionContext *restrictionContext, uint32 taskId,
2525 TaskType taskType, bool modifyRequiresCoordinatorEvaluation,
2526 DeferredErrorMessage **planningError)
2527 {
2528 Query *taskQuery = copyObject(originalQuery);
2529
2530 StringInfo queryString = makeStringInfo();
2531 ListCell *restrictionCell = NULL;
2532 List *taskShardList = NIL;
2533 List *relationShardList = NIL;
2534 uint64 jobId = INVALID_JOB_ID;
2535 uint64 anchorShardId = INVALID_SHARD_ID;
2536 bool modifyWithSubselect = false;
2537 RangeTblEntry *resultRangeTable = NULL;
2538 Oid resultRelationOid = InvalidOid;
2539
2540 /*
2541 * If it is a modify query with sub-select, we need to set result relation shard's id
2542 * as anchor shard id.
2543 */
2544 if (UpdateOrDeleteQuery(originalQuery))
2545 {
2546 resultRangeTable = rt_fetch(originalQuery->resultRelation, originalQuery->rtable);
2547 resultRelationOid = resultRangeTable->relid;
2548 modifyWithSubselect = true;
2549 }
2550
2551 /*
2552 * Find the relevant shard out of each relation for this task.
2553 */
2554 foreach(restrictionCell, restrictionContext->relationRestrictionList)
2555 {
2556 RelationRestriction *relationRestriction =
2557 (RelationRestriction *) lfirst(restrictionCell);
2558 Oid relationId = relationRestriction->relationId;
2559 ShardInterval *shardInterval = NULL;
2560
2561 CitusTableCacheEntry *cacheEntry = GetCitusTableCacheEntry(relationId);
2562 if (IsCitusTableTypeCacheEntry(cacheEntry, CITUS_TABLE_WITH_NO_DIST_KEY))
2563 {
2564 /* non-distributed tables have only one shard */
2565 shardInterval = cacheEntry->sortedShardIntervalArray[0];
2566
2567 /* only use reference table as anchor shard if none exists yet */
2568 if (anchorShardId == INVALID_SHARD_ID)
2569 {
2570 anchorShardId = shardInterval->shardId;
2571 }
2572 }
2573 else if (UpdateOrDeleteQuery(originalQuery))
2574 {
2575 shardInterval = cacheEntry->sortedShardIntervalArray[shardIndex];
2576 if (!modifyWithSubselect || relationId == resultRelationOid)
2577 {
2578 /* for UPDATE/DELETE the shard in the result relation becomes the anchor shard */
2579 anchorShardId = shardInterval->shardId;
2580 }
2581 }
2582 else
2583 {
2584 /* for SELECT we pick an arbitrary shard as the anchor shard */
2585 shardInterval = cacheEntry->sortedShardIntervalArray[shardIndex];
2586 anchorShardId = shardInterval->shardId;
2587 }
2588
2589 ShardInterval *copiedShardInterval = CopyShardInterval(shardInterval);
2590
2591 taskShardList = lappend(taskShardList, list_make1(copiedShardInterval));
2592
2593 RelationShard *relationShard = CitusMakeNode(RelationShard);
2594 relationShard->relationId = copiedShardInterval->relationId;
2595 relationShard->shardId = copiedShardInterval->shardId;
2596
2597 relationShardList = lappend(relationShardList, relationShard);
2598 }
2599
2600 Assert(anchorShardId != INVALID_SHARD_ID);
2601
2602 List *taskPlacementList = PlacementsForWorkersContainingAllShards(taskShardList);
2603 if (list_length(taskPlacementList) == 0)
2604 {
2605 *planningError = DeferredError(ERRCODE_FEATURE_NOT_SUPPORTED,
2606 "cannot find a worker that has active placements for all "
2607 "shards in the query",
2608 NULL, NULL);
2609
2610 return NULL;
2611 }
2612
2613 /*
2614 * Augment the relations in the query with the shard IDs.
2615 */
2616 UpdateRelationToShardNames((Node *) taskQuery, relationShardList);
2617
2618 /*
2619 * Ands are made implicit during shard pruning, as predicate comparison and
2620 * refutation depend on it being so. We need to make them explicit again so
2621 * that the query string is generated as (...) AND (...) as opposed to
2622 * (...), (...).
2623 */
2624 if (taskQuery->jointree->quals != NULL && IsA(taskQuery->jointree->quals, List))
2625 {
2626 taskQuery->jointree->quals = (Node *) make_ands_explicit(
2627 (List *) taskQuery->jointree->quals);
2628 }
2629
2630 Task *subqueryTask = CreateBasicTask(jobId, taskId, taskType, NULL);
2631
2632 if ((taskType == MODIFY_TASK && !modifyRequiresCoordinatorEvaluation) ||
2633 taskType == READ_TASK)
2634 {
2635 pg_get_query_def(taskQuery, queryString);
2636 ereport(DEBUG4, (errmsg("distributed statement: %s",
2637 ApplyLogRedaction(queryString->data))));
2638 SetTaskQueryString(subqueryTask, queryString->data);
2639 }
2640
2641 subqueryTask->dependentTaskList = NULL;
2642 subqueryTask->anchorShardId = anchorShardId;
2643 subqueryTask->taskPlacementList = taskPlacementList;
2644 subqueryTask->relationShardList = relationShardList;
2645
2646 return subqueryTask;
2647 }
2648
2649
2650 /*
2651 * CoPartitionedTables checks if given two distributed tables have 1-to-1 shard
2652 * placement matching. It first checks for the shard count, if tables don't have
2653 * same amount shard then it returns false. Note that, if any table does not
2654 * have any shard, it returns true. If two tables have same amount of shards,
2655 * we check colocationIds for hash distributed tables and shardInterval's min
2656 * max values for append and range distributed tables.
2657 */
2658 bool
CoPartitionedTables(Oid firstRelationId,Oid secondRelationId)2659 CoPartitionedTables(Oid firstRelationId, Oid secondRelationId)
2660 {
2661 if (firstRelationId == secondRelationId)
2662 {
2663 return true;
2664 }
2665
2666 CitusTableCacheEntry *firstTableCache = GetCitusTableCacheEntry(firstRelationId);
2667 CitusTableCacheEntry *secondTableCache = GetCitusTableCacheEntry(secondRelationId);
2668
2669 ShardInterval **sortedFirstIntervalArray = firstTableCache->sortedShardIntervalArray;
2670 ShardInterval **sortedSecondIntervalArray =
2671 secondTableCache->sortedShardIntervalArray;
2672 uint32 firstListShardCount = firstTableCache->shardIntervalArrayLength;
2673 uint32 secondListShardCount = secondTableCache->shardIntervalArrayLength;
2674 FmgrInfo *comparisonFunction = firstTableCache->shardIntervalCompareFunction;
2675
2676 /* reference tables are always & only copartitioned with reference tables */
2677 if (IsCitusTableTypeCacheEntry(firstTableCache, CITUS_TABLE_WITH_NO_DIST_KEY) &&
2678 IsCitusTableTypeCacheEntry(secondTableCache, CITUS_TABLE_WITH_NO_DIST_KEY))
2679 {
2680 return true;
2681 }
2682 else if (IsCitusTableTypeCacheEntry(firstTableCache, CITUS_TABLE_WITH_NO_DIST_KEY) ||
2683 IsCitusTableTypeCacheEntry(secondTableCache, CITUS_TABLE_WITH_NO_DIST_KEY))
2684 {
2685 return false;
2686 }
2687
2688 if (firstListShardCount != secondListShardCount)
2689 {
2690 return false;
2691 }
2692
2693 /* if there are not any shards just return true */
2694 if (firstListShardCount == 0)
2695 {
2696 return true;
2697 }
2698
2699 Assert(comparisonFunction != NULL);
2700
2701 /*
2702 * Check if the tables have the same colocation ID - if so, we know
2703 * they're colocated.
2704 */
2705 if (firstTableCache->colocationId != INVALID_COLOCATION_ID &&
2706 firstTableCache->colocationId == secondTableCache->colocationId)
2707 {
2708 return true;
2709 }
2710
2711 /*
2712 * For hash distributed tables two tables are accepted as colocated only if
2713 * they have the same colocationId. Otherwise they may have same minimum and
2714 * maximum values for each shard interval, yet hash function may result with
2715 * different values for the same value. int vs bigint can be given as an
2716 * example.
2717 */
2718 if (IsCitusTableTypeCacheEntry(firstTableCache, HASH_DISTRIBUTED) ||
2719 IsCitusTableTypeCacheEntry(secondTableCache, HASH_DISTRIBUTED))
2720 {
2721 return false;
2722 }
2723
2724
2725 /*
2726 * Don't compare unequal types
2727 */
2728 Oid collation = firstTableCache->partitionColumn->varcollid;
2729 if (firstTableCache->partitionColumn->vartype !=
2730 secondTableCache->partitionColumn->vartype ||
2731 collation != secondTableCache->partitionColumn->varcollid)
2732 {
2733 return false;
2734 }
2735
2736
2737 /*
2738 * If not known to be colocated check if the remaining shards are
2739 * anyway. Do so by comparing the shard interval arrays that are sorted on
2740 * interval minimum values. Then it compares every shard interval in order
2741 * and if any pair of shard intervals are not equal or they are not located
2742 * in the same node it returns false.
2743 */
2744 for (uint32 intervalIndex = 0; intervalIndex < firstListShardCount; intervalIndex++)
2745 {
2746 ShardInterval *firstInterval = sortedFirstIntervalArray[intervalIndex];
2747 ShardInterval *secondInterval = sortedSecondIntervalArray[intervalIndex];
2748
2749 bool shardIntervalsEqual = ShardIntervalsEqual(comparisonFunction,
2750 collation,
2751 firstInterval,
2752 secondInterval);
2753 if (!shardIntervalsEqual || !CoPlacedShardIntervals(firstInterval,
2754 secondInterval))
2755 {
2756 return false;
2757 }
2758 }
2759
2760 return true;
2761 }
2762
2763
2764 /*
2765 * CoPlacedShardIntervals checks whether the given intervals located in the same nodes.
2766 */
2767 static bool
CoPlacedShardIntervals(ShardInterval * firstInterval,ShardInterval * secondInterval)2768 CoPlacedShardIntervals(ShardInterval *firstInterval, ShardInterval *secondInterval)
2769 {
2770 List *firstShardPlacementList = ShardPlacementListWithoutOrphanedPlacements(
2771 firstInterval->shardId);
2772 List *secondShardPlacementList = ShardPlacementListWithoutOrphanedPlacements(
2773 secondInterval->shardId);
2774 ListCell *firstShardPlacementCell = NULL;
2775 ListCell *secondShardPlacementCell = NULL;
2776
2777 /* Shards must have same number of placements */
2778 if (list_length(firstShardPlacementList) != list_length(secondShardPlacementList))
2779 {
2780 return false;
2781 }
2782
2783 firstShardPlacementList = SortList(firstShardPlacementList, CompareShardPlacements);
2784 secondShardPlacementList = SortList(secondShardPlacementList, CompareShardPlacements);
2785
2786 forboth(firstShardPlacementCell, firstShardPlacementList, secondShardPlacementCell,
2787 secondShardPlacementList)
2788 {
2789 ShardPlacement *firstShardPlacement = (ShardPlacement *) lfirst(
2790 firstShardPlacementCell);
2791 ShardPlacement *secondShardPlacement = (ShardPlacement *) lfirst(
2792 secondShardPlacementCell);
2793
2794 if (firstShardPlacement->nodeId != secondShardPlacement->nodeId)
2795 {
2796 return false;
2797 }
2798 }
2799
2800 return true;
2801 }
2802
2803
2804 /*
2805 * ShardIntervalsEqual checks if given shard intervals have equal min/max values.
2806 */
2807 static bool
ShardIntervalsEqual(FmgrInfo * comparisonFunction,Oid collation,ShardInterval * firstInterval,ShardInterval * secondInterval)2808 ShardIntervalsEqual(FmgrInfo *comparisonFunction, Oid collation,
2809 ShardInterval *firstInterval, ShardInterval *secondInterval)
2810 {
2811 bool shardIntervalsEqual = false;
2812
2813 Datum firstMin = firstInterval->minValue;
2814 Datum firstMax = firstInterval->maxValue;
2815 Datum secondMin = secondInterval->minValue;
2816 Datum secondMax = secondInterval->maxValue;
2817
2818 if (firstInterval->minValueExists && firstInterval->maxValueExists &&
2819 secondInterval->minValueExists && secondInterval->maxValueExists)
2820 {
2821 Datum minDatum = FunctionCall2Coll(comparisonFunction, collation, firstMin,
2822 secondMin);
2823 Datum maxDatum = FunctionCall2Coll(comparisonFunction, collation, firstMax,
2824 secondMax);
2825 int firstComparison = DatumGetInt32(minDatum);
2826 int secondComparison = DatumGetInt32(maxDatum);
2827
2828 if (firstComparison == 0 && secondComparison == 0)
2829 {
2830 shardIntervalsEqual = true;
2831 }
2832 }
2833
2834 return shardIntervalsEqual;
2835 }
2836
2837
2838 /*
2839 * SqlTaskList creates a list of SQL tasks to execute the given job. For this,
2840 * the function walks over each range table in the job's range table list, gets
2841 * each range table's table fragments, and prunes unneeded table fragments. The
2842 * function then joins table fragments from different range tables, and creates
2843 * all fragment combinations. For each created combination, the function builds
2844 * a SQL task, and appends this task to a task list.
2845 */
2846 static List *
SqlTaskList(Job * job)2847 SqlTaskList(Job *job)
2848 {
2849 List *sqlTaskList = NIL;
2850 uint32 taskIdIndex = 1; /* 0 is reserved for invalid taskId */
2851 uint64 jobId = job->jobId;
2852 bool anchorRangeTableBasedAssignment = false;
2853 uint32 anchorRangeTableId = 0;
2854
2855 Query *jobQuery = job->jobQuery;
2856 List *rangeTableList = jobQuery->rtable;
2857 List *whereClauseList = (List *) jobQuery->jointree->quals;
2858 List *dependentJobList = job->dependentJobList;
2859
2860 /*
2861 * If we don't depend on a hash partition, then we determine the largest
2862 * table around which we build our queries. This reduces data fetching.
2863 */
2864 bool dependsOnHashPartitionJob = DependsOnHashPartitionJob(job);
2865 if (!dependsOnHashPartitionJob)
2866 {
2867 anchorRangeTableBasedAssignment = true;
2868 anchorRangeTableId = AnchorRangeTableId(rangeTableList);
2869
2870 Assert(anchorRangeTableId != 0);
2871 Assert(anchorRangeTableId <= list_length(rangeTableList));
2872 }
2873
2874 /* adjust our column old attributes for partition pruning to work */
2875 AdjustColumnOldAttributes(whereClauseList);
2876 AdjustColumnOldAttributes(jobQuery->targetList);
2877
2878 /*
2879 * Ands are made implicit during shard pruning, as predicate comparison and
2880 * refutation depend on it being so. We need to make them explicit again so
2881 * that the query string is generated as (...) AND (...) as opposed to
2882 * (...), (...).
2883 */
2884 Node *whereClauseTree = (Node *) make_ands_explicit(
2885 (List *) jobQuery->jointree->quals);
2886 jobQuery->jointree->quals = whereClauseTree;
2887
2888 /*
2889 * For each range table, we first get a list of their shards or merge tasks.
2890 * We also apply partition pruning based on the selection criteria. If all
2891 * range table fragments are pruned away, we return an empty task list.
2892 */
2893 List *rangeTableFragmentsList = RangeTableFragmentsList(rangeTableList,
2894 whereClauseList,
2895 dependentJobList);
2896 if (rangeTableFragmentsList == NIL)
2897 {
2898 return NIL;
2899 }
2900
2901 /*
2902 * We then generate fragment combinations according to how range tables join
2903 * with each other (and apply join pruning). Each fragment combination then
2904 * represents one SQL task's dependencies.
2905 */
2906 List *fragmentCombinationList = FragmentCombinationList(rangeTableFragmentsList,
2907 jobQuery, dependentJobList);
2908
2909 ListCell *fragmentCombinationCell = NULL;
2910 foreach(fragmentCombinationCell, fragmentCombinationList)
2911 {
2912 List *fragmentCombination = (List *) lfirst(fragmentCombinationCell);
2913
2914 /* create tasks to fetch fragments required for the sql task */
2915 List *dataFetchTaskList = DataFetchTaskList(jobId, taskIdIndex,
2916 fragmentCombination);
2917 int32 dataFetchTaskCount = list_length(dataFetchTaskList);
2918 taskIdIndex += dataFetchTaskCount;
2919
2920 /* update range table entries with fragment aliases (in place) */
2921 Query *taskQuery = copyObject(jobQuery);
2922 List *fragmentRangeTableList = taskQuery->rtable;
2923 UpdateRangeTableAlias(fragmentRangeTableList, fragmentCombination);
2924
2925 /* transform the updated task query to a SQL query string */
2926 StringInfo sqlQueryString = makeStringInfo();
2927 pg_get_query_def(taskQuery, sqlQueryString);
2928
2929 Task *sqlTask = CreateBasicTask(jobId, taskIdIndex, READ_TASK,
2930 sqlQueryString->data);
2931 sqlTask->dependentTaskList = dataFetchTaskList;
2932 sqlTask->relationShardList = BuildRelationShardList(fragmentRangeTableList,
2933 fragmentCombination);
2934
2935 /* log the query string we generated */
2936 ereport(DEBUG4, (errmsg("generated sql query for task %d", sqlTask->taskId),
2937 errdetail("query string: \"%s\"",
2938 ApplyLogRedaction(sqlQueryString->data))));
2939
2940 sqlTask->anchorShardId = INVALID_SHARD_ID;
2941 if (anchorRangeTableBasedAssignment)
2942 {
2943 sqlTask->anchorShardId = AnchorShardId(fragmentCombination,
2944 anchorRangeTableId);
2945 }
2946
2947 taskIdIndex++;
2948 sqlTaskList = lappend(sqlTaskList, sqlTask);
2949 }
2950
2951 return sqlTaskList;
2952 }
2953
2954
2955 /*
2956 * RelabelTypeToCollateExpr converts RelabelType's into CollationExpr's.
2957 * With that, we will be able to pushdown COLLATE's.
2958 */
2959 CollateExpr *
RelabelTypeToCollateExpr(RelabelType * relabelType)2960 RelabelTypeToCollateExpr(RelabelType *relabelType)
2961 {
2962 Assert(OidIsValid(relabelType->resultcollid));
2963
2964 CollateExpr *collateExpr = makeNode(CollateExpr);
2965 collateExpr->arg = relabelType->arg;
2966 collateExpr->collOid = relabelType->resultcollid;
2967 collateExpr->location = relabelType->location;
2968
2969 return collateExpr;
2970 }
2971
2972
2973 /*
2974 * DependsOnHashPartitionJob checks if the given job depends on a hash
2975 * partitioning job.
2976 */
2977 static bool
DependsOnHashPartitionJob(Job * job)2978 DependsOnHashPartitionJob(Job *job)
2979 {
2980 bool dependsOnHashPartitionJob = false;
2981 List *dependentJobList = job->dependentJobList;
2982
2983 uint32 dependentJobCount = (uint32) list_length(dependentJobList);
2984 if (dependentJobCount > 0)
2985 {
2986 Job *dependentJob = (Job *) linitial(dependentJobList);
2987 if (CitusIsA(dependentJob, MapMergeJob))
2988 {
2989 MapMergeJob *mapMergeJob = (MapMergeJob *) dependentJob;
2990 if (mapMergeJob->partitionType == DUAL_HASH_PARTITION_TYPE)
2991 {
2992 dependsOnHashPartitionJob = true;
2993 }
2994 }
2995 }
2996
2997 return dependsOnHashPartitionJob;
2998 }
2999
3000
3001 /*
3002 * AnchorRangeTableId determines the table around which we build our queries,
3003 * and returns this table's range table id. We refer to this table as the anchor
3004 * table, and make sure that the anchor table's shards are moved or cached only
3005 * when absolutely necessary.
3006 */
3007 static uint32
AnchorRangeTableId(List * rangeTableList)3008 AnchorRangeTableId(List *rangeTableList)
3009 {
3010 uint32 anchorRangeTableId = 0;
3011 uint64 maxTableSize = 0;
3012
3013 /*
3014 * We first filter anything but ordinary tables. Then, we pick the table(s)
3015 * with the most number of shards as our anchor table. If multiple tables
3016 * have the most number of shards, we have a draw.
3017 */
3018 List *baseTableIdList = BaseRangeTableIdList(rangeTableList);
3019 List *anchorTableIdList = AnchorRangeTableIdList(rangeTableList, baseTableIdList);
3020 ListCell *anchorTableIdCell = NULL;
3021
3022 int anchorTableIdCount = list_length(anchorTableIdList);
3023 Assert(anchorTableIdCount > 0);
3024
3025 if (anchorTableIdCount == 1)
3026 {
3027 anchorRangeTableId = (uint32) linitial_int(anchorTableIdList);
3028 return anchorRangeTableId;
3029 }
3030
3031 /*
3032 * If more than one table has the most number of shards, we break the draw
3033 * by comparing table sizes and picking the table with the largest size.
3034 */
3035 foreach(anchorTableIdCell, anchorTableIdList)
3036 {
3037 uint32 anchorTableId = (uint32) lfirst_int(anchorTableIdCell);
3038 RangeTblEntry *tableEntry = rt_fetch(anchorTableId, rangeTableList);
3039 uint64 tableSize = 0;
3040
3041 List *shardList = LoadShardList(tableEntry->relid);
3042 ListCell *shardCell = NULL;
3043
3044 foreach(shardCell, shardList)
3045 {
3046 uint64 *shardIdPointer = (uint64 *) lfirst(shardCell);
3047 uint64 shardId = (*shardIdPointer);
3048 uint64 shardSize = ShardLength(shardId);
3049
3050 tableSize += shardSize;
3051 }
3052
3053 if (tableSize > maxTableSize)
3054 {
3055 maxTableSize = tableSize;
3056 anchorRangeTableId = anchorTableId;
3057 }
3058 }
3059
3060 if (anchorRangeTableId == 0)
3061 {
3062 /* all tables have the same shard count and size 0, pick the first */
3063 anchorRangeTableId = (uint32) linitial_int(anchorTableIdList);
3064 }
3065
3066 return anchorRangeTableId;
3067 }
3068
3069
3070 /*
3071 * BaseRangeTableIdList walks over range tables in the given range table list,
3072 * finds range tables that correspond to base (non-repartitioned) tables, and
3073 * returns these range tables' identifiers in a new list.
3074 */
3075 static List *
BaseRangeTableIdList(List * rangeTableList)3076 BaseRangeTableIdList(List *rangeTableList)
3077 {
3078 List *baseRangeTableIdList = NIL;
3079 uint32 rangeTableId = 1;
3080
3081 ListCell *rangeTableCell = NULL;
3082 foreach(rangeTableCell, rangeTableList)
3083 {
3084 RangeTblEntry *rangeTableEntry = (RangeTblEntry *) lfirst(rangeTableCell);
3085 if (GetRangeTblKind(rangeTableEntry) == CITUS_RTE_RELATION)
3086 {
3087 baseRangeTableIdList = lappend_int(baseRangeTableIdList, rangeTableId);
3088 }
3089
3090 rangeTableId++;
3091 }
3092
3093 return baseRangeTableIdList;
3094 }
3095
3096
3097 /*
3098 * AnchorRangeTableIdList finds ordinary table(s) with the most number of shards
3099 * and returns the corresponding range table id(s) in a list.
3100 */
3101 static List *
AnchorRangeTableIdList(List * rangeTableList,List * baseRangeTableIdList)3102 AnchorRangeTableIdList(List *rangeTableList, List *baseRangeTableIdList)
3103 {
3104 List *anchorTableIdList = NIL;
3105 uint32 maxShardCount = 0;
3106 ListCell *baseRangeTableIdCell = NULL;
3107
3108 uint32 baseRangeTableCount = list_length(baseRangeTableIdList);
3109 if (baseRangeTableCount == 1)
3110 {
3111 return baseRangeTableIdList;
3112 }
3113
3114 foreach(baseRangeTableIdCell, baseRangeTableIdList)
3115 {
3116 uint32 baseRangeTableId = (uint32) lfirst_int(baseRangeTableIdCell);
3117 RangeTblEntry *tableEntry = rt_fetch(baseRangeTableId, rangeTableList);
3118 List *shardList = LoadShardList(tableEntry->relid);
3119
3120 uint32 shardCount = (uint32) list_length(shardList);
3121 if (shardCount > maxShardCount)
3122 {
3123 anchorTableIdList = list_make1_int(baseRangeTableId);
3124 maxShardCount = shardCount;
3125 }
3126 else if (shardCount == maxShardCount)
3127 {
3128 anchorTableIdList = lappend_int(anchorTableIdList, baseRangeTableId);
3129 }
3130 }
3131
3132 return anchorTableIdList;
3133 }
3134
3135
3136 /*
3137 * AdjustColumnOldAttributes adjust the old tableId (varnosyn) and old columnId
3138 * (varattnosyn), and sets them equal to the new values. We need this adjustment
3139 * for partition pruning where we compare these columns with partition columns
3140 * loaded from system catalogs. Since columns loaded from system catalogs always
3141 * have the same old and new values, we also need to adjust column values here.
3142 */
3143 static void
AdjustColumnOldAttributes(List * expressionList)3144 AdjustColumnOldAttributes(List *expressionList)
3145 {
3146 List *columnList = pull_var_clause_default((Node *) expressionList);
3147 ListCell *columnCell = NULL;
3148
3149 foreach(columnCell, columnList)
3150 {
3151 Var *column = (Var *) lfirst(columnCell);
3152 column->varnosyn = column->varno;
3153 column->varattnosyn = column->varattno;
3154 }
3155 }
3156
3157
3158 /*
3159 * RangeTableFragmentsList walks over range tables in the given range table list
3160 * and for each table, the function creates a list of its fragments. A fragment
3161 * in this list represents either a regular shard or a merge task. Once a list
3162 * for each range table is constructed, the function applies partition pruning
3163 * using the given where clause list. Then, the function appends the fragment
3164 * list for each range table to a list of lists, and returns this list of lists.
3165 */
3166 static List *
RangeTableFragmentsList(List * rangeTableList,List * whereClauseList,List * dependentJobList)3167 RangeTableFragmentsList(List *rangeTableList, List *whereClauseList,
3168 List *dependentJobList)
3169 {
3170 List *rangeTableFragmentsList = NIL;
3171 uint32 rangeTableIndex = 0;
3172 const uint32 fragmentSize = sizeof(RangeTableFragment);
3173
3174 ListCell *rangeTableCell = NULL;
3175 foreach(rangeTableCell, rangeTableList)
3176 {
3177 uint32 tableId = rangeTableIndex + 1; /* tableId starts from 1 */
3178
3179 RangeTblEntry *rangeTableEntry = (RangeTblEntry *) lfirst(rangeTableCell);
3180 CitusRTEKind rangeTableKind = GetRangeTblKind(rangeTableEntry);
3181
3182 if (rangeTableKind == CITUS_RTE_RELATION)
3183 {
3184 Oid relationId = rangeTableEntry->relid;
3185 ListCell *shardIntervalCell = NULL;
3186 List *shardFragmentList = NIL;
3187 List *prunedShardIntervalList = PruneShards(relationId, tableId,
3188 whereClauseList, NULL);
3189
3190 /*
3191 * If we prune all shards for one table, query results will be empty.
3192 * We can therefore return NIL for the task list here.
3193 */
3194 if (prunedShardIntervalList == NIL)
3195 {
3196 return NIL;
3197 }
3198
3199 foreach(shardIntervalCell, prunedShardIntervalList)
3200 {
3201 ShardInterval *shardInterval =
3202 (ShardInterval *) lfirst(shardIntervalCell);
3203
3204 RangeTableFragment *shardFragment = palloc0(fragmentSize);
3205 shardFragment->fragmentReference = shardInterval;
3206 shardFragment->fragmentType = CITUS_RTE_RELATION;
3207 shardFragment->rangeTableId = tableId;
3208
3209 shardFragmentList = lappend(shardFragmentList, shardFragment);
3210 }
3211
3212 rangeTableFragmentsList = lappend(rangeTableFragmentsList,
3213 shardFragmentList);
3214 }
3215 else if (rangeTableKind == CITUS_RTE_REMOTE_QUERY)
3216 {
3217 List *mergeTaskFragmentList = NIL;
3218 ListCell *mergeTaskCell = NULL;
3219
3220 Job *dependentJob = JobForRangeTable(dependentJobList, rangeTableEntry);
3221 Assert(CitusIsA(dependentJob, MapMergeJob));
3222
3223 MapMergeJob *dependentMapMergeJob = (MapMergeJob *) dependentJob;
3224 List *mergeTaskList = dependentMapMergeJob->mergeTaskList;
3225
3226 /* if there are no tasks for the dependent job, just return NIL */
3227 if (mergeTaskList == NIL)
3228 {
3229 return NIL;
3230 }
3231
3232 foreach(mergeTaskCell, mergeTaskList)
3233 {
3234 Task *mergeTask = (Task *) lfirst(mergeTaskCell);
3235
3236 RangeTableFragment *mergeTaskFragment = palloc0(fragmentSize);
3237 mergeTaskFragment->fragmentReference = mergeTask;
3238 mergeTaskFragment->fragmentType = CITUS_RTE_REMOTE_QUERY;
3239 mergeTaskFragment->rangeTableId = tableId;
3240
3241 mergeTaskFragmentList = lappend(mergeTaskFragmentList, mergeTaskFragment);
3242 }
3243
3244 rangeTableFragmentsList = lappend(rangeTableFragmentsList,
3245 mergeTaskFragmentList);
3246 }
3247
3248 rangeTableIndex++;
3249 }
3250
3251 return rangeTableFragmentsList;
3252 }
3253
3254
3255 /*
3256 * BuildBaseConstraint builds and returns a base constraint. This constraint
3257 * implements an expression in the form of (column <= max && column >= min),
3258 * where column is the partition key, and min and max values represent a shard's
3259 * min and max values. These shard values are filled in after the constraint is
3260 * built.
3261 */
3262 Node *
BuildBaseConstraint(Var * column)3263 BuildBaseConstraint(Var *column)
3264 {
3265 /* Build these expressions with only one argument for now */
3266 OpExpr *lessThanExpr = MakeOpExpression(column, BTLessEqualStrategyNumber);
3267 OpExpr *greaterThanExpr = MakeOpExpression(column, BTGreaterEqualStrategyNumber);
3268
3269 /* Build base constaint as an and of two qual conditions */
3270 Node *baseConstraint = make_and_qual((Node *) lessThanExpr, (Node *) greaterThanExpr);
3271
3272 return baseConstraint;
3273 }
3274
3275
3276 /*
3277 * MakeOpExpression builds an operator expression node. This operator expression
3278 * implements the operator clause as defined by the variable and the strategy
3279 * number.
3280 */
3281 OpExpr *
MakeOpExpression(Var * variable,int16 strategyNumber)3282 MakeOpExpression(Var *variable, int16 strategyNumber)
3283 {
3284 Oid typeId = variable->vartype;
3285 Oid typeModId = variable->vartypmod;
3286 Oid collationId = variable->varcollid;
3287
3288 Oid accessMethodId = BTREE_AM_OID;
3289
3290 OperatorCacheEntry *operatorCacheEntry = LookupOperatorByType(typeId, accessMethodId,
3291 strategyNumber);
3292
3293 Oid operatorId = operatorCacheEntry->operatorId;
3294 Oid operatorClassInputType = operatorCacheEntry->operatorClassInputType;
3295 char typeType = operatorCacheEntry->typeType;
3296
3297 /*
3298 * Relabel variable if input type of default operator class is not equal to
3299 * the variable type. Note that we don't relabel the variable if the default
3300 * operator class variable type is a pseudo-type.
3301 */
3302 if (operatorClassInputType != typeId && typeType != TYPTYPE_PSEUDO)
3303 {
3304 variable = (Var *) makeRelabelType((Expr *) variable, operatorClassInputType,
3305 -1, collationId, COERCE_IMPLICIT_CAST);
3306 }
3307
3308 Const *constantValue = makeNullConst(operatorClassInputType, typeModId, collationId);
3309
3310 /* Now make the expression with the given variable and a null constant */
3311 OpExpr *expression = (OpExpr *) make_opclause(operatorId,
3312 InvalidOid, /* no result type yet */
3313 false, /* no return set */
3314 (Expr *) variable,
3315 (Expr *) constantValue,
3316 InvalidOid, collationId);
3317
3318 /* Set implementing function id and result type */
3319 expression->opfuncid = get_opcode(operatorId);
3320 expression->opresulttype = get_func_rettype(expression->opfuncid);
3321
3322 return expression;
3323 }
3324
3325
3326 /*
3327 * LookupOperatorByType is a wrapper around GetOperatorByType(),
3328 * operatorClassInputType() and get_typtype() functions that uses a cache to avoid
3329 * multiple lookups of operators and its related fields within a single session by
3330 * their types, access methods and strategy numbers.
3331 * LookupOperatorByType function errors out if it cannot find corresponding
3332 * default operator class with the given parameters on the system catalogs.
3333 */
3334 static OperatorCacheEntry *
LookupOperatorByType(Oid typeId,Oid accessMethodId,int16 strategyNumber)3335 LookupOperatorByType(Oid typeId, Oid accessMethodId, int16 strategyNumber)
3336 {
3337 OperatorCacheEntry *matchingCacheEntry = NULL;
3338 ListCell *cacheEntryCell = NULL;
3339
3340 /* search the cache */
3341 foreach(cacheEntryCell, OperatorCache)
3342 {
3343 OperatorCacheEntry *cacheEntry = lfirst(cacheEntryCell);
3344
3345 if ((cacheEntry->typeId == typeId) &&
3346 (cacheEntry->accessMethodId == accessMethodId) &&
3347 (cacheEntry->strategyNumber == strategyNumber))
3348 {
3349 matchingCacheEntry = cacheEntry;
3350 break;
3351 }
3352 }
3353
3354 /* if not found in the cache, call GetOperatorByType and put the result in cache */
3355 if (matchingCacheEntry == NULL)
3356 {
3357 Oid operatorClassId = GetDefaultOpClass(typeId, accessMethodId);
3358
3359 if (operatorClassId == InvalidOid)
3360 {
3361 /* if operatorId is invalid, error out */
3362 ereport(ERROR, (errmsg("cannot find default operator class for type:%d,"
3363 " access method: %d", typeId, accessMethodId)));
3364 }
3365
3366 /* fill the other fields to the cache */
3367 Oid operatorId = GetOperatorByType(typeId, accessMethodId, strategyNumber);
3368 Oid operatorClassInputType = get_opclass_input_type(operatorClassId);
3369 char typeType = get_typtype(operatorClassInputType);
3370
3371 /* make sure we've initialized CacheMemoryContext */
3372 if (CacheMemoryContext == NULL)
3373 {
3374 CreateCacheMemoryContext();
3375 }
3376
3377 MemoryContext oldContext = MemoryContextSwitchTo(CacheMemoryContext);
3378
3379 matchingCacheEntry = palloc0(sizeof(OperatorCacheEntry));
3380 matchingCacheEntry->typeId = typeId;
3381 matchingCacheEntry->accessMethodId = accessMethodId;
3382 matchingCacheEntry->strategyNumber = strategyNumber;
3383 matchingCacheEntry->operatorId = operatorId;
3384 matchingCacheEntry->operatorClassInputType = operatorClassInputType;
3385 matchingCacheEntry->typeType = typeType;
3386
3387 OperatorCache = lappend(OperatorCache, matchingCacheEntry);
3388
3389 MemoryContextSwitchTo(oldContext);
3390 }
3391
3392 return matchingCacheEntry;
3393 }
3394
3395
3396 /*
3397 * GetOperatorByType returns the operator oid for the given type, access method,
3398 * and strategy number.
3399 */
3400 static Oid
GetOperatorByType(Oid typeId,Oid accessMethodId,int16 strategyNumber)3401 GetOperatorByType(Oid typeId, Oid accessMethodId, int16 strategyNumber)
3402 {
3403 /* Get default operator class from pg_opclass */
3404 Oid operatorClassId = GetDefaultOpClass(typeId, accessMethodId);
3405
3406 Oid operatorFamily = get_opclass_family(operatorClassId);
3407 Oid operatorClassInputType = get_opclass_input_type(operatorClassId);
3408
3409 /* Lookup for the operator with the desired input type in the family */
3410 Oid operatorId = get_opfamily_member(operatorFamily, operatorClassInputType,
3411 operatorClassInputType, strategyNumber);
3412 return operatorId;
3413 }
3414
3415
3416 /*
3417 * BinaryOpExpression checks that a given expression is a binary operator. If
3418 * this is the case it returns true and sets leftOperand and rightOperand to
3419 * the left and right hand side of the operator. left/rightOperand will be
3420 * stripped of implicit coercions by strip_implicit_coercions.
3421 */
3422 bool
BinaryOpExpression(Expr * clause,Node ** leftOperand,Node ** rightOperand)3423 BinaryOpExpression(Expr *clause, Node **leftOperand, Node **rightOperand)
3424 {
3425 if (!is_opclause(clause) || list_length(((OpExpr *) clause)->args) != 2)
3426 {
3427 if (leftOperand != NULL)
3428 {
3429 *leftOperand = NULL;
3430 }
3431 if (rightOperand != NULL)
3432 {
3433 *leftOperand = NULL;
3434 }
3435 return false;
3436 }
3437 if (leftOperand != NULL)
3438 {
3439 *leftOperand = get_leftop(clause);
3440 Assert(*leftOperand != NULL);
3441 *leftOperand = strip_implicit_coercions(*leftOperand);
3442 }
3443 if (rightOperand != NULL)
3444 {
3445 *rightOperand = get_rightop(clause);
3446 Assert(*rightOperand != NULL);
3447 *rightOperand = strip_implicit_coercions(*rightOperand);
3448 }
3449 return true;
3450 }
3451
3452
3453 /*
3454 * SimpleOpExpression checks that given expression is a simple operator
3455 * expression. A simple operator expression is a binary operator expression with
3456 * operands of a var and a non-null constant.
3457 */
3458 bool
SimpleOpExpression(Expr * clause)3459 SimpleOpExpression(Expr *clause)
3460 {
3461 Const *constantClause = NULL;
3462
3463 Node *leftOperand;
3464 Node *rightOperand;
3465 if (!BinaryOpExpression(clause, &leftOperand, &rightOperand))
3466 {
3467 return false;
3468 }
3469
3470 if (IsA(rightOperand, Const) && IsA(leftOperand, Var))
3471 {
3472 constantClause = (Const *) rightOperand;
3473 }
3474 else if (IsA(leftOperand, Const) && IsA(rightOperand, Var))
3475 {
3476 constantClause = (Const *) leftOperand;
3477 }
3478 else
3479 {
3480 return false;
3481 }
3482
3483 if (constantClause->constisnull)
3484 {
3485 return false;
3486 }
3487
3488 return true;
3489 }
3490
3491
3492 /*
3493 * MakeInt4Column creates a column of int4 type with invalid table id and max
3494 * attribute number.
3495 */
3496 Var *
MakeInt4Column()3497 MakeInt4Column()
3498 {
3499 Index tableId = 0;
3500 AttrNumber columnAttributeNumber = RESERVED_HASHED_COLUMN_ID;
3501 Oid columnType = INT4OID;
3502 int32 columnTypeMod = -1;
3503 Oid columnCollationOid = InvalidOid;
3504 Index columnLevelSup = 0;
3505
3506 Var *int4Column = makeVar(tableId, columnAttributeNumber, columnType,
3507 columnTypeMod, columnCollationOid, columnLevelSup);
3508 return int4Column;
3509 }
3510
3511
3512 /* Updates the base constraint with the given min/max values. */
3513 void
UpdateConstraint(Node * baseConstraint,ShardInterval * shardInterval)3514 UpdateConstraint(Node *baseConstraint, ShardInterval *shardInterval)
3515 {
3516 BoolExpr *andExpr = (BoolExpr *) baseConstraint;
3517 Node *lessThanExpr = (Node *) linitial(andExpr->args);
3518 Node *greaterThanExpr = (Node *) lsecond(andExpr->args);
3519
3520 Node *minNode = get_rightop((Expr *) greaterThanExpr); /* right op */
3521 Node *maxNode = get_rightop((Expr *) lessThanExpr); /* right op */
3522
3523 Assert(shardInterval != NULL);
3524 Assert(shardInterval->minValueExists);
3525 Assert(shardInterval->maxValueExists);
3526 Assert(minNode != NULL);
3527 Assert(maxNode != NULL);
3528 Assert(IsA(minNode, Const));
3529 Assert(IsA(maxNode, Const));
3530
3531 Const *minConstant = (Const *) minNode;
3532 Const *maxConstant = (Const *) maxNode;
3533
3534 minConstant->constvalue = datumCopy(shardInterval->minValue,
3535 shardInterval->valueByVal,
3536 shardInterval->valueTypeLen);
3537 maxConstant->constvalue = datumCopy(shardInterval->maxValue,
3538 shardInterval->valueByVal,
3539 shardInterval->valueTypeLen);
3540
3541 minConstant->constisnull = false;
3542 maxConstant->constisnull = false;
3543 }
3544
3545
3546 /*
3547 * FragmentCombinationList first builds an ordered sequence of range tables that
3548 * join together. The function then iteratively adds fragments from each joined
3549 * range table, and forms fragment combinations (lists) that cover all tables.
3550 * While doing so, the function also performs join pruning to remove unnecessary
3551 * fragment pairs. Last, the function adds each fragment combination (list) to a
3552 * list, and returns this list.
3553 */
3554 static List *
FragmentCombinationList(List * rangeTableFragmentsList,Query * jobQuery,List * dependentJobList)3555 FragmentCombinationList(List *rangeTableFragmentsList, Query *jobQuery,
3556 List *dependentJobList)
3557 {
3558 List *fragmentCombinationList = NIL;
3559 List *fragmentCombinationQueue = NIL;
3560 List *emptyList = NIL;
3561
3562 /* find a sequence that joins the range tables in the list */
3563 JoinSequenceNode *joinSequenceArray = JoinSequenceArray(rangeTableFragmentsList,
3564 jobQuery,
3565 dependentJobList);
3566
3567 /*
3568 * We use breadth-first search with pruning to create fragment combinations.
3569 * For this, we first queue the root node (an empty combination), and then
3570 * start traversing our search space.
3571 */
3572 fragmentCombinationQueue = lappend(fragmentCombinationQueue, emptyList);
3573 while (fragmentCombinationQueue != NIL)
3574 {
3575 ListCell *tableFragmentCell = NULL;
3576 int32 joiningTableSequenceIndex = -1;
3577
3578 /* pop first element from the fragment queue */
3579 List *fragmentCombination = linitial(fragmentCombinationQueue);
3580 fragmentCombinationQueue = list_delete_first(fragmentCombinationQueue);
3581
3582 /*
3583 * If this combination covered all range tables in a join sequence, add
3584 * this combination to our result set.
3585 */
3586 int32 joinSequenceIndex = list_length(fragmentCombination);
3587 int32 rangeTableCount = list_length(rangeTableFragmentsList);
3588 if (joinSequenceIndex == rangeTableCount)
3589 {
3590 fragmentCombinationList = lappend(fragmentCombinationList,
3591 fragmentCombination);
3592 continue;
3593 }
3594
3595 /* find the next range table to add to our search space */
3596 uint32 tableId = joinSequenceArray[joinSequenceIndex].rangeTableId;
3597 List *tableFragments = FindRangeTableFragmentsList(rangeTableFragmentsList,
3598 tableId);
3599
3600 /* resolve sequence index for the previous range table we join against */
3601 int32 joiningTableId = joinSequenceArray[joinSequenceIndex].joiningRangeTableId;
3602 if (joiningTableId != NON_PRUNABLE_JOIN)
3603 {
3604 for (int32 sequenceIndex = 0; sequenceIndex < rangeTableCount;
3605 sequenceIndex++)
3606 {
3607 JoinSequenceNode *joinSequenceNode = &joinSequenceArray[sequenceIndex];
3608 if (joinSequenceNode->rangeTableId == joiningTableId)
3609 {
3610 joiningTableSequenceIndex = sequenceIndex;
3611 break;
3612 }
3613 }
3614
3615 Assert(joiningTableSequenceIndex != -1);
3616 }
3617
3618 /*
3619 * We walk over each range table fragment, and check if we can prune out
3620 * this fragment joining with the existing fragment combination. If we
3621 * can't prune away, we create a new fragment combination and add it to
3622 * our search space.
3623 */
3624 foreach(tableFragmentCell, tableFragments)
3625 {
3626 RangeTableFragment *tableFragment = lfirst(tableFragmentCell);
3627 bool joinPrunable = false;
3628
3629 if (joiningTableId != NON_PRUNABLE_JOIN)
3630 {
3631 RangeTableFragment *joiningTableFragment =
3632 list_nth(fragmentCombination, joiningTableSequenceIndex);
3633
3634 joinPrunable = JoinPrunable(joiningTableFragment, tableFragment);
3635 }
3636
3637 /* if join can't be pruned, extend fragment combination and search */
3638 if (!joinPrunable)
3639 {
3640 List *newFragmentCombination = list_copy(fragmentCombination);
3641 newFragmentCombination = lappend(newFragmentCombination, tableFragment);
3642
3643 fragmentCombinationQueue = lappend(fragmentCombinationQueue,
3644 newFragmentCombination);
3645 }
3646 }
3647 }
3648
3649 return fragmentCombinationList;
3650 }
3651
3652
3653 /*
3654 * NodeIsRangeTblRefReferenceTable checks if the node is a RangeTblRef that
3655 * points to a reference table in the rangeTableList.
3656 */
3657 static bool
NodeIsRangeTblRefReferenceTable(Node * node,List * rangeTableList)3658 NodeIsRangeTblRefReferenceTable(Node *node, List *rangeTableList)
3659 {
3660 if (!IsA(node, RangeTblRef))
3661 {
3662 return false;
3663 }
3664 RangeTblRef *tableRef = castNode(RangeTblRef, node);
3665 RangeTblEntry *rangeTableEntry = rt_fetch(tableRef->rtindex, rangeTableList);
3666 CitusRTEKind rangeTableType = GetRangeTblKind(rangeTableEntry);
3667 if (rangeTableType != CITUS_RTE_RELATION)
3668 {
3669 return false;
3670 }
3671 return IsCitusTableType(rangeTableEntry->relid, REFERENCE_TABLE);
3672 }
3673
3674
3675 /*
3676 * FetchEqualityAttrNumsForRTE fetches the attribute numbers from quals
3677 * which have an equality operator
3678 */
3679 List *
FetchEqualityAttrNumsForRTE(Node * node)3680 FetchEqualityAttrNumsForRTE(Node *node)
3681 {
3682 if (node == NULL)
3683 {
3684 return NIL;
3685 }
3686 if (IsA(node, List))
3687 {
3688 return FetchEqualityAttrNumsForList((List *) node);
3689 }
3690 else if (IsA(node, OpExpr))
3691 {
3692 return FetchEqualityAttrNumsForRTEOpExpr((OpExpr *) node);
3693 }
3694 else if (IsA(node, BoolExpr))
3695 {
3696 return FetchEqualityAttrNumsForRTEBoolExpr((BoolExpr *) node);
3697 }
3698 return NIL;
3699 }
3700
3701
3702 /*
3703 * FetchEqualityAttrNumsForList fetches the attribute numbers of expression
3704 * of the form "= constant" from the given node list.
3705 */
3706 static List *
FetchEqualityAttrNumsForList(List * nodeList)3707 FetchEqualityAttrNumsForList(List *nodeList)
3708 {
3709 List *attributeNums = NIL;
3710 Node *node = NULL;
3711 bool hasAtLeastOneEquality = false;
3712 foreach_ptr(node, nodeList)
3713 {
3714 List *fetchedEqualityAttrNums =
3715 FetchEqualityAttrNumsForRTE(node);
3716 hasAtLeastOneEquality |= list_length(fetchedEqualityAttrNums) > 0;
3717 attributeNums = list_concat(attributeNums, fetchedEqualityAttrNums);
3718 }
3719
3720 /*
3721 * the given list is in the form of AND'ed expressions
3722 * hence if we have one equality then it is enough.
3723 * E.g: dist.a = 5 AND dist.a > 10
3724 */
3725 if (hasAtLeastOneEquality)
3726 {
3727 return attributeNums;
3728 }
3729 return NIL;
3730 }
3731
3732
3733 /*
3734 * FetchEqualityAttrNumsForRTEOpExpr fetches the attribute numbers of expression
3735 * of the form "= constant" from the given opExpr.
3736 */
3737 static List *
FetchEqualityAttrNumsForRTEOpExpr(OpExpr * opExpr)3738 FetchEqualityAttrNumsForRTEOpExpr(OpExpr *opExpr)
3739 {
3740 if (!OperatorImplementsEquality(opExpr->opno))
3741 {
3742 return NIL;
3743 }
3744
3745 List *attributeNums = NIL;
3746 Var *var = NULL;
3747 if (VarConstOpExprClause(opExpr, &var, NULL))
3748 {
3749 attributeNums = lappend_int(attributeNums, var->varattno);
3750 }
3751 return attributeNums;
3752 }
3753
3754
3755 /*
3756 * FetchEqualityAttrNumsForRTEBoolExpr fetches the attribute numbers of expression
3757 * of the form "= constant" from the given boolExpr.
3758 */
3759 static List *
FetchEqualityAttrNumsForRTEBoolExpr(BoolExpr * boolExpr)3760 FetchEqualityAttrNumsForRTEBoolExpr(BoolExpr *boolExpr)
3761 {
3762 if (boolExpr->boolop != AND_EXPR && boolExpr->boolop != OR_EXPR)
3763 {
3764 return NIL;
3765 }
3766
3767 List *attributeNums = NIL;
3768 bool hasEquality = true;
3769 Node *arg = NULL;
3770 foreach_ptr(arg, boolExpr->args)
3771 {
3772 List *attributeNumsInSubExpression = FetchEqualityAttrNumsForRTE(arg);
3773 if (boolExpr->boolop == AND_EXPR)
3774 {
3775 hasEquality |= list_length(attributeNumsInSubExpression) > 0;
3776 }
3777 else if (boolExpr->boolop == OR_EXPR)
3778 {
3779 hasEquality &= list_length(attributeNumsInSubExpression) > 0;
3780 }
3781 attributeNums = list_concat(attributeNums, attributeNumsInSubExpression);
3782 }
3783 if (hasEquality)
3784 {
3785 return attributeNums;
3786 }
3787 return NIL;
3788 }
3789
3790
3791 /*
3792 * JoinSequenceArray walks over the join nodes in the job query and constructs a join
3793 * sequence containing an entry for each joined table. The function then returns an
3794 * array of join sequence nodes, in which each node contains the id of a table in the
3795 * range table list and the id of a preceding table with which it is joined, if any.
3796 */
3797 static JoinSequenceNode *
JoinSequenceArray(List * rangeTableFragmentsList,Query * jobQuery,List * dependentJobList)3798 JoinSequenceArray(List *rangeTableFragmentsList, Query *jobQuery, List *dependentJobList)
3799 {
3800 List *rangeTableList = jobQuery->rtable;
3801 uint32 rangeTableCount = (uint32) list_length(rangeTableList);
3802 uint32 sequenceNodeSize = sizeof(JoinSequenceNode);
3803 uint32 joinedTableCount = 0;
3804 ListCell *joinExprCell = NULL;
3805 uint32 firstRangeTableId = 1;
3806 JoinSequenceNode *joinSequenceArray = palloc0(rangeTableCount * sequenceNodeSize);
3807
3808 List *joinExprList = JoinExprList(jobQuery->jointree);
3809
3810 /* pick first range table as starting table for the join sequence */
3811 if (list_length(joinExprList) > 0)
3812 {
3813 JoinExpr *firstExpr = (JoinExpr *) linitial(joinExprList);
3814 RangeTblRef *leftTableRef = (RangeTblRef *) firstExpr->larg;
3815 firstRangeTableId = leftTableRef->rtindex;
3816 }
3817 else
3818 {
3819 /* when there are no joins, the join sequence contains a node for the table */
3820 firstRangeTableId = 1;
3821 }
3822
3823 joinSequenceArray[joinedTableCount].rangeTableId = firstRangeTableId;
3824 joinSequenceArray[joinedTableCount].joiningRangeTableId = NON_PRUNABLE_JOIN;
3825 joinedTableCount++;
3826
3827 foreach(joinExprCell, joinExprList)
3828 {
3829 JoinExpr *joinExpr = (JoinExpr *) lfirst(joinExprCell);
3830 RangeTblRef *rightTableRef = castNode(RangeTblRef, joinExpr->rarg);
3831 uint32 nextRangeTableId = rightTableRef->rtindex;
3832 Index existingRangeTableId = 0;
3833 bool applyJoinPruning = false;
3834
3835 List *nextJoinClauseList = make_ands_implicit((Expr *) joinExpr->quals);
3836 bool leftIsReferenceTable = NodeIsRangeTblRefReferenceTable(joinExpr->larg,
3837 rangeTableList);
3838 bool rightIsReferenceTable = NodeIsRangeTblRefReferenceTable(joinExpr->rarg,
3839 rangeTableList);
3840 bool isReferenceJoin = IsSupportedReferenceJoin(joinExpr->jointype,
3841 leftIsReferenceTable,
3842 rightIsReferenceTable);
3843
3844 /*
3845 * If next join clause list is empty, the user tried a cartesian product
3846 * between tables. We don't support this functionality for non
3847 * reference joins, and error out.
3848 */
3849 if (nextJoinClauseList == NIL && !isReferenceJoin)
3850 {
3851 ereport(ERROR, (errcode(ERRCODE_FEATURE_NOT_SUPPORTED),
3852 errmsg("cannot perform distributed planning on this query"),
3853 errdetail("Cartesian products are currently unsupported")));
3854 }
3855
3856 /*
3857 * We now determine if we can apply join pruning between existing range
3858 * tables and this new one.
3859 */
3860 Node *nextJoinClause = NULL;
3861 foreach_ptr(nextJoinClause, nextJoinClauseList)
3862 {
3863 if (!NodeIsEqualsOpExpr(nextJoinClause))
3864 {
3865 continue;
3866 }
3867
3868 OpExpr *nextJoinClauseOpExpr = castNode(OpExpr, nextJoinClause);
3869
3870 if (!IsJoinClause((Node *) nextJoinClauseOpExpr))
3871 {
3872 continue;
3873 }
3874
3875 Var *leftColumn = LeftColumnOrNULL(nextJoinClauseOpExpr);
3876 Var *rightColumn = RightColumnOrNULL(nextJoinClauseOpExpr);
3877 if (leftColumn == NULL || rightColumn == NULL)
3878 {
3879 continue;
3880 }
3881
3882 Index leftRangeTableId = leftColumn->varno;
3883 Index rightRangeTableId = rightColumn->varno;
3884
3885 /*
3886 * We have a table from the existing join list joining with the next
3887 * table. First resolve the existing table's range table id.
3888 */
3889 if (leftRangeTableId == nextRangeTableId)
3890 {
3891 existingRangeTableId = rightRangeTableId;
3892 }
3893 else
3894 {
3895 existingRangeTableId = leftRangeTableId;
3896 }
3897
3898 /*
3899 * Then, we check if we can apply join pruning between the existing
3900 * range table and this new one. For this, columns need to have the
3901 * same type and be the partition column for their respective tables.
3902 */
3903 if (leftColumn->vartype != rightColumn->vartype)
3904 {
3905 continue;
3906 }
3907
3908 bool leftPartitioned = PartitionedOnColumn(leftColumn, rangeTableList,
3909 dependentJobList);
3910 bool rightPartitioned = PartitionedOnColumn(rightColumn, rangeTableList,
3911 dependentJobList);
3912 if (leftPartitioned && rightPartitioned)
3913 {
3914 /* make sure this join clause references only simple columns */
3915 CheckJoinBetweenColumns(nextJoinClauseOpExpr);
3916
3917 applyJoinPruning = true;
3918 break;
3919 }
3920 }
3921
3922 /* set next joining range table's info in the join sequence */
3923 JoinSequenceNode *nextJoinSequenceNode = &joinSequenceArray[joinedTableCount];
3924 if (applyJoinPruning)
3925 {
3926 nextJoinSequenceNode->rangeTableId = nextRangeTableId;
3927 nextJoinSequenceNode->joiningRangeTableId = (int32) existingRangeTableId;
3928 }
3929 else
3930 {
3931 nextJoinSequenceNode->rangeTableId = nextRangeTableId;
3932 nextJoinSequenceNode->joiningRangeTableId = NON_PRUNABLE_JOIN;
3933 }
3934
3935 joinedTableCount++;
3936 }
3937
3938 return joinSequenceArray;
3939 }
3940
3941
3942 /*
3943 * PartitionedOnColumn finds the given column's range table entry, and checks if
3944 * that range table is partitioned on the given column. Note that since reference
3945 * tables do not have partition columns, the function returns false when the distributed
3946 * relation is a reference table.
3947 */
3948 static bool
PartitionedOnColumn(Var * column,List * rangeTableList,List * dependentJobList)3949 PartitionedOnColumn(Var *column, List *rangeTableList, List *dependentJobList)
3950 {
3951 bool partitionedOnColumn = false;
3952 Index rangeTableId = column->varno;
3953 RangeTblEntry *rangeTableEntry = rt_fetch(rangeTableId, rangeTableList);
3954
3955 CitusRTEKind rangeTableType = GetRangeTblKind(rangeTableEntry);
3956 if (rangeTableType == CITUS_RTE_RELATION)
3957 {
3958 Oid relationId = rangeTableEntry->relid;
3959 Var *partitionColumn = PartitionColumn(relationId, rangeTableId);
3960
3961 /* non-distributed tables do not have partition columns */
3962 if (IsCitusTableType(relationId, CITUS_TABLE_WITH_NO_DIST_KEY))
3963 {
3964 return false;
3965 }
3966
3967 if (partitionColumn->varattno == column->varattno)
3968 {
3969 partitionedOnColumn = true;
3970 }
3971 }
3972 else if (rangeTableType == CITUS_RTE_REMOTE_QUERY)
3973 {
3974 Job *job = JobForRangeTable(dependentJobList, rangeTableEntry);
3975 MapMergeJob *mapMergeJob = (MapMergeJob *) job;
3976
3977 /*
3978 * The column's current attribute number is it's location in the target
3979 * list for the table represented by the remote query. We retrieve this
3980 * value from the target list to compare against the partition column
3981 * as stored in the job.
3982 */
3983 List *targetEntryList = job->jobQuery->targetList;
3984 int32 columnIndex = column->varattno - 1;
3985 Assert(columnIndex >= 0);
3986 Assert(columnIndex < list_length(targetEntryList));
3987
3988 TargetEntry *targetEntry = (TargetEntry *) list_nth(targetEntryList, columnIndex);
3989 Var *remoteRelationColumn = (Var *) targetEntry->expr;
3990 Assert(IsA(remoteRelationColumn, Var));
3991
3992 /* retrieve the partition column for the job */
3993 Var *partitionColumn = mapMergeJob->partitionColumn;
3994 if (partitionColumn->varattno == remoteRelationColumn->varattno)
3995 {
3996 partitionedOnColumn = true;
3997 }
3998 }
3999
4000 return partitionedOnColumn;
4001 }
4002
4003
4004 /* Checks that the join clause references only simple columns. */
4005 static void
CheckJoinBetweenColumns(OpExpr * joinClause)4006 CheckJoinBetweenColumns(OpExpr *joinClause)
4007 {
4008 List *argumentList = joinClause->args;
4009 Node *leftArgument = (Node *) linitial(argumentList);
4010 Node *rightArgument = (Node *) lsecond(argumentList);
4011 Node *strippedLeftArgument = strip_implicit_coercions(leftArgument);
4012 Node *strippedRightArgument = strip_implicit_coercions(rightArgument);
4013
4014 NodeTag leftArgumentType = nodeTag(strippedLeftArgument);
4015 NodeTag rightArgumentType = nodeTag(strippedRightArgument);
4016
4017 if (leftArgumentType != T_Var || rightArgumentType != T_Var)
4018 {
4019 ereport(ERROR, (errcode(ERRCODE_FEATURE_NOT_SUPPORTED),
4020 errmsg("cannot perform local joins that involve expressions"),
4021 errdetail("local joins can be performed between columns only")));
4022 }
4023 }
4024
4025
4026 /*
4027 * FindRangeTableFragmentsList walks over the given list of range table fragments
4028 * and, returns the one with the given table id.
4029 */
4030 static List *
FindRangeTableFragmentsList(List * rangeTableFragmentsList,int tableId)4031 FindRangeTableFragmentsList(List *rangeTableFragmentsList, int tableId)
4032 {
4033 List *foundTableFragments = NIL;
4034 ListCell *rangeTableFragmentsCell = NULL;
4035
4036 foreach(rangeTableFragmentsCell, rangeTableFragmentsList)
4037 {
4038 List *tableFragments = (List *) lfirst(rangeTableFragmentsCell);
4039 if (tableFragments != NIL)
4040 {
4041 RangeTableFragment *tableFragment =
4042 (RangeTableFragment *) linitial(tableFragments);
4043 if (tableFragment->rangeTableId == tableId)
4044 {
4045 foundTableFragments = tableFragments;
4046 break;
4047 }
4048 }
4049 }
4050
4051 return foundTableFragments;
4052 }
4053
4054
4055 /*
4056 * JoinPrunable checks if a join between the given left and right fragments can
4057 * be pruned away, without performing the actual join. To do this, the function
4058 * checks if we have a hash repartition join. If we do, the function determines
4059 * pruning based on partitionIds. Else if we have a merge repartition join, the
4060 * function checks if the two fragments have disjoint intervals.
4061 */
4062 static bool
JoinPrunable(RangeTableFragment * leftFragment,RangeTableFragment * rightFragment)4063 JoinPrunable(RangeTableFragment *leftFragment, RangeTableFragment *rightFragment)
4064 {
4065 /*
4066 * If both range tables are remote queries, we then have a hash repartition
4067 * join. In that case, we can just prune away this join if left and right
4068 * hand side fragments have the same partitionId.
4069 */
4070 if (leftFragment->fragmentType == CITUS_RTE_REMOTE_QUERY &&
4071 rightFragment->fragmentType == CITUS_RTE_REMOTE_QUERY)
4072 {
4073 Task *leftMergeTask = (Task *) leftFragment->fragmentReference;
4074 Task *rightMergeTask = (Task *) rightFragment->fragmentReference;
4075
4076
4077 if (leftMergeTask->partitionId != rightMergeTask->partitionId)
4078 {
4079 ereport(DEBUG2, (errmsg("join prunable for task partitionId %u and %u",
4080 leftMergeTask->partitionId,
4081 rightMergeTask->partitionId)));
4082 return true;
4083 }
4084 else
4085 {
4086 return false;
4087 }
4088 }
4089
4090
4091 /*
4092 * We have a single (re)partition join. We now get shard intervals for both
4093 * fragments, and then check if these intervals overlap.
4094 */
4095 ShardInterval *leftFragmentInterval = FragmentInterval(leftFragment);
4096 ShardInterval *rightFragmentInterval = FragmentInterval(rightFragment);
4097
4098 bool overlap = ShardIntervalsOverlap(leftFragmentInterval, rightFragmentInterval);
4099 if (!overlap)
4100 {
4101 if (IsLoggableLevel(DEBUG2))
4102 {
4103 StringInfo leftString = FragmentIntervalString(leftFragmentInterval);
4104 StringInfo rightString = FragmentIntervalString(rightFragmentInterval);
4105
4106 ereport(DEBUG2, (errmsg("join prunable for intervals %s and %s",
4107 leftString->data, rightString->data)));
4108 }
4109
4110 return true;
4111 }
4112
4113 return false;
4114 }
4115
4116
4117 /*
4118 * FragmentInterval takes the given fragment, and determines the range of data
4119 * covered by this fragment. The function then returns this range (interval).
4120 */
4121 static ShardInterval *
FragmentInterval(RangeTableFragment * fragment)4122 FragmentInterval(RangeTableFragment *fragment)
4123 {
4124 ShardInterval *fragmentInterval = NULL;
4125 if (fragment->fragmentType == CITUS_RTE_RELATION)
4126 {
4127 Assert(CitusIsA(fragment->fragmentReference, ShardInterval));
4128 fragmentInterval = (ShardInterval *) fragment->fragmentReference;
4129 }
4130 else if (fragment->fragmentType == CITUS_RTE_REMOTE_QUERY)
4131 {
4132 Assert(CitusIsA(fragment->fragmentReference, Task));
4133
4134 Task *mergeTask = (Task *) fragment->fragmentReference;
4135 fragmentInterval = mergeTask->shardInterval;
4136 }
4137
4138 return fragmentInterval;
4139 }
4140
4141
4142 /* Checks if the given shard intervals have overlapping ranges. */
4143 bool
ShardIntervalsOverlap(ShardInterval * firstInterval,ShardInterval * secondInterval)4144 ShardIntervalsOverlap(ShardInterval *firstInterval, ShardInterval *secondInterval)
4145 {
4146 CitusTableCacheEntry *intervalRelation =
4147 GetCitusTableCacheEntry(firstInterval->relationId);
4148
4149 Assert(IsCitusTableTypeCacheEntry(intervalRelation, DISTRIBUTED_TABLE));
4150
4151 if (!(firstInterval->minValueExists && firstInterval->maxValueExists &&
4152 secondInterval->minValueExists && secondInterval->maxValueExists))
4153 {
4154 return true;
4155 }
4156
4157 Datum firstMin = firstInterval->minValue;
4158 Datum firstMax = firstInterval->maxValue;
4159 Datum secondMin = secondInterval->minValue;
4160 Datum secondMax = secondInterval->maxValue;
4161
4162 FmgrInfo *comparisonFunction = intervalRelation->shardIntervalCompareFunction;
4163 Oid collation = intervalRelation->partitionColumn->varcollid;
4164
4165 return ShardIntervalsOverlapWithParams(firstMin, firstMax, secondMin, secondMax,
4166 comparisonFunction, collation);
4167 }
4168
4169
4170 /*
4171 * ShardIntervalsOverlapWithParams is a helper function which compares the input
4172 * shard min/max values, and returns true if the shards overlap.
4173 * The caller is responsible to ensure the input shard min/max values are not NULL.
4174 */
4175 bool
ShardIntervalsOverlapWithParams(Datum firstMin,Datum firstMax,Datum secondMin,Datum secondMax,FmgrInfo * comparisonFunction,Oid collation)4176 ShardIntervalsOverlapWithParams(Datum firstMin, Datum firstMax, Datum secondMin,
4177 Datum secondMax, FmgrInfo *comparisonFunction,
4178 Oid collation)
4179 {
4180 /*
4181 * We need to have min/max values for both intervals first. Then, we assume
4182 * two intervals i1 = [min1, max1] and i2 = [min2, max2] do not overlap if
4183 * (max1 < min2) or (max2 < min1). For details, please see the explanation
4184 * on overlapping intervals at http://www.rgrjr.com/emacs/overlap.html.
4185 */
4186 Datum firstDatum = FunctionCall2Coll(comparisonFunction, collation, firstMax,
4187 secondMin);
4188 Datum secondDatum = FunctionCall2Coll(comparisonFunction, collation, secondMax,
4189 firstMin);
4190 int firstComparison = DatumGetInt32(firstDatum);
4191 int secondComparison = DatumGetInt32(secondDatum);
4192
4193 if (firstComparison < 0 || secondComparison < 0)
4194 {
4195 return false;
4196 }
4197
4198 return true;
4199 }
4200
4201
4202 /*
4203 * FragmentIntervalString takes the given fragment interval, and converts this
4204 * interval into its string representation for use in debug messages.
4205 */
4206 static StringInfo
FragmentIntervalString(ShardInterval * fragmentInterval)4207 FragmentIntervalString(ShardInterval *fragmentInterval)
4208 {
4209 Oid typeId = fragmentInterval->valueTypeId;
4210 Oid outputFunctionId = InvalidOid;
4211 bool typeVariableLength = false;
4212
4213 Assert(fragmentInterval->minValueExists);
4214 Assert(fragmentInterval->maxValueExists);
4215
4216 FmgrInfo *outputFunction = (FmgrInfo *) palloc0(sizeof(FmgrInfo));
4217 getTypeOutputInfo(typeId, &outputFunctionId, &typeVariableLength);
4218 fmgr_info(outputFunctionId, outputFunction);
4219
4220 char *minValueString = OutputFunctionCall(outputFunction, fragmentInterval->minValue);
4221 char *maxValueString = OutputFunctionCall(outputFunction, fragmentInterval->maxValue);
4222
4223 StringInfo fragmentIntervalString = makeStringInfo();
4224 appendStringInfo(fragmentIntervalString, "[%s,%s]", minValueString, maxValueString);
4225
4226 return fragmentIntervalString;
4227 }
4228
4229
4230 /*
4231 * DataFetchTaskList builds a merge fetch task for every remote query result
4232 * in the given fragment list, appends these merge fetch tasks into a list,
4233 * and returns this list.
4234 */
4235 static List *
DataFetchTaskList(uint64 jobId,uint32 taskIdIndex,List * fragmentList)4236 DataFetchTaskList(uint64 jobId, uint32 taskIdIndex, List *fragmentList)
4237 {
4238 List *dataFetchTaskList = NIL;
4239 ListCell *fragmentCell = NULL;
4240
4241 foreach(fragmentCell, fragmentList)
4242 {
4243 RangeTableFragment *fragment = (RangeTableFragment *) lfirst(fragmentCell);
4244 if (fragment->fragmentType == CITUS_RTE_REMOTE_QUERY)
4245 {
4246 Task *mergeTask = (Task *) fragment->fragmentReference;
4247 char *undefinedQueryString = NULL;
4248
4249 /* create merge fetch task and have it depend on the merge task */
4250 Task *mergeFetchTask = CreateBasicTask(jobId, taskIdIndex, MERGE_FETCH_TASK,
4251 undefinedQueryString);
4252 mergeFetchTask->dependentTaskList = list_make1(mergeTask);
4253
4254 dataFetchTaskList = lappend(dataFetchTaskList, mergeFetchTask);
4255 taskIdIndex++;
4256 }
4257 }
4258
4259 return dataFetchTaskList;
4260 }
4261
4262
4263 /* Helper function to return a datum array's external string representation. */
4264 static StringInfo
DatumArrayString(Datum * datumArray,uint32 datumCount,Oid datumTypeId)4265 DatumArrayString(Datum *datumArray, uint32 datumCount, Oid datumTypeId)
4266 {
4267 int16 typeLength = 0;
4268 bool typeByValue = false;
4269 char typeAlignment = 0;
4270
4271 /* construct the array object from the given array */
4272 get_typlenbyvalalign(datumTypeId, &typeLength, &typeByValue, &typeAlignment);
4273 ArrayType *arrayObject = construct_array(datumArray, datumCount, datumTypeId,
4274 typeLength, typeByValue, typeAlignment);
4275 Datum arrayObjectDatum = PointerGetDatum(arrayObject);
4276
4277 /* convert the array object to its string representation */
4278 FmgrInfo *arrayOutFunction = (FmgrInfo *) palloc0(sizeof(FmgrInfo));
4279 fmgr_info(F_ARRAY_OUT, arrayOutFunction);
4280
4281 Datum arrayStringDatum = FunctionCall1(arrayOutFunction, arrayObjectDatum);
4282 char *arrayString = DatumGetCString(arrayStringDatum);
4283
4284 StringInfo arrayStringInfo = makeStringInfo();
4285 appendStringInfo(arrayStringInfo, "%s", arrayString);
4286
4287 return arrayStringInfo;
4288 }
4289
4290
4291 /*
4292 * CreateBasicTask creates a task, initializes fields that are common to each task,
4293 * and returns the created task.
4294 */
4295 Task *
CreateBasicTask(uint64 jobId,uint32 taskId,TaskType taskType,char * queryString)4296 CreateBasicTask(uint64 jobId, uint32 taskId, TaskType taskType, char *queryString)
4297 {
4298 Task *task = CitusMakeNode(Task);
4299 task->jobId = jobId;
4300 task->taskId = taskId;
4301 task->taskType = taskType;
4302 task->replicationModel = REPLICATION_MODEL_INVALID;
4303 SetTaskQueryString(task, queryString);
4304
4305 return task;
4306 }
4307
4308
4309 /*
4310 * BuildRelationShardList builds a list of RelationShard pairs for a task.
4311 * This represents the mapping of range table entries to shard IDs for a
4312 * task for the purposes of locking, deparsing, and connection management.
4313 */
4314 static List *
BuildRelationShardList(List * rangeTableList,List * fragmentList)4315 BuildRelationShardList(List *rangeTableList, List *fragmentList)
4316 {
4317 List *relationShardList = NIL;
4318 ListCell *fragmentCell = NULL;
4319
4320 foreach(fragmentCell, fragmentList)
4321 {
4322 RangeTableFragment *fragment = (RangeTableFragment *) lfirst(fragmentCell);
4323 Index rangeTableId = fragment->rangeTableId;
4324 RangeTblEntry *rangeTableEntry = rt_fetch(rangeTableId, rangeTableList);
4325
4326 CitusRTEKind fragmentType = fragment->fragmentType;
4327 if (fragmentType == CITUS_RTE_RELATION)
4328 {
4329 ShardInterval *shardInterval = (ShardInterval *) fragment->fragmentReference;
4330 RelationShard *relationShard = CitusMakeNode(RelationShard);
4331
4332 relationShard->relationId = rangeTableEntry->relid;
4333 relationShard->shardId = shardInterval->shardId;
4334
4335 relationShardList = lappend(relationShardList, relationShard);
4336 }
4337 }
4338
4339 return relationShardList;
4340 }
4341
4342
4343 /*
4344 * UpdateRangeTableAlias walks over each fragment in the given fragment list,
4345 * and creates an alias that represents the fragment name to be used in the
4346 * query. The function then updates the corresponding range table entry with
4347 * this alias.
4348 */
4349 static void
UpdateRangeTableAlias(List * rangeTableList,List * fragmentList)4350 UpdateRangeTableAlias(List *rangeTableList, List *fragmentList)
4351 {
4352 ListCell *fragmentCell = NULL;
4353 foreach(fragmentCell, fragmentList)
4354 {
4355 RangeTableFragment *fragment = (RangeTableFragment *) lfirst(fragmentCell);
4356 Index rangeTableId = fragment->rangeTableId;
4357 RangeTblEntry *rangeTableEntry = rt_fetch(rangeTableId, rangeTableList);
4358
4359 Alias *fragmentAlias = FragmentAlias(rangeTableEntry, fragment);
4360 rangeTableEntry->alias = fragmentAlias;
4361 }
4362 }
4363
4364
4365 /*
4366 * FragmentAlias creates an alias structure that captures the table fragment's
4367 * name on the worker node. Each fragment represents either a regular shard, or
4368 * a merge task.
4369 */
4370 static Alias *
FragmentAlias(RangeTblEntry * rangeTableEntry,RangeTableFragment * fragment)4371 FragmentAlias(RangeTblEntry *rangeTableEntry, RangeTableFragment *fragment)
4372 {
4373 char *aliasName = NULL;
4374 char *schemaName = NULL;
4375 char *fragmentName = NULL;
4376
4377 CitusRTEKind fragmentType = fragment->fragmentType;
4378 if (fragmentType == CITUS_RTE_RELATION)
4379 {
4380 ShardInterval *shardInterval = (ShardInterval *) fragment->fragmentReference;
4381 uint64 shardId = shardInterval->shardId;
4382
4383 Oid relationId = rangeTableEntry->relid;
4384 char *relationName = get_rel_name(relationId);
4385
4386 Oid schemaId = get_rel_namespace(relationId);
4387 schemaName = get_namespace_name(schemaId);
4388
4389 aliasName = relationName;
4390
4391 /*
4392 * Set shard name in alias to <relation_name>_<shard_id>.
4393 */
4394 fragmentName = pstrdup(relationName);
4395 AppendShardIdToName(&fragmentName, shardId);
4396 }
4397 else if (fragmentType == CITUS_RTE_REMOTE_QUERY)
4398 {
4399 Task *mergeTask = (Task *) fragment->fragmentReference;
4400 uint64 jobId = mergeTask->jobId;
4401 uint32 taskId = mergeTask->taskId;
4402
4403 StringInfo jobSchemaName = JobSchemaName(jobId);
4404 StringInfo taskTableName = TaskTableName(taskId);
4405
4406 StringInfo aliasNameString = makeStringInfo();
4407 appendStringInfo(aliasNameString, "%s.%s",
4408 jobSchemaName->data, taskTableName->data);
4409
4410 aliasName = aliasNameString->data;
4411 fragmentName = taskTableName->data;
4412 schemaName = jobSchemaName->data;
4413 }
4414
4415 /*
4416 * We need to set the aliasname to relation name, as pg_get_query_def() uses
4417 * the relation name to disambiguate column names from different tables.
4418 */
4419 Alias *alias = rangeTableEntry->alias;
4420 if (alias == NULL)
4421 {
4422 alias = makeNode(Alias);
4423 alias->aliasname = aliasName;
4424 }
4425
4426 ModifyRangeTblExtraData(rangeTableEntry, CITUS_RTE_SHARD,
4427 schemaName, fragmentName, NIL);
4428
4429 return alias;
4430 }
4431
4432
4433 /*
4434 * AnchorShardId walks over each fragment in the given fragment list, finds the
4435 * fragment that corresponds to the given anchor range tableId, and returns this
4436 * fragment's shard identifier. Note that the given tableId must correspond to a
4437 * base relation.
4438 */
4439 static uint64
AnchorShardId(List * fragmentList,uint32 anchorRangeTableId)4440 AnchorShardId(List *fragmentList, uint32 anchorRangeTableId)
4441 {
4442 uint64 anchorShardId = INVALID_SHARD_ID;
4443 ListCell *fragmentCell = NULL;
4444
4445 foreach(fragmentCell, fragmentList)
4446 {
4447 RangeTableFragment *fragment = (RangeTableFragment *) lfirst(fragmentCell);
4448 if (fragment->rangeTableId == anchorRangeTableId)
4449 {
4450 Assert(fragment->fragmentType == CITUS_RTE_RELATION);
4451 Assert(CitusIsA(fragment->fragmentReference, ShardInterval));
4452
4453 ShardInterval *shardInterval = (ShardInterval *) fragment->fragmentReference;
4454 anchorShardId = shardInterval->shardId;
4455 break;
4456 }
4457 }
4458
4459 Assert(anchorShardId != INVALID_SHARD_ID);
4460 return anchorShardId;
4461 }
4462
4463
4464 /*
4465 * PruneSqlTaskDependencies iterates over each sql task from the given sql task
4466 * list, and prunes away merge-fetch tasks, as the task assignment algorithm
4467 * ensures co-location of these tasks.
4468 */
4469 static List *
PruneSqlTaskDependencies(List * sqlTaskList)4470 PruneSqlTaskDependencies(List *sqlTaskList)
4471 {
4472 ListCell *sqlTaskCell = NULL;
4473 foreach(sqlTaskCell, sqlTaskList)
4474 {
4475 Task *sqlTask = (Task *) lfirst(sqlTaskCell);
4476 List *dependentTaskList = sqlTask->dependentTaskList;
4477 List *prunedDependendTaskList = NIL;
4478
4479 ListCell *dependentTaskCell = NULL;
4480 foreach(dependentTaskCell, dependentTaskList)
4481 {
4482 Task *dataFetchTask = (Task *) lfirst(dependentTaskCell);
4483
4484 /*
4485 * If we have a merge fetch task, our task assignment algorithm makes
4486 * sure that the sql task is colocated with the anchor shard / merge
4487 * task. We can therefore prune out this data fetch task.
4488 */
4489 if (dataFetchTask->taskType == MERGE_FETCH_TASK)
4490 {
4491 List *mergeFetchDependencyList = dataFetchTask->dependentTaskList;
4492 Assert(list_length(mergeFetchDependencyList) == 1);
4493
4494 Task *mergeTaskReference = (Task *) linitial(mergeFetchDependencyList);
4495 prunedDependendTaskList = lappend(prunedDependendTaskList,
4496 mergeTaskReference);
4497
4498 ereport(DEBUG2, (errmsg("pruning merge fetch taskId %d",
4499 dataFetchTask->taskId),
4500 errdetail("Creating dependency on merge taskId %d",
4501 mergeTaskReference->taskId)));
4502 }
4503 }
4504
4505 sqlTask->dependentTaskList = prunedDependendTaskList;
4506 }
4507
4508 return sqlTaskList;
4509 }
4510
4511
4512 /*
4513 * MapTaskList creates a list of map tasks for the given MapMerge job. For this,
4514 * the function walks over each filter task (sql task) in the given filter task
4515 * list, and wraps this task with a map function call. The map function call
4516 * repartitions the filter task's output according to MapMerge job's parameters.
4517 */
4518 static List *
MapTaskList(MapMergeJob * mapMergeJob,List * filterTaskList)4519 MapTaskList(MapMergeJob *mapMergeJob, List *filterTaskList)
4520 {
4521 List *mapTaskList = NIL;
4522 Query *filterQuery = mapMergeJob->job.jobQuery;
4523 ListCell *filterTaskCell = NULL;
4524 Var *partitionColumn = mapMergeJob->partitionColumn;
4525
4526 uint32 partitionColumnResNo = 0;
4527 List *groupClauseList = filterQuery->groupClause;
4528 if (groupClauseList != NIL)
4529 {
4530 List *targetEntryList = filterQuery->targetList;
4531 List *groupTargetEntryList = GroupTargetEntryList(groupClauseList,
4532 targetEntryList);
4533 TargetEntry *groupByTargetEntry = (TargetEntry *) linitial(groupTargetEntryList);
4534
4535 partitionColumnResNo = groupByTargetEntry->resno;
4536 }
4537 else
4538 {
4539 partitionColumnResNo = PartitionColumnIndex(partitionColumn,
4540 filterQuery->targetList);
4541 }
4542
4543 foreach(filterTaskCell, filterTaskList)
4544 {
4545 Task *filterTask = (Task *) lfirst(filterTaskCell);
4546 StringInfo mapQueryString = CreateMapQueryString(mapMergeJob, filterTask,
4547 partitionColumnResNo);
4548
4549 /* convert filter query task into map task */
4550 Task *mapTask = filterTask;
4551 SetTaskQueryString(mapTask, mapQueryString->data);
4552 mapTask->taskType = MAP_TASK;
4553
4554 mapTaskList = lappend(mapTaskList, mapTask);
4555 }
4556
4557 return mapTaskList;
4558 }
4559
4560
4561 /*
4562 * PartitionColumnIndex finds the index of the given target var.
4563 */
4564 static int
PartitionColumnIndex(Var * targetVar,List * targetList)4565 PartitionColumnIndex(Var *targetVar, List *targetList)
4566 {
4567 TargetEntry *targetEntry = NULL;
4568 int resNo = 1;
4569 foreach_ptr(targetEntry, targetList)
4570 {
4571 if (IsA(targetEntry->expr, Var))
4572 {
4573 Var *candidateVar = (Var *) targetEntry->expr;
4574 if (candidateVar->varattno == targetVar->varattno &&
4575 candidateVar->varno == targetVar->varno)
4576 {
4577 return resNo;
4578 }
4579 resNo++;
4580 }
4581 }
4582
4583 ereport(ERROR, (errmsg("unexpected state: %d varno %d varattno couldn't be found",
4584 targetVar->varno, targetVar->varattno)));
4585 return resNo;
4586 }
4587
4588
4589 /*
4590 * CreateMapQueryString creates and returns the map query string for the given filterTask.
4591 */
4592 static StringInfo
CreateMapQueryString(MapMergeJob * mapMergeJob,Task * filterTask,uint32 partitionColumnIndex)4593 CreateMapQueryString(MapMergeJob *mapMergeJob, Task *filterTask,
4594 uint32 partitionColumnIndex)
4595 {
4596 uint64 jobId = filterTask->jobId;
4597 uint32 taskId = filterTask->taskId;
4598
4599 /* wrap repartition query string around filter query string */
4600 StringInfo mapQueryString = makeStringInfo();
4601 char *filterQueryString = TaskQueryString(filterTask);
4602 char *filterQueryEscapedText = quote_literal_cstr(filterQueryString);
4603 PartitionType partitionType = mapMergeJob->partitionType;
4604
4605 Var *partitionColumn = mapMergeJob->partitionColumn;
4606 Oid partitionColumnType = partitionColumn->vartype;
4607 char *partitionColumnTypeFullName = format_type_be_qualified(partitionColumnType);
4608 int32 partitionColumnTypeMod = partitionColumn->vartypmod;
4609
4610 ShardInterval **intervalArray = mapMergeJob->sortedShardIntervalArray;
4611 uint32 intervalCount = mapMergeJob->partitionCount;
4612
4613 if (partitionType == DUAL_HASH_PARTITION_TYPE)
4614 {
4615 partitionColumnType = INT4OID;
4616 partitionColumnTypeMod = get_typmodin(INT4OID);
4617 intervalArray = GenerateSyntheticShardIntervalArray(intervalCount);
4618 }
4619 else if (partitionType == SINGLE_HASH_PARTITION_TYPE)
4620 {
4621 partitionColumnType = INT4OID;
4622 partitionColumnTypeMod = get_typmodin(INT4OID);
4623 }
4624
4625 ArrayType *splitPointObject = SplitPointObject(intervalArray, intervalCount);
4626 StringInfo splitPointString = ArrayObjectToString(splitPointObject,
4627 partitionColumnType,
4628 partitionColumnTypeMod);
4629
4630 char *partitionCommand = NULL;
4631 if (partitionType == RANGE_PARTITION_TYPE)
4632 {
4633 partitionCommand = RANGE_PARTITION_COMMAND;
4634 }
4635 else
4636 {
4637 partitionCommand = HASH_PARTITION_COMMAND;
4638 }
4639
4640 char *partitionColumnIndextText = ConvertIntToString(partitionColumnIndex);
4641 appendStringInfo(mapQueryString, partitionCommand, jobId, taskId,
4642 filterQueryEscapedText, partitionColumnIndextText,
4643 partitionColumnTypeFullName, splitPointString->data);
4644 return mapQueryString;
4645 }
4646
4647
4648 /*
4649 * GenerateSyntheticShardIntervalArray returns a shard interval pointer array
4650 * which has a uniform hash distribution for the given input partitionCount.
4651 *
4652 * The function only fills the min/max values of shard the intervals. Thus, should
4653 * not be used for general purpose operations.
4654 */
4655 ShardInterval **
GenerateSyntheticShardIntervalArray(int partitionCount)4656 GenerateSyntheticShardIntervalArray(int partitionCount)
4657 {
4658 ShardInterval **shardIntervalArray = palloc0(partitionCount *
4659 sizeof(ShardInterval *));
4660 uint64 hashTokenIncrement = HASH_TOKEN_COUNT / partitionCount;
4661
4662 for (int shardIndex = 0; shardIndex < partitionCount; ++shardIndex)
4663 {
4664 ShardInterval *shardInterval = CitusMakeNode(ShardInterval);
4665
4666 /* calculate the split of the hash space */
4667 int32 shardMinHashToken = PG_INT32_MIN + (shardIndex * hashTokenIncrement);
4668 int32 shardMaxHashToken = shardMinHashToken + (hashTokenIncrement - 1);
4669
4670 shardInterval->relationId = InvalidOid;
4671 shardInterval->minValueExists = true;
4672 shardInterval->minValue = Int32GetDatum(shardMinHashToken);
4673
4674 shardInterval->maxValueExists = true;
4675 shardInterval->maxValue = Int32GetDatum(shardMaxHashToken);
4676
4677 shardInterval->shardId = INVALID_SHARD_ID;
4678 shardInterval->valueTypeId = INT4OID;
4679
4680 shardIntervalArray[shardIndex] = shardInterval;
4681 }
4682
4683 return shardIntervalArray;
4684 }
4685
4686
4687 /*
4688 * Determine RowModifyLevel required for given query
4689 */
4690 RowModifyLevel
RowModifyLevelForQuery(Query * query)4691 RowModifyLevelForQuery(Query *query)
4692 {
4693 CmdType commandType = query->commandType;
4694
4695 if (commandType == CMD_SELECT)
4696 {
4697 if (query->hasModifyingCTE)
4698 {
4699 /* skip checking for INSERT as those CTEs are recursively planned */
4700 CommonTableExpr *cte = NULL;
4701 foreach_ptr(cte, query->cteList)
4702 {
4703 Query *cteQuery = (Query *) cte->ctequery;
4704
4705 if (cteQuery->commandType == CMD_UPDATE ||
4706 cteQuery->commandType == CMD_DELETE)
4707 {
4708 return ROW_MODIFY_NONCOMMUTATIVE;
4709 }
4710 }
4711 }
4712
4713 return ROW_MODIFY_READONLY;
4714 }
4715
4716 if (commandType == CMD_INSERT)
4717 {
4718 if (query->onConflict == NULL)
4719 {
4720 return ROW_MODIFY_COMMUTATIVE;
4721 }
4722 else
4723 {
4724 return ROW_MODIFY_NONCOMMUTATIVE;
4725 }
4726 }
4727
4728 if (commandType == CMD_UPDATE ||
4729 commandType == CMD_DELETE)
4730 {
4731 return ROW_MODIFY_NONCOMMUTATIVE;
4732 }
4733
4734 return ROW_MODIFY_NONE;
4735 }
4736
4737
4738 /*
4739 * ArrayObjectToString converts an SQL object to its string representation.
4740 */
4741 StringInfo
ArrayObjectToString(ArrayType * arrayObject,Oid columnType,int32 columnTypeMod)4742 ArrayObjectToString(ArrayType *arrayObject, Oid columnType, int32 columnTypeMod)
4743 {
4744 Datum arrayDatum = PointerGetDatum(arrayObject);
4745 Oid outputFunctionId = InvalidOid;
4746 bool typeVariableLength = false;
4747
4748 Oid arrayOutType = get_array_type(columnType);
4749 if (arrayOutType == InvalidOid)
4750 {
4751 char *columnTypeName = format_type_be(columnType);
4752 ereport(ERROR, (errmsg("cannot range repartition table on column type %s",
4753 columnTypeName)));
4754 }
4755
4756 FmgrInfo *arrayOutFunction = (FmgrInfo *) palloc0(sizeof(FmgrInfo));
4757 getTypeOutputInfo(arrayOutType, &outputFunctionId, &typeVariableLength);
4758 fmgr_info(outputFunctionId, arrayOutFunction);
4759
4760 char *arrayOutputText = OutputFunctionCall(arrayOutFunction, arrayDatum);
4761 char *arrayOutputEscapedText = quote_literal_cstr(arrayOutputText);
4762
4763 /* add an explicit cast to array's string representation */
4764 char *arrayOutTypeName = format_type_with_typemod(arrayOutType, columnTypeMod);
4765
4766 StringInfo arrayString = makeStringInfo();
4767 appendStringInfo(arrayString, "%s::%s",
4768 arrayOutputEscapedText, arrayOutTypeName);
4769
4770 return arrayString;
4771 }
4772
4773
4774 /*
4775 * MergeTaskList creates a list of merge tasks for the given MapMerge job. While
4776 * doing this, the function also establishes dependencies between each merge
4777 * task and its downstream map task dependencies by creating "map fetch" tasks.
4778 */
4779 static List *
MergeTaskList(MapMergeJob * mapMergeJob,List * mapTaskList,uint32 taskIdIndex)4780 MergeTaskList(MapMergeJob *mapMergeJob, List *mapTaskList, uint32 taskIdIndex)
4781 {
4782 List *mergeTaskList = NIL;
4783 uint64 jobId = mapMergeJob->job.jobId;
4784 uint32 partitionCount = mapMergeJob->partitionCount;
4785
4786 /* build column name and column type arrays (table schema) */
4787 Query *filterQuery = mapMergeJob->job.jobQuery;
4788 List *targetEntryList = filterQuery->targetList;
4789
4790 /* if all map tasks were pruned away, return NIL for merge tasks */
4791 if (mapTaskList == NIL)
4792 {
4793 return NIL;
4794 }
4795
4796 /*
4797 * XXX: We currently ignore the 0th partition bucket that range partitioning
4798 * generates. This bucket holds all values less than the minimum value or
4799 * NULLs, both of which we can currently ignore. However, when we support
4800 * range re-partitioned OUTER joins, we will need these rows for the
4801 * relation whose rows are retained in the OUTER join.
4802 */
4803 uint32 initialPartitionId = 0;
4804 if (mapMergeJob->partitionType == RANGE_PARTITION_TYPE)
4805 {
4806 initialPartitionId = 1;
4807 partitionCount = partitionCount + 1;
4808 }
4809 else if (mapMergeJob->partitionType == SINGLE_HASH_PARTITION_TYPE)
4810 {
4811 initialPartitionId = 0;
4812 }
4813
4814 /* build merge tasks and their associated "map output fetch" tasks */
4815 for (uint32 partitionId = initialPartitionId; partitionId < partitionCount;
4816 partitionId++)
4817 {
4818 Task *mergeTask = NULL;
4819 List *mapOutputFetchTaskList = NIL;
4820 ListCell *mapTaskCell = NULL;
4821 uint32 mergeTaskId = taskIdIndex;
4822
4823 Query *reduceQuery = mapMergeJob->reduceQuery;
4824 if (reduceQuery == NULL)
4825 {
4826 uint32 columnCount = (uint32) list_length(targetEntryList);
4827 StringInfo columnNames = ColumnNameArrayString(columnCount, jobId);
4828 StringInfo columnTypes = ColumnTypeArrayString(targetEntryList);
4829
4830 StringInfo mergeQueryString = makeStringInfo();
4831 appendStringInfo(mergeQueryString, MERGE_FILES_INTO_TABLE_COMMAND,
4832 jobId, taskIdIndex, columnNames->data, columnTypes->data);
4833
4834 /* create merge task */
4835 mergeTask = CreateBasicTask(jobId, mergeTaskId, MERGE_TASK,
4836 mergeQueryString->data);
4837 }
4838 mergeTask->partitionId = partitionId;
4839 taskIdIndex++;
4840
4841 /* create tasks to fetch map outputs to this merge task */
4842 foreach(mapTaskCell, mapTaskList)
4843 {
4844 Task *mapTask = (Task *) lfirst(mapTaskCell);
4845
4846 /* find the node name/port for map task's execution */
4847 List *mapTaskPlacementList = mapTask->taskPlacementList;
4848
4849 ShardPlacement *mapTaskPlacement = linitial(mapTaskPlacementList);
4850 char *mapTaskNodeName = mapTaskPlacement->nodeName;
4851 uint32 mapTaskNodePort = mapTaskPlacement->nodePort;
4852
4853 /*
4854 * We will use the first node even if replication factor is greater than 1
4855 * When replication factor is greater than 1 and there
4856 * is a connection problem to the node that has done the map task, we will get
4857 * an error in fetch task execution.
4858 */
4859 StringInfo mapFetchQueryString = makeStringInfo();
4860 appendStringInfo(mapFetchQueryString, MAP_OUTPUT_FETCH_COMMAND,
4861 mapTask->jobId, mapTask->taskId, partitionId,
4862 mergeTaskId, /* fetch results to merge task */
4863 mapTaskNodeName, mapTaskNodePort);
4864
4865 Task *mapOutputFetchTask = CreateBasicTask(jobId, taskIdIndex,
4866 MAP_OUTPUT_FETCH_TASK,
4867 mapFetchQueryString->data);
4868 mapOutputFetchTask->partitionId = partitionId;
4869 mapOutputFetchTask->upstreamTaskId = mergeTaskId;
4870 mapOutputFetchTask->dependentTaskList = list_make1(mapTask);
4871 taskIdIndex++;
4872
4873 mapOutputFetchTaskList = lappend(mapOutputFetchTaskList, mapOutputFetchTask);
4874 }
4875
4876 /* merge task depends on completion of fetch tasks */
4877 mergeTask->dependentTaskList = mapOutputFetchTaskList;
4878
4879 /* if single repartitioned, each merge task represents an interval */
4880 if (mapMergeJob->partitionType == RANGE_PARTITION_TYPE)
4881 {
4882 int32 mergeTaskIntervalId = partitionId - 1;
4883 ShardInterval **mergeTaskIntervals = mapMergeJob->sortedShardIntervalArray;
4884 Assert(mergeTaskIntervalId >= 0);
4885
4886 mergeTask->shardInterval = mergeTaskIntervals[mergeTaskIntervalId];
4887 }
4888 else if (mapMergeJob->partitionType == SINGLE_HASH_PARTITION_TYPE)
4889 {
4890 int32 mergeTaskIntervalId = partitionId;
4891 ShardInterval **mergeTaskIntervals = mapMergeJob->sortedShardIntervalArray;
4892 Assert(mergeTaskIntervalId >= 0);
4893
4894 mergeTask->shardInterval = mergeTaskIntervals[mergeTaskIntervalId];
4895 }
4896
4897 mergeTaskList = lappend(mergeTaskList, mergeTask);
4898 }
4899
4900 return mergeTaskList;
4901 }
4902
4903
4904 /*
4905 * ColumnNameArrayString creates a list of column names for a merged table, and
4906 * outputs this list of column names in their (array) string representation.
4907 */
4908 static StringInfo
ColumnNameArrayString(uint32 columnCount,uint64 generatingJobId)4909 ColumnNameArrayString(uint32 columnCount, uint64 generatingJobId)
4910 {
4911 Datum *columnNameArray = palloc0(columnCount * sizeof(Datum));
4912 uint32 columnNameIndex = 0;
4913
4914 /* build list of intermediate column names, generated by given jobId */
4915 List *columnNameList = DerivedColumnNameList(columnCount, generatingJobId);
4916
4917 ListCell *columnNameCell = NULL;
4918 foreach(columnNameCell, columnNameList)
4919 {
4920 Value *columnNameValue = (Value *) lfirst(columnNameCell);
4921 char *columnNameString = strVal(columnNameValue);
4922 Datum columnName = CStringGetDatum(columnNameString);
4923
4924 columnNameArray[columnNameIndex] = columnName;
4925 columnNameIndex++;
4926 }
4927
4928 StringInfo columnNameArrayString = DatumArrayString(columnNameArray, columnCount,
4929 CSTRINGOID);
4930
4931 return columnNameArrayString;
4932 }
4933
4934
4935 /*
4936 * ColumnTypeArrayString resolves a list of column types for a merged table, and
4937 * outputs this list of column types in their (array) string representation.
4938 */
4939 static StringInfo
ColumnTypeArrayString(List * targetEntryList)4940 ColumnTypeArrayString(List *targetEntryList)
4941 {
4942 ListCell *targetEntryCell = NULL;
4943
4944 uint32 columnCount = (uint32) list_length(targetEntryList);
4945 Datum *columnTypeArray = palloc0(columnCount * sizeof(Datum));
4946 uint32 columnTypeIndex = 0;
4947
4948 foreach(targetEntryCell, targetEntryList)
4949 {
4950 TargetEntry *targetEntry = (TargetEntry *) lfirst(targetEntryCell);
4951 Node *columnExpression = (Node *) targetEntry->expr;
4952 Oid columnTypeId = exprType(columnExpression);
4953 int32 columnTypeMod = exprTypmod(columnExpression);
4954
4955 char *columnTypeName = format_type_with_typemod(columnTypeId, columnTypeMod);
4956 Datum columnType = CStringGetDatum(columnTypeName);
4957
4958 columnTypeArray[columnTypeIndex] = columnType;
4959 columnTypeIndex++;
4960 }
4961
4962 StringInfo columnTypeArrayString = DatumArrayString(columnTypeArray, columnCount,
4963 CSTRINGOID);
4964
4965 return columnTypeArrayString;
4966 }
4967
4968
4969 /*
4970 * AssignTaskList assigns locations to given tasks based on dependencies between
4971 * tasks and configured task assignment policies. The function also handles the
4972 * case where multiple SQL tasks depend on the same merge task, and makes sure
4973 * that this group of multiple SQL tasks and the merge task are assigned to the
4974 * same location.
4975 */
4976 static List *
AssignTaskList(List * sqlTaskList)4977 AssignTaskList(List *sqlTaskList)
4978 {
4979 List *assignedSqlTaskList = NIL;
4980 bool hasAnchorShardId = false;
4981 ListCell *sqlTaskCell = NULL;
4982 List *primarySqlTaskList = NIL;
4983 ListCell *primarySqlTaskCell = NULL;
4984 ListCell *constrainedSqlTaskCell = NULL;
4985
4986 /* no tasks to assign */
4987 if (sqlTaskList == NIL)
4988 {
4989 return NIL;
4990 }
4991
4992 Task *firstSqlTask = (Task *) linitial(sqlTaskList);
4993 if (firstSqlTask->anchorShardId != INVALID_SHARD_ID)
4994 {
4995 hasAnchorShardId = true;
4996 }
4997
4998 /*
4999 * If these SQL tasks don't depend on any merge tasks, we can assign each
5000 * one independently of the other. We therefore go ahead and assign these
5001 * SQL tasks using the "anchor shard based" assignment algorithms.
5002 */
5003 bool hasMergeTaskDependencies = HasMergeTaskDependencies(sqlTaskList);
5004 if (!hasMergeTaskDependencies)
5005 {
5006 Assert(hasAnchorShardId);
5007
5008 assignedSqlTaskList = AssignAnchorShardTaskList(sqlTaskList);
5009
5010 return assignedSqlTaskList;
5011 }
5012
5013 /*
5014 * SQL tasks can depend on merge tasks in one of two ways: (1) each SQL task
5015 * depends on merge task(s) that no other SQL task depends upon, (2) several
5016 * SQL tasks depend on the same merge task(s) and all need to be assigned to
5017 * the same worker node. To handle the second case, we first pick a primary
5018 * SQL task among those that depend on the same merge task, and assign it.
5019 */
5020 foreach(sqlTaskCell, sqlTaskList)
5021 {
5022 Task *sqlTask = (Task *) lfirst(sqlTaskCell);
5023 List *mergeTaskList = FindDependentMergeTaskList(sqlTask);
5024
5025 Task *firstMergeTask = (Task *) linitial(mergeTaskList);
5026 if (!firstMergeTask->assignmentConstrained)
5027 {
5028 firstMergeTask->assignmentConstrained = true;
5029
5030 primarySqlTaskList = lappend(primarySqlTaskList, sqlTask);
5031 }
5032 }
5033
5034 if (hasAnchorShardId)
5035 {
5036 primarySqlTaskList = AssignAnchorShardTaskList(primarySqlTaskList);
5037 }
5038 else
5039 {
5040 primarySqlTaskList = AssignDualHashTaskList(primarySqlTaskList);
5041 }
5042
5043 /* propagate SQL task assignments to the merge tasks we depend upon */
5044 foreach(primarySqlTaskCell, primarySqlTaskList)
5045 {
5046 Task *sqlTask = (Task *) lfirst(primarySqlTaskCell);
5047 List *mergeTaskList = FindDependentMergeTaskList(sqlTask);
5048
5049 ListCell *mergeTaskCell = NULL;
5050 foreach(mergeTaskCell, mergeTaskList)
5051 {
5052 Task *mergeTask = (Task *) lfirst(mergeTaskCell);
5053 Assert(mergeTask->taskPlacementList == NIL);
5054
5055 mergeTask->taskPlacementList = list_copy(sqlTask->taskPlacementList);
5056 }
5057
5058 assignedSqlTaskList = lappend(assignedSqlTaskList, sqlTask);
5059 }
5060
5061 /*
5062 * If we had a set of SQL tasks depending on the same merge task, we only
5063 * assigned one SQL task from that set. We call the assigned SQL task the
5064 * primary, and note that the remaining SQL tasks are constrained by the
5065 * primary's task assignment. We propagate the primary's task assignment in
5066 * each set to the remaining (constrained) tasks.
5067 */
5068 List *constrainedSqlTaskList = TaskListDifference(sqlTaskList, primarySqlTaskList);
5069
5070 foreach(constrainedSqlTaskCell, constrainedSqlTaskList)
5071 {
5072 Task *sqlTask = (Task *) lfirst(constrainedSqlTaskCell);
5073 List *mergeTaskList = FindDependentMergeTaskList(sqlTask);
5074 List *mergeTaskPlacementList = NIL;
5075
5076 ListCell *mergeTaskCell = NULL;
5077 foreach(mergeTaskCell, mergeTaskList)
5078 {
5079 Task *mergeTask = (Task *) lfirst(mergeTaskCell);
5080
5081 /*
5082 * If we have more than one merge task, both of them should have the
5083 * same task placement list.
5084 */
5085 mergeTaskPlacementList = mergeTask->taskPlacementList;
5086 Assert(mergeTaskPlacementList != NIL);
5087
5088 ereport(DEBUG3, (errmsg("propagating assignment from merge task %d "
5089 "to constrained sql task %d",
5090 mergeTask->taskId, sqlTask->taskId)));
5091 }
5092
5093 sqlTask->taskPlacementList = list_copy(mergeTaskPlacementList);
5094
5095 assignedSqlTaskList = lappend(assignedSqlTaskList, sqlTask);
5096 }
5097
5098 return assignedSqlTaskList;
5099 }
5100
5101
5102 /*
5103 * HasMergeTaskDependencies checks if sql tasks in the given sql task list have
5104 * any dependencies on merge tasks. If they do, the function returns true.
5105 */
5106 static bool
HasMergeTaskDependencies(List * sqlTaskList)5107 HasMergeTaskDependencies(List *sqlTaskList)
5108 {
5109 bool hasMergeTaskDependencies = false;
5110 Task *sqlTask = (Task *) linitial(sqlTaskList);
5111 List *dependentTaskList = sqlTask->dependentTaskList;
5112
5113 ListCell *dependentTaskCell = NULL;
5114 foreach(dependentTaskCell, dependentTaskList)
5115 {
5116 Task *dependentTask = (Task *) lfirst(dependentTaskCell);
5117 if (dependentTask->taskType == MERGE_TASK)
5118 {
5119 hasMergeTaskDependencies = true;
5120 break;
5121 }
5122 }
5123
5124 return hasMergeTaskDependencies;
5125 }
5126
5127
5128 /* Return true if two tasks are equal, false otherwise. */
5129 bool
TasksEqual(const Task * a,const Task * b)5130 TasksEqual(const Task *a, const Task *b)
5131 {
5132 Assert(CitusIsA(a, Task));
5133 Assert(CitusIsA(b, Task));
5134
5135 if (a->taskType != b->taskType)
5136 {
5137 return false;
5138 }
5139 if (a->jobId != b->jobId)
5140 {
5141 return false;
5142 }
5143 if (a->taskId != b->taskId)
5144 {
5145 return false;
5146 }
5147
5148 return true;
5149 }
5150
5151
5152 /* Is the passed in Task a member of the list. */
5153 bool
TaskListMember(const List * taskList,const Task * task)5154 TaskListMember(const List *taskList, const Task *task)
5155 {
5156 const ListCell *taskCell = NULL;
5157
5158 foreach(taskCell, taskList)
5159 {
5160 if (TasksEqual((Task *) lfirst(taskCell), task))
5161 {
5162 return true;
5163 }
5164 }
5165
5166 return false;
5167 }
5168
5169
5170 /*
5171 * TaskListDifference returns a list that contains all the tasks in taskList1
5172 * that are not in taskList2. The returned list is freshly allocated via
5173 * palloc(), but the cells themselves point to the same objects as the cells
5174 * of the input lists.
5175 */
5176 List *
TaskListDifference(const List * list1,const List * list2)5177 TaskListDifference(const List *list1, const List *list2)
5178 {
5179 const ListCell *taskCell = NULL;
5180 List *resultList = NIL;
5181
5182 if (list2 == NIL)
5183 {
5184 return list_copy(list1);
5185 }
5186
5187 foreach(taskCell, list1)
5188 {
5189 if (!TaskListMember(list2, lfirst(taskCell)))
5190 {
5191 resultList = lappend(resultList, lfirst(taskCell));
5192 }
5193 }
5194
5195 return resultList;
5196 }
5197
5198
5199 /*
5200 * AssignAnchorShardTaskList assigns locations to the given tasks based on the
5201 * configured task assignment policy. The distributed executor later sends these
5202 * tasks to their assigned locations for remote execution.
5203 */
5204 List *
AssignAnchorShardTaskList(List * taskList)5205 AssignAnchorShardTaskList(List *taskList)
5206 {
5207 List *assignedTaskList = NIL;
5208
5209 /* choose task assignment policy based on config value */
5210 if (TaskAssignmentPolicy == TASK_ASSIGNMENT_GREEDY)
5211 {
5212 assignedTaskList = GreedyAssignTaskList(taskList);
5213 }
5214 else if (TaskAssignmentPolicy == TASK_ASSIGNMENT_FIRST_REPLICA)
5215 {
5216 assignedTaskList = FirstReplicaAssignTaskList(taskList);
5217 }
5218 else if (TaskAssignmentPolicy == TASK_ASSIGNMENT_ROUND_ROBIN)
5219 {
5220 assignedTaskList = RoundRobinAssignTaskList(taskList);
5221 }
5222
5223 Assert(assignedTaskList != NIL);
5224 return assignedTaskList;
5225 }
5226
5227
5228 /*
5229 * GreedyAssignTaskList uses a greedy algorithm similar to Hadoop's, and assigns
5230 * locations to the given tasks. The ideal assignment algorithm balances three
5231 * properties: (a) determinism, (b) even load distribution, and (c) consistency
5232 * across similar task lists. To maintain these properties, the algorithm sorts
5233 * all its input lists.
5234 */
5235 static List *
GreedyAssignTaskList(List * taskList)5236 GreedyAssignTaskList(List *taskList)
5237 {
5238 List *assignedTaskList = NIL;
5239 uint32 assignedTaskCount = 0;
5240 uint32 taskCount = list_length(taskList);
5241
5242 /* get the worker node list and sort the list */
5243 List *workerNodeList = ActiveReadableNodeList();
5244 workerNodeList = SortList(workerNodeList, CompareWorkerNodes);
5245
5246 /*
5247 * We first sort tasks by their anchor shard id. We then walk over each task
5248 * in the sorted list, get the task's anchor shard id, and look up the shard
5249 * placements (locations) for this shard id. Next, we sort the placements by
5250 * their insertion time, and append them to a new list.
5251 */
5252 taskList = SortList(taskList, CompareTasksByShardId);
5253 List *activeShardPlacementLists = ActiveShardPlacementLists(taskList);
5254
5255 while (assignedTaskCount < taskCount)
5256 {
5257 ListCell *workerNodeCell = NULL;
5258 uint32 loopStartTaskCount = assignedTaskCount;
5259
5260 /* walk over each node and check if we can assign a task to it */
5261 foreach(workerNodeCell, workerNodeList)
5262 {
5263 WorkerNode *workerNode = (WorkerNode *) lfirst(workerNodeCell);
5264
5265 Task *assignedTask = GreedyAssignTask(workerNode, taskList,
5266 activeShardPlacementLists);
5267 if (assignedTask != NULL)
5268 {
5269 assignedTaskList = lappend(assignedTaskList, assignedTask);
5270 assignedTaskCount++;
5271 }
5272 }
5273
5274 /* if we could not assign any new tasks, avoid looping forever */
5275 if (assignedTaskCount == loopStartTaskCount)
5276 {
5277 uint32 remainingTaskCount = taskCount - assignedTaskCount;
5278 ereport(ERROR, (errmsg("failed to assign %u task(s) to worker nodes",
5279 remainingTaskCount)));
5280 }
5281 }
5282
5283 return assignedTaskList;
5284 }
5285
5286
5287 /*
5288 * GreedyAssignTask tries to assign a task to the given worker node. To do this,
5289 * the function walks over tasks' anchor shard ids, and finds the first set of
5290 * nodes the shards were replicated to. If any of these replica nodes and the
5291 * given worker node match, the corresponding task is assigned to that node. If
5292 * not, the function goes on to search the second set of replicas and so forth.
5293 *
5294 * Note that this function has side-effects; when the function assigns a new
5295 * task, it overwrites the corresponding task list pointer.
5296 */
5297 static Task *
GreedyAssignTask(WorkerNode * workerNode,List * taskList,List * activeShardPlacementLists)5298 GreedyAssignTask(WorkerNode *workerNode, List *taskList, List *activeShardPlacementLists)
5299 {
5300 Task *assignedTask = NULL;
5301 List *taskPlacementList = NIL;
5302 ShardPlacement *primaryPlacement = NULL;
5303 uint32 rotatePlacementListBy = 0;
5304 uint32 replicaIndex = 0;
5305 uint32 replicaCount = ShardReplicationFactor;
5306 const char *workerName = workerNode->workerName;
5307 const uint32 workerPort = workerNode->workerPort;
5308
5309 while ((assignedTask == NULL) && (replicaIndex < replicaCount))
5310 {
5311 /* walk over all tasks and try to assign one */
5312 ListCell *taskCell = NULL;
5313 ListCell *placementListCell = NULL;
5314
5315 forboth(taskCell, taskList, placementListCell, activeShardPlacementLists)
5316 {
5317 Task *task = (Task *) lfirst(taskCell);
5318 List *placementList = (List *) lfirst(placementListCell);
5319
5320 /* check if we already assigned this task */
5321 if (task == NULL)
5322 {
5323 continue;
5324 }
5325
5326 /* check if we have enough replicas */
5327 uint32 placementCount = list_length(placementList);
5328 if (placementCount <= replicaIndex)
5329 {
5330 continue;
5331 }
5332
5333 ShardPlacement *placement = (ShardPlacement *) list_nth(placementList,
5334 replicaIndex);
5335 if ((strncmp(placement->nodeName, workerName, WORKER_LENGTH) == 0) &&
5336 (placement->nodePort == workerPort))
5337 {
5338 /* we found a task to assign to the given worker node */
5339 assignedTask = task;
5340 taskPlacementList = placementList;
5341 rotatePlacementListBy = replicaIndex;
5342
5343 /* overwrite task list to signal that this task is assigned */
5344 SetListCellPtr(taskCell, NULL);
5345 break;
5346 }
5347 }
5348
5349 /* go over the next set of shard replica placements */
5350 replicaIndex++;
5351 }
5352
5353 /* if we found a task placement list, rotate and assign task placements */
5354 if (assignedTask != NULL)
5355 {
5356 taskPlacementList = LeftRotateList(taskPlacementList, rotatePlacementListBy);
5357 assignedTask->taskPlacementList = taskPlacementList;
5358
5359 primaryPlacement = (ShardPlacement *) linitial(assignedTask->taskPlacementList);
5360 ereport(DEBUG3, (errmsg("assigned task %u to node %s:%u", assignedTask->taskId,
5361 primaryPlacement->nodeName,
5362 primaryPlacement->nodePort)));
5363 }
5364
5365 return assignedTask;
5366 }
5367
5368
5369 /*
5370 * FirstReplicaAssignTaskList assigns locations to the given tasks simply by
5371 * looking at placements for a given shard. A particular task's assignments are
5372 * then ordered by the insertion order of the relevant placements rows. In other
5373 * words, a task for a specific shard is simply assigned to the first replica
5374 * for that shard. This algorithm is extremely simple and intended for use when
5375 * a customer has placed shards carefully and wants strong guarantees about
5376 * which shards will be used by what nodes (i.e. for stronger memory residency
5377 * guarantees).
5378 */
5379 List *
FirstReplicaAssignTaskList(List * taskList)5380 FirstReplicaAssignTaskList(List *taskList)
5381 {
5382 /* No additional reordering need take place for this algorithm */
5383 ReorderFunction reorderFunction = NULL;
5384
5385 taskList = ReorderAndAssignTaskList(taskList, reorderFunction);
5386
5387 return taskList;
5388 }
5389
5390
5391 /*
5392 * RoundRobinAssignTaskList uses a round-robin algorithm to assign locations to
5393 * the given tasks. An ideal round-robin implementation requires keeping shared
5394 * state for task assignments; and we instead approximate our implementation by
5395 * relying on the sequentially increasing jobId. For each task, we mod its jobId
5396 * by the number of active shard placements, and ensure that we rotate between
5397 * these placements across subsequent queries.
5398 */
5399 List *
RoundRobinAssignTaskList(List * taskList)5400 RoundRobinAssignTaskList(List *taskList)
5401 {
5402 taskList = ReorderAndAssignTaskList(taskList, RoundRobinReorder);
5403
5404 return taskList;
5405 }
5406
5407
5408 /*
5409 * RoundRobinReorder implements the core of the round-robin assignment policy.
5410 * It takes a placement list and rotates a copy of it based on the latest stable
5411 * transaction id provided by PostgreSQL.
5412 *
5413 * We prefer to use transactionId as the seed for the rotation to use the replicas
5414 * in the same worker node within the same transaction. This becomes more important
5415 * when we're reading from (the same or multiple) reference tables within a
5416 * transaction. With this approach, we can prevent reads to expand the worker nodes
5417 * that participate in a distributed transaction.
5418 *
5419 * Note that we prefer PostgreSQL's transactionId over distributed transactionId that
5420 * Citus generates since the distributed transactionId is generated during the execution
5421 * where as task-assignment happens duing the planning.
5422 */
5423 List *
RoundRobinReorder(List * placementList)5424 RoundRobinReorder(List *placementList)
5425 {
5426 TransactionId transactionId = GetMyProcLocalTransactionId();
5427 uint32 activePlacementCount = list_length(placementList);
5428 uint32 roundRobinIndex = (transactionId % activePlacementCount);
5429
5430 placementList = LeftRotateList(placementList, roundRobinIndex);
5431
5432 return placementList;
5433 }
5434
5435
5436 /*
5437 * ReorderAndAssignTaskList finds the placements for a task based on its anchor
5438 * shard id and then sorts them by insertion time. If reorderFunction is given,
5439 * it is used to reorder the placements list in a custom fashion (for instance,
5440 * by rotation or shuffling). Returns the task list with placements assigned.
5441 */
5442 static List *
ReorderAndAssignTaskList(List * taskList,ReorderFunction reorderFunction)5443 ReorderAndAssignTaskList(List *taskList, ReorderFunction reorderFunction)
5444 {
5445 List *assignedTaskList = NIL;
5446 ListCell *taskCell = NULL;
5447 ListCell *placementListCell = NULL;
5448 uint32 unAssignedTaskCount = 0;
5449
5450 if (taskList == NIL)
5451 {
5452 return NIL;
5453 }
5454
5455 /*
5456 * We first sort tasks by their anchor shard id. We then sort placements for
5457 * each anchor shard by the placement's insertion time. Note that we sort
5458 * these lists just to make our policy more deterministic.
5459 */
5460 taskList = SortList(taskList, CompareTasksByShardId);
5461 List *activeShardPlacementLists = ActiveShardPlacementLists(taskList);
5462
5463 forboth(taskCell, taskList, placementListCell, activeShardPlacementLists)
5464 {
5465 Task *task = (Task *) lfirst(taskCell);
5466 List *placementList = (List *) lfirst(placementListCell);
5467
5468 /* inactive placements are already filtered out */
5469 uint32 activePlacementCount = list_length(placementList);
5470 if (activePlacementCount > 0)
5471 {
5472 if (reorderFunction != NULL)
5473 {
5474 placementList = reorderFunction(placementList);
5475 }
5476 task->taskPlacementList = placementList;
5477
5478 ShardPlacement *primaryPlacement = (ShardPlacement *) linitial(
5479 task->taskPlacementList);
5480 ereport(DEBUG3, (errmsg("assigned task %u to node %s:%u", task->taskId,
5481 primaryPlacement->nodeName,
5482 primaryPlacement->nodePort)));
5483
5484 assignedTaskList = lappend(assignedTaskList, task);
5485 }
5486 else
5487 {
5488 unAssignedTaskCount++;
5489 }
5490 }
5491
5492 /* if we have unassigned tasks, error out */
5493 if (unAssignedTaskCount > 0)
5494 {
5495 ereport(ERROR, (errmsg("failed to assign %u task(s) to worker nodes",
5496 unAssignedTaskCount)));
5497 }
5498
5499 return assignedTaskList;
5500 }
5501
5502
5503 /* Helper function to compare two tasks by their anchor shardId. */
5504 static int
CompareTasksByShardId(const void * leftElement,const void * rightElement)5505 CompareTasksByShardId(const void *leftElement, const void *rightElement)
5506 {
5507 const Task *leftTask = *((const Task **) leftElement);
5508 const Task *rightTask = *((const Task **) rightElement);
5509
5510 uint64 leftShardId = leftTask->anchorShardId;
5511 uint64 rightShardId = rightTask->anchorShardId;
5512
5513 /* we compare 64-bit integers, instead of casting their difference to int */
5514 if (leftShardId > rightShardId)
5515 {
5516 return 1;
5517 }
5518 else if (leftShardId < rightShardId)
5519 {
5520 return -1;
5521 }
5522 else
5523 {
5524 return 0;
5525 }
5526 }
5527
5528
5529 /*
5530 * ActiveShardPlacementLists finds the active shard placement list for each task in
5531 * the given task list, sorts each shard placement list by shard creation time,
5532 * and adds the sorted placement list into a new list of lists. The function also
5533 * ensures a one-to-one mapping between each placement list in the new list of
5534 * lists and each task in the given task list.
5535 */
5536 static List *
ActiveShardPlacementLists(List * taskList)5537 ActiveShardPlacementLists(List *taskList)
5538 {
5539 List *shardPlacementLists = NIL;
5540 ListCell *taskCell = NULL;
5541
5542 foreach(taskCell, taskList)
5543 {
5544 Task *task = (Task *) lfirst(taskCell);
5545 uint64 anchorShardId = task->anchorShardId;
5546 List *shardPlacementList = ActiveShardPlacementList(anchorShardId);
5547
5548 /* filter out shard placements that reside in inactive nodes */
5549 List *activeShardPlacementList = ActivePlacementList(shardPlacementList);
5550 if (activeShardPlacementList == NIL)
5551 {
5552 ereport(ERROR,
5553 (errmsg("no active placements were found for shard " UINT64_FORMAT,
5554 anchorShardId)));
5555 }
5556
5557 /* sort shard placements by their creation time */
5558 activeShardPlacementList = SortList(activeShardPlacementList,
5559 CompareShardPlacements);
5560 shardPlacementLists = lappend(shardPlacementLists, activeShardPlacementList);
5561 }
5562
5563 return shardPlacementLists;
5564 }
5565
5566
5567 /*
5568 * CompareShardPlacements compares two shard placements by their tuple oid; this
5569 * oid reflects the tuple's insertion order into pg_dist_placement.
5570 */
5571 int
CompareShardPlacements(const void * leftElement,const void * rightElement)5572 CompareShardPlacements(const void *leftElement, const void *rightElement)
5573 {
5574 const ShardPlacement *leftPlacement = *((const ShardPlacement **) leftElement);
5575 const ShardPlacement *rightPlacement = *((const ShardPlacement **) rightElement);
5576
5577 uint64 leftPlacementId = leftPlacement->placementId;
5578 uint64 rightPlacementId = rightPlacement->placementId;
5579
5580 if (leftPlacementId < rightPlacementId)
5581 {
5582 return -1;
5583 }
5584 else if (leftPlacementId > rightPlacementId)
5585 {
5586 return 1;
5587 }
5588 else
5589 {
5590 return 0;
5591 }
5592 }
5593
5594
5595 /*
5596 * ActivePlacementList walks over shard placements in the given list, and finds
5597 * the corresponding worker node for each placement. The function then checks if
5598 * that worker node is active, and if it is, appends the placement to a new list.
5599 * The function last returns the new placement list.
5600 */
5601 static List *
ActivePlacementList(List * placementList)5602 ActivePlacementList(List *placementList)
5603 {
5604 List *activePlacementList = NIL;
5605 ListCell *placementCell = NULL;
5606
5607 foreach(placementCell, placementList)
5608 {
5609 ShardPlacement *placement = (ShardPlacement *) lfirst(placementCell);
5610
5611 /* check if the worker node for this shard placement is active */
5612 WorkerNode *workerNode = FindWorkerNode(placement->nodeName, placement->nodePort);
5613 if (workerNode != NULL && workerNode->isActive)
5614 {
5615 activePlacementList = lappend(activePlacementList, placement);
5616 }
5617 }
5618
5619 return activePlacementList;
5620 }
5621
5622
5623 /*
5624 * LeftRotateList returns a copy of the given list that has been cyclically
5625 * shifted to the left by the given rotation count. For this, the function
5626 * repeatedly moves the list's first element to the end of the list, and
5627 * then returns the newly rotated list.
5628 */
5629 static List *
LeftRotateList(List * list,uint32 rotateCount)5630 LeftRotateList(List *list, uint32 rotateCount)
5631 {
5632 List *rotatedList = list_copy(list);
5633
5634 for (uint32 rotateIndex = 0; rotateIndex < rotateCount; rotateIndex++)
5635 {
5636 void *firstElement = linitial(rotatedList);
5637
5638 rotatedList = list_delete_first(rotatedList);
5639 rotatedList = lappend(rotatedList, firstElement);
5640 }
5641
5642 return rotatedList;
5643 }
5644
5645
5646 /*
5647 * FindDependentMergeTaskList walks over the given task's dependent task list,
5648 * finds the merge tasks in the list, and returns those found tasks in a new
5649 * list.
5650 */
5651 static List *
FindDependentMergeTaskList(Task * sqlTask)5652 FindDependentMergeTaskList(Task *sqlTask)
5653 {
5654 List *dependentMergeTaskList = NIL;
5655 List *dependentTaskList = sqlTask->dependentTaskList;
5656
5657 ListCell *dependentTaskCell = NULL;
5658 foreach(dependentTaskCell, dependentTaskList)
5659 {
5660 Task *dependentTask = (Task *) lfirst(dependentTaskCell);
5661 if (dependentTask->taskType == MERGE_TASK)
5662 {
5663 dependentMergeTaskList = lappend(dependentMergeTaskList, dependentTask);
5664 }
5665 }
5666
5667 return dependentMergeTaskList;
5668 }
5669
5670
5671 /*
5672 * AssignDualHashTaskList uses a round-robin algorithm to assign locations to
5673 * tasks; these tasks don't have any anchor shards and instead operate on (hash
5674 * repartitioned) merged tables.
5675 */
5676 static List *
AssignDualHashTaskList(List * taskList)5677 AssignDualHashTaskList(List *taskList)
5678 {
5679 List *assignedTaskList = NIL;
5680 ListCell *taskCell = NULL;
5681 Task *firstTask = (Task *) linitial(taskList);
5682 uint64 jobId = firstTask->jobId;
5683 uint32 assignedTaskIndex = 0;
5684
5685 /*
5686 * We start assigning tasks at an index determined by the jobId. This way,
5687 * if subsequent jobs have a small number of tasks, we won't allocate the
5688 * tasks to the same worker repeatedly.
5689 */
5690 List *workerNodeList = ActiveReadableNodeList();
5691 uint32 workerNodeCount = (uint32) list_length(workerNodeList);
5692 uint32 beginningNodeIndex = jobId % workerNodeCount;
5693
5694 /* sort worker node list and task list for deterministic results */
5695 workerNodeList = SortList(workerNodeList, CompareWorkerNodes);
5696 taskList = SortList(taskList, CompareTasksByTaskId);
5697
5698 foreach(taskCell, taskList)
5699 {
5700 Task *task = (Task *) lfirst(taskCell);
5701 List *taskPlacementList = NIL;
5702
5703 for (uint32 replicaIndex = 0; replicaIndex < ShardReplicationFactor;
5704 replicaIndex++)
5705 {
5706 uint32 assignmentOffset = beginningNodeIndex + assignedTaskIndex +
5707 replicaIndex;
5708 uint32 assignmentIndex = assignmentOffset % workerNodeCount;
5709 WorkerNode *workerNode = list_nth(workerNodeList, assignmentIndex);
5710
5711 ShardPlacement *taskPlacement = CitusMakeNode(ShardPlacement);
5712 SetPlacementNodeMetadata(taskPlacement, workerNode);
5713
5714 taskPlacementList = lappend(taskPlacementList, taskPlacement);
5715 }
5716
5717 task->taskPlacementList = taskPlacementList;
5718
5719 ShardPlacement *primaryPlacement = (ShardPlacement *) linitial(
5720 task->taskPlacementList);
5721 ereport(DEBUG3, (errmsg("assigned task %u to node %s:%u", task->taskId,
5722 primaryPlacement->nodeName,
5723 primaryPlacement->nodePort)));
5724
5725 assignedTaskList = lappend(assignedTaskList, task);
5726 assignedTaskIndex++;
5727 }
5728
5729 return assignedTaskList;
5730 }
5731
5732
5733 /*
5734 * SetPlacementNodeMetadata sets nodename, nodeport, nodeid and groupid for the placement.
5735 */
5736 void
SetPlacementNodeMetadata(ShardPlacement * placement,WorkerNode * workerNode)5737 SetPlacementNodeMetadata(ShardPlacement *placement, WorkerNode *workerNode)
5738 {
5739 placement->nodeName = pstrdup(workerNode->workerName);
5740 placement->nodePort = workerNode->workerPort;
5741 placement->nodeId = workerNode->nodeId;
5742 placement->groupId = workerNode->groupId;
5743 }
5744
5745
5746 /*
5747 * CompareTasksByTaskId is a helper function to compare two tasks by their taskId.
5748 */
5749 int
CompareTasksByTaskId(const void * leftElement,const void * rightElement)5750 CompareTasksByTaskId(const void *leftElement, const void *rightElement)
5751 {
5752 const Task *leftTask = *((const Task **) leftElement);
5753 const Task *rightTask = *((const Task **) rightElement);
5754
5755 uint32 leftTaskId = leftTask->taskId;
5756 uint32 rightTaskId = rightTask->taskId;
5757
5758 int taskIdDiff = leftTaskId - rightTaskId;
5759 return taskIdDiff;
5760 }
5761
5762
5763 /*
5764 * AssignDataFetchDependencies walks over tasks in the given sql or merge task
5765 * list. The function then propagates worker node assignments from each sql or
5766 * merge task to the task's data fetch dependencies.
5767 */
5768 static void
AssignDataFetchDependencies(List * taskList)5769 AssignDataFetchDependencies(List *taskList)
5770 {
5771 ListCell *taskCell = NULL;
5772 foreach(taskCell, taskList)
5773 {
5774 Task *task = (Task *) lfirst(taskCell);
5775 List *dependentTaskList = task->dependentTaskList;
5776 ListCell *dependentTaskCell = NULL;
5777
5778 Assert(task->taskPlacementList != NIL);
5779 Assert(task->taskType == READ_TASK || task->taskType == MERGE_TASK);
5780
5781 foreach(dependentTaskCell, dependentTaskList)
5782 {
5783 Task *dependentTask = (Task *) lfirst(dependentTaskCell);
5784 if (dependentTask->taskType == MAP_OUTPUT_FETCH_TASK)
5785 {
5786 dependentTask->taskPlacementList = task->taskPlacementList;
5787 }
5788 }
5789 }
5790 }
5791
5792
5793 /*
5794 * TaskListHighestTaskId walks over tasks in the given task list, finds the task
5795 * that has the largest taskId, and returns that taskId.
5796 *
5797 * Note: This function assumes that the dependent taskId's are set before the
5798 * taskId's for the given task list.
5799 */
5800 static uint32
TaskListHighestTaskId(List * taskList)5801 TaskListHighestTaskId(List *taskList)
5802 {
5803 uint32 highestTaskId = 0;
5804 ListCell *taskCell = NULL;
5805
5806 foreach(taskCell, taskList)
5807 {
5808 Task *task = (Task *) lfirst(taskCell);
5809 if (task->taskId > highestTaskId)
5810 {
5811 highestTaskId = task->taskId;
5812 }
5813 }
5814
5815 return highestTaskId;
5816 }
5817