1 
2 /*-------------------------------------------------------------------------
3  *
4  * multi_router_planner.c
5  *
6  * This file contains functions to plan multiple shard queries without any
7  * aggregation step including distributed table modifications.
8  *
9  * Copyright (c) Citus Data, Inc.
10  *
11  *-------------------------------------------------------------------------
12  */
13 
14 #include "postgres.h"
15 
16 #include "distributed/pg_version_constants.h"
17 
18 #include <stddef.h>
19 
20 #include "access/stratnum.h"
21 #include "access/xact.h"
22 #include "catalog/pg_opfamily.h"
23 #include "catalog/pg_type.h"
24 #include "distributed/colocation_utils.h"
25 #include "distributed/citus_clauses.h"
26 #include "distributed/citus_nodes.h"
27 #include "distributed/citus_nodefuncs.h"
28 #include "distributed/deparse_shard_query.h"
29 #include "distributed/distribution_column.h"
30 #include "distributed/errormessage.h"
31 #include "distributed/log_utils.h"
32 #include "distributed/insert_select_planner.h"
33 #include "distributed/intermediate_result_pruning.h"
34 #include "distributed/metadata_utility.h"
35 #include "distributed/coordinator_protocol.h"
36 #include "distributed/metadata_cache.h"
37 #include "distributed/multi_executor.h"
38 #include "distributed/multi_join_order.h"
39 #include "distributed/multi_logical_planner.h"
40 #include "distributed/multi_logical_optimizer.h"
41 #include "distributed/multi_partitioning_utils.h"
42 #include "distributed/multi_physical_planner.h"
43 #include "distributed/multi_router_planner.h"
44 #include "distributed/multi_server_executor.h"
45 #include "distributed/listutils.h"
46 #include "distributed/citus_ruleutils.h"
47 #include "distributed/query_pushdown_planning.h"
48 #include "distributed/query_utils.h"
49 #include "distributed/reference_table_utils.h"
50 #include "distributed/relation_restriction_equivalence.h"
51 #include "distributed/relay_utility.h"
52 #include "distributed/recursive_planning.h"
53 #include "distributed/resource_lock.h"
54 #include "distributed/shardinterval_utils.h"
55 #include "distributed/shard_pruning.h"
56 #include "executor/execdesc.h"
57 #include "lib/stringinfo.h"
58 #include "nodes/makefuncs.h"
59 #include "nodes/nodeFuncs.h"
60 #include "nodes/nodes.h"
61 #include "nodes/parsenodes.h"
62 #include "nodes/pg_list.h"
63 #include "nodes/primnodes.h"
64 #include "optimizer/clauses.h"
65 #include "optimizer/joininfo.h"
66 #include "optimizer/pathnode.h"
67 #include "optimizer/paths.h"
68 #include "optimizer/optimizer.h"
69 #include "optimizer/restrictinfo.h"
70 #include "parser/parsetree.h"
71 #include "parser/parse_oper.h"
72 #include "postmaster/postmaster.h"
73 #include "storage/lock.h"
74 #include "utils/builtins.h"
75 #include "utils/elog.h"
76 #include "utils/errcodes.h"
77 #include "utils/lsyscache.h"
78 #include "utils/rel.h"
79 #include "utils/typcache.h"
80 
81 #include "catalog/pg_proc.h"
82 #include "optimizer/planmain.h"
83 
84 /* intermediate value for INSERT processing */
85 typedef struct InsertValues
86 {
87 	Expr *partitionValueExpr; /* partition value provided in INSERT row */
88 	List *rowValues;          /* full values list of INSERT row, possibly NIL */
89 	int64 shardId;            /* target shard for this row, possibly invalid */
90 	Index listIndex;          /* index to make our sorting stable */
91 } InsertValues;
92 
93 
94 /*
95  * A ModifyRoute encapsulates the information needed to route modifications
96  * to the appropriate shard. For a single-shard modification, only one route
97  * is needed, but in the case of e.g. a multi-row INSERT, lists of these values
98  * will help divide the rows by their destination shards, permitting later
99  * shard-and-row-specific extension of the original SQL.
100  */
101 typedef struct ModifyRoute
102 {
103 	int64 shardId;        /* identifier of target shard */
104 	List *rowValuesLists; /* for multi-row INSERTs, list of rows to be inserted */
105 } ModifyRoute;
106 
107 
108 typedef struct WalkerState
109 {
110 	bool containsVar;
111 	bool varArgument;
112 	bool badCoalesce;
113 } WalkerState;
114 
115 bool EnableRouterExecution = true;
116 
117 
118 /* planner functions forward declarations */
119 static void CreateSingleTaskRouterSelectPlan(DistributedPlan *distributedPlan,
120 											 Query *originalQuery,
121 											 Query *query,
122 											 PlannerRestrictionContext *
123 											 plannerRestrictionContext);
124 static Oid ResultRelationOidForQuery(Query *query);
125 static bool IsTidColumn(Node *node);
126 static DeferredErrorMessage * ModifyPartialQuerySupported(Query *queryTree, bool
127 														  multiShardQuery,
128 														  Oid *distributedTableId);
129 static bool NodeIsFieldStore(Node *node);
130 static DeferredErrorMessage * MultiShardUpdateDeleteSupported(Query *originalQuery,
131 															  PlannerRestrictionContext *
132 															  plannerRestrictionContext);
133 static DeferredErrorMessage * SingleShardUpdateDeleteSupported(Query *originalQuery,
134 															   PlannerRestrictionContext *
135 															   plannerRestrictionContext);
136 static bool HasDangerousJoinUsing(List *rtableList, Node *jtnode);
137 static bool MasterIrreducibleExpression(Node *expression, bool *varArgument,
138 										bool *badCoalesce);
139 static bool MasterIrreducibleExpressionWalker(Node *expression, WalkerState *state);
140 static bool MasterIrreducibleExpressionFunctionChecker(Oid func_id, void *context);
141 static bool TargetEntryChangesValue(TargetEntry *targetEntry, Var *column,
142 									FromExpr *joinTree);
143 static Job * RouterInsertJob(Query *originalQuery);
144 static void ErrorIfNoShardsExist(CitusTableCacheEntry *cacheEntry);
145 static DeferredErrorMessage * DeferErrorIfModifyView(Query *queryTree);
146 static Job * CreateJob(Query *query);
147 static Task * CreateTask(TaskType taskType);
148 static Job * RouterJob(Query *originalQuery,
149 					   PlannerRestrictionContext *plannerRestrictionContext,
150 					   DeferredErrorMessage **planningError);
151 static bool RelationPrunesToMultipleShards(List *relationShardList);
152 static void NormalizeMultiRowInsertTargetList(Query *query);
153 static void AppendNextDummyColReference(Alias *expendedReferenceNames);
154 static Value * MakeDummyColumnString(int dummyColumnId);
155 static List * BuildRoutesForInsert(Query *query, DeferredErrorMessage **planningError);
156 static List * GroupInsertValuesByShardId(List *insertValuesList);
157 static List * ExtractInsertValuesList(Query *query, Var *partitionColumn);
158 static DeferredErrorMessage * DeferErrorIfUnsupportedRouterPlannableSelectQuery(
159 	Query *query);
160 static DeferredErrorMessage * ErrorIfQueryHasUnroutableModifyingCTE(Query *queryTree);
161 #if PG_VERSION_NUM >= PG_VERSION_14
162 static DeferredErrorMessage * ErrorIfQueryHasCTEWithSearchClause(Query *queryTree);
163 static bool ContainsSearchClauseWalker(Node *node);
164 #endif
165 static bool SelectsFromDistributedTable(List *rangeTableList, Query *query);
166 static ShardPlacement * CreateDummyPlacement(bool hasLocalRelation);
167 static ShardPlacement * CreateLocalDummyPlacement();
168 static int CompareInsertValuesByShardId(const void *leftElement,
169 										const void *rightElement);
170 static List * SingleShardTaskList(Query *query, uint64 jobId,
171 								  List *relationShardList, List *placementList,
172 								  uint64 shardId, bool parametersInQueryResolved,
173 								  bool isLocalTableModification);
174 static bool RowLocksOnRelations(Node *node, List **rtiLockList);
175 static void ReorderTaskPlacementsByTaskAssignmentPolicy(Job *job,
176 														TaskAssignmentPolicyType
177 														taskAssignmentPolicy,
178 														List *placementList);
179 static bool ModifiesLocalTableWithRemoteCitusLocalTable(List *rangeTableList);
180 static DeferredErrorMessage * DeferErrorIfUnsupportedLocalTableJoin(List *rangeTableList);
181 static bool IsLocallyAccessibleCitusLocalTable(Oid relationId);
182 
183 
184 /*
185  * CreateRouterPlan attempts to create a router executor plan for the given
186  * SELECT statement. ->planningError is set if planning fails.
187  */
188 DistributedPlan *
CreateRouterPlan(Query * originalQuery,Query * query,PlannerRestrictionContext * plannerRestrictionContext)189 CreateRouterPlan(Query *originalQuery, Query *query,
190 				 PlannerRestrictionContext *plannerRestrictionContext)
191 {
192 	DistributedPlan *distributedPlan = CitusMakeNode(DistributedPlan);
193 
194 	distributedPlan->planningError = DeferErrorIfUnsupportedRouterPlannableSelectQuery(
195 		query);
196 
197 	if (distributedPlan->planningError == NULL)
198 	{
199 		CreateSingleTaskRouterSelectPlan(distributedPlan, originalQuery, query,
200 										 plannerRestrictionContext);
201 	}
202 
203 	distributedPlan->fastPathRouterPlan =
204 		plannerRestrictionContext->fastPathRestrictionContext->fastPathRouterQuery;
205 
206 	return distributedPlan;
207 }
208 
209 
210 /*
211  * CreateModifyPlan attempts to create a plan for the given modification
212  * statement. If planning fails ->planningError is set to a description of
213  * the failure.
214  */
215 DistributedPlan *
CreateModifyPlan(Query * originalQuery,Query * query,PlannerRestrictionContext * plannerRestrictionContext)216 CreateModifyPlan(Query *originalQuery, Query *query,
217 				 PlannerRestrictionContext *plannerRestrictionContext)
218 {
219 	Job *job = NULL;
220 	DistributedPlan *distributedPlan = CitusMakeNode(DistributedPlan);
221 	bool multiShardQuery = false;
222 
223 	Assert(originalQuery->commandType != CMD_SELECT);
224 
225 	distributedPlan->modLevel = RowModifyLevelForQuery(query);
226 
227 	distributedPlan->planningError = ModifyQuerySupported(query, originalQuery,
228 														  multiShardQuery,
229 														  plannerRestrictionContext);
230 
231 	if (distributedPlan->planningError != NULL)
232 	{
233 		return distributedPlan;
234 	}
235 
236 	if (UpdateOrDeleteQuery(query))
237 	{
238 		job = RouterJob(originalQuery, plannerRestrictionContext,
239 						&distributedPlan->planningError);
240 	}
241 	else
242 	{
243 		job = RouterInsertJob(originalQuery);
244 	}
245 
246 	if (distributedPlan->planningError != NULL)
247 	{
248 		return distributedPlan;
249 	}
250 
251 	ereport(DEBUG2, (errmsg("Creating router plan")));
252 
253 	distributedPlan->workerJob = job;
254 	distributedPlan->combineQuery = NULL;
255 	distributedPlan->expectResults = originalQuery->returningList != NIL;
256 	distributedPlan->targetRelationId = ResultRelationOidForQuery(query);
257 
258 	distributedPlan->fastPathRouterPlan =
259 		plannerRestrictionContext->fastPathRestrictionContext->fastPathRouterQuery;
260 
261 
262 	return distributedPlan;
263 }
264 
265 
266 /*
267  * CreateSingleTaskRouterPlan creates a physical plan for given SELECT query.
268  * The returned plan is a router task that returns query results from a single worker.
269  * If not router plannable, the returned plan's planningError describes the problem.
270  */
271 static void
CreateSingleTaskRouterSelectPlan(DistributedPlan * distributedPlan,Query * originalQuery,Query * query,PlannerRestrictionContext * plannerRestrictionContext)272 CreateSingleTaskRouterSelectPlan(DistributedPlan *distributedPlan, Query *originalQuery,
273 								 Query *query,
274 								 PlannerRestrictionContext *plannerRestrictionContext)
275 {
276 	Assert(query->commandType == CMD_SELECT);
277 
278 	distributedPlan->modLevel = RowModifyLevelForQuery(query);
279 
280 	Job *job = RouterJob(originalQuery, plannerRestrictionContext,
281 						 &distributedPlan->planningError);
282 
283 	if (distributedPlan->planningError != NULL)
284 	{
285 		/* query cannot be handled by this planner */
286 		return;
287 	}
288 
289 	ereport(DEBUG2, (errmsg("Creating router plan")));
290 
291 	distributedPlan->workerJob = job;
292 	distributedPlan->combineQuery = NULL;
293 	distributedPlan->expectResults = true;
294 }
295 
296 
297 /*
298  * ShardIntervalOpExpressions returns a list of OpExprs with exactly two
299  * items in it. The list consists of shard interval ranges with partition columns
300  * such as (partitionColumn >= shardMinValue) and (partitionColumn <= shardMaxValue).
301  *
302  * The function returns hashed columns generated by MakeInt4Column() for the hash
303  * partitioned tables in place of partition columns.
304  *
305  * The function returns NIL if shard interval does not belong to a hash,
306  * range and append distributed tables.
307  *
308  * NB: If you update this, also look at PrunableExpressionsWalker().
309  */
310 List *
ShardIntervalOpExpressions(ShardInterval * shardInterval,Index rteIndex)311 ShardIntervalOpExpressions(ShardInterval *shardInterval, Index rteIndex)
312 {
313 	Oid relationId = shardInterval->relationId;
314 	Var *partitionColumn = NULL;
315 
316 	if (IsCitusTableType(relationId, HASH_DISTRIBUTED))
317 	{
318 		partitionColumn = MakeInt4Column();
319 	}
320 	else if (IsCitusTableType(relationId, RANGE_DISTRIBUTED) || IsCitusTableType(
321 				 relationId, APPEND_DISTRIBUTED))
322 	{
323 		Assert(rteIndex > 0);
324 		partitionColumn = PartitionColumn(relationId, rteIndex);
325 	}
326 	else
327 	{
328 		/* do not add any shard range interval for reference tables */
329 		return NIL;
330 	}
331 
332 	/* build the base expression for constraint */
333 	Node *baseConstraint = BuildBaseConstraint(partitionColumn);
334 
335 	/* walk over shard list and check if shards can be pruned */
336 	if (shardInterval->minValueExists && shardInterval->maxValueExists)
337 	{
338 		UpdateConstraint(baseConstraint, shardInterval);
339 	}
340 
341 	return list_make1(baseConstraint);
342 }
343 
344 
345 /*
346  * AddPartitionKeyNotNullFilterToSelect adds the following filters to a subquery:
347  *
348  *    partitionColumn IS NOT NULL
349  *
350  * The function expects and asserts that subquery's target list contains a partition
351  * column value. Thus, this function should never be called with reference tables.
352  */
353 void
AddPartitionKeyNotNullFilterToSelect(Query * subqery)354 AddPartitionKeyNotNullFilterToSelect(Query *subqery)
355 {
356 	List *targetList = subqery->targetList;
357 	ListCell *targetEntryCell = NULL;
358 	Var *targetPartitionColumnVar = NULL;
359 
360 	/* iterate through the target entries */
361 	foreach(targetEntryCell, targetList)
362 	{
363 		TargetEntry *targetEntry = lfirst(targetEntryCell);
364 
365 		bool skipOuterVars = true;
366 		if (IsPartitionColumn(targetEntry->expr, subqery, skipOuterVars) &&
367 			IsA(targetEntry->expr, Var))
368 		{
369 			targetPartitionColumnVar = (Var *) targetEntry->expr;
370 			break;
371 		}
372 	}
373 
374 	/* we should have found target partition column */
375 	Assert(targetPartitionColumnVar != NULL);
376 
377 	/* create expression for partition_column IS NOT NULL */
378 	NullTest *nullTest = makeNode(NullTest);
379 	nullTest->nulltesttype = IS_NOT_NULL;
380 	nullTest->arg = (Expr *) targetPartitionColumnVar;
381 	nullTest->argisrow = false;
382 
383 	/* finally add the quals */
384 	if (subqery->jointree->quals == NULL)
385 	{
386 		subqery->jointree->quals = (Node *) nullTest;
387 	}
388 	else
389 	{
390 		subqery->jointree->quals = make_and_qual(subqery->jointree->quals,
391 												 (Node *) nullTest);
392 	}
393 }
394 
395 
396 /*
397  * ExtractSelectRangeTableEntry returns the range table entry of the subquery.
398  * Note that the function expects and asserts that the input query be
399  * an INSERT...SELECT query.
400  */
401 RangeTblEntry *
ExtractSelectRangeTableEntry(Query * query)402 ExtractSelectRangeTableEntry(Query *query)
403 {
404 	Assert(InsertSelectIntoCitusTable(query) || InsertSelectIntoLocalTable(query));
405 
406 	/*
407 	 * Since we already asserted InsertSelectIntoCitusTable() it is safe to access
408 	 * both lists
409 	 */
410 	List *fromList = query->jointree->fromlist;
411 	RangeTblRef *reference = linitial(fromList);
412 	RangeTblEntry *subqueryRte = rt_fetch(reference->rtindex, query->rtable);
413 
414 	return subqueryRte;
415 }
416 
417 
418 /*
419  * ModifyQueryResultRelationId returns the result relation's Oid
420  * for the given modification query.
421  *
422  * The function errors out if the input query is not a
423  * modify query (e.g., INSERT, UPDATE or DELETE). So, this
424  * function is not expected to be called on SELECT queries.
425  */
426 Oid
ModifyQueryResultRelationId(Query * query)427 ModifyQueryResultRelationId(Query *query)
428 {
429 	/* only modify queries have result relations */
430 	if (!IsModifyCommand(query))
431 	{
432 		ereport(ERROR, (errcode(ERRCODE_OBJECT_NOT_IN_PREREQUISITE_STATE),
433 						errmsg("input query is not a modification query")));
434 	}
435 
436 	RangeTblEntry *resultRte = ExtractResultRelationRTE(query);
437 	Assert(OidIsValid(resultRte->relid));
438 
439 	return resultRte->relid;
440 }
441 
442 
443 /*
444  * ResultRelationOidForQuery returns the OID of the relation this is modified
445  * by a given query.
446  */
447 static Oid
ResultRelationOidForQuery(Query * query)448 ResultRelationOidForQuery(Query *query)
449 {
450 	RangeTblEntry *resultRTE = rt_fetch(query->resultRelation, query->rtable);
451 
452 	return resultRTE->relid;
453 }
454 
455 
456 /*
457  * ExtractResultRelationRTE returns the table's resultRelation range table
458  * entry. This returns NULL when there's no resultRelation, such as in a SELECT
459  * query.
460  */
461 RangeTblEntry *
ExtractResultRelationRTE(Query * query)462 ExtractResultRelationRTE(Query *query)
463 {
464 	if (query->resultRelation > 0)
465 	{
466 		return rt_fetch(query->resultRelation, query->rtable);
467 	}
468 
469 	return NULL;
470 }
471 
472 
473 /*
474  * ExtractResultRelationRTEOrError returns the table's resultRelation range table
475  * entry and errors out if there's no result relation at all, e.g. like in a
476  * SELECT query.
477  *
478  * This is a separate function (instead of using missingOk), so static analysis
479  * reasons about NULL returns correctly.
480  */
481 RangeTblEntry *
ExtractResultRelationRTEOrError(Query * query)482 ExtractResultRelationRTEOrError(Query *query)
483 {
484 	RangeTblEntry *relation = ExtractResultRelationRTE(query);
485 	if (relation == NULL)
486 	{
487 		ereport(ERROR, (errmsg("no result relation could be found for the query"),
488 						errhint("is this a SELECT query?")));
489 	}
490 
491 	return relation;
492 }
493 
494 
495 /*
496  * IsTidColumn gets a node and returns true if the node is a Var type of TID.
497  */
498 static bool
IsTidColumn(Node * node)499 IsTidColumn(Node *node)
500 {
501 	if (IsA(node, Var))
502 	{
503 		Var *column = (Var *) node;
504 		if (column->vartype == TIDOID)
505 		{
506 			return true;
507 		}
508 	}
509 
510 	return false;
511 }
512 
513 
514 /*
515  * ModifyPartialQuerySupported implements a subset of what ModifyQuerySupported checks,
516  * that subset being what's necessary to check modifying CTEs for.
517  */
518 static DeferredErrorMessage *
ModifyPartialQuerySupported(Query * queryTree,bool multiShardQuery,Oid * distributedTableIdOutput)519 ModifyPartialQuerySupported(Query *queryTree, bool multiShardQuery,
520 							Oid *distributedTableIdOutput)
521 {
522 	DeferredErrorMessage *deferredError = DeferErrorIfModifyView(queryTree);
523 	if (deferredError != NULL)
524 	{
525 		return deferredError;
526 	}
527 	CmdType commandType = queryTree->commandType;
528 
529 	deferredError = DeferErrorIfUnsupportedLocalTableJoin(queryTree->rtable);
530 	if (deferredError != NULL)
531 	{
532 		return deferredError;
533 	}
534 
535 	/*
536 	 * Reject subqueries which are in SELECT or WHERE clause.
537 	 * Queries which include subqueries in FROM clauses are rejected below.
538 	 */
539 	if (queryTree->hasSubLinks == true)
540 	{
541 		/* we support subqueries for INSERTs only via INSERT INTO ... SELECT */
542 		if (!UpdateOrDeleteQuery(queryTree))
543 		{
544 			Assert(queryTree->commandType == CMD_INSERT);
545 
546 			return DeferredError(ERRCODE_FEATURE_NOT_SUPPORTED,
547 								 "subqueries are not supported within INSERT queries",
548 								 NULL, "Try rewriting your queries with 'INSERT "
549 									   "INTO ... SELECT' syntax.");
550 		}
551 	}
552 
553 	/* reject queries which include CommonTableExpr which aren't routable */
554 	if (queryTree->cteList != NIL)
555 	{
556 		ListCell *cteCell = NULL;
557 
558 		/* CTEs still not supported for INSERTs. */
559 		if (queryTree->commandType == CMD_INSERT)
560 		{
561 			return DeferredError(ERRCODE_FEATURE_NOT_SUPPORTED,
562 								 "Router planner doesn't support common table expressions with INSERT queries.",
563 								 NULL, NULL);
564 		}
565 
566 		foreach(cteCell, queryTree->cteList)
567 		{
568 			CommonTableExpr *cte = (CommonTableExpr *) lfirst(cteCell);
569 			Query *cteQuery = (Query *) cte->ctequery;
570 
571 			if (cteQuery->commandType != CMD_SELECT)
572 			{
573 				/* Modifying CTEs still not supported for multi shard queries. */
574 				if (multiShardQuery)
575 				{
576 					return DeferredError(ERRCODE_FEATURE_NOT_SUPPORTED,
577 										 "Router planner doesn't support non-select common table expressions with multi shard queries.",
578 										 NULL, NULL);
579 				}
580 				/* Modifying CTEs exclude both INSERT CTEs & INSERT queries. */
581 				else if (cteQuery->commandType == CMD_INSERT)
582 				{
583 					return DeferredError(ERRCODE_FEATURE_NOT_SUPPORTED,
584 										 "Router planner doesn't support INSERT common table expressions.",
585 										 NULL, NULL);
586 				}
587 			}
588 
589 			if (cteQuery->hasForUpdate &&
590 				FindNodeMatchingCheckFunctionInRangeTableList(cteQuery->rtable,
591 															  IsReferenceTableRTE))
592 			{
593 				return DeferredError(ERRCODE_FEATURE_NOT_SUPPORTED,
594 									 "Router planner doesn't support SELECT FOR UPDATE"
595 									 " in common table expressions involving reference tables.",
596 									 NULL, NULL);
597 			}
598 
599 			if (FindNodeMatchingCheckFunction((Node *) cteQuery, CitusIsVolatileFunction))
600 			{
601 				return DeferredError(ERRCODE_FEATURE_NOT_SUPPORTED,
602 									 "Router planner doesn't support VOLATILE functions"
603 									 " in common table expressions.",
604 									 NULL, NULL);
605 			}
606 
607 			if (cteQuery->commandType == CMD_SELECT)
608 			{
609 				DeferredErrorMessage *cteError =
610 					DeferErrorIfUnsupportedRouterPlannableSelectQuery(cteQuery);
611 				if (cteError)
612 				{
613 					return cteError;
614 				}
615 			}
616 		}
617 	}
618 
619 
620 	Oid resultRelationId = ModifyQueryResultRelationId(queryTree);
621 	*distributedTableIdOutput = resultRelationId;
622 	uint32 rangeTableId = 1;
623 
624 	Var *partitionColumn = NULL;
625 	if (IsCitusTable(resultRelationId))
626 	{
627 		partitionColumn = PartitionColumn(resultRelationId, rangeTableId);
628 	}
629 	commandType = queryTree->commandType;
630 	if (commandType == CMD_INSERT || commandType == CMD_UPDATE ||
631 		commandType == CMD_DELETE)
632 	{
633 		bool hasVarArgument = false; /* A STABLE function is passed a Var argument */
634 		bool hasBadCoalesce = false; /* CASE/COALESCE passed a mutable function */
635 		FromExpr *joinTree = queryTree->jointree;
636 		ListCell *targetEntryCell = NULL;
637 
638 		foreach(targetEntryCell, queryTree->targetList)
639 		{
640 			TargetEntry *targetEntry = (TargetEntry *) lfirst(targetEntryCell);
641 
642 			/* skip resjunk entries: UPDATE adds some for ctid, etc. */
643 			if (targetEntry->resjunk)
644 			{
645 				continue;
646 			}
647 
648 			bool targetEntryPartitionColumn = false;
649 			AttrNumber targetColumnAttrNumber = InvalidAttrNumber;
650 
651 			/* reference tables do not have partition column */
652 			if (partitionColumn == NULL)
653 			{
654 				targetEntryPartitionColumn = false;
655 			}
656 			else
657 			{
658 				if (commandType == CMD_UPDATE)
659 				{
660 					/*
661 					 * Note that it is not possible to give an alias to
662 					 * UPDATE table SET ...
663 					 */
664 					if (targetEntry->resname)
665 					{
666 						targetColumnAttrNumber = get_attnum(resultRelationId,
667 															targetEntry->resname);
668 						if (targetColumnAttrNumber == partitionColumn->varattno)
669 						{
670 							targetEntryPartitionColumn = true;
671 						}
672 					}
673 				}
674 			}
675 
676 
677 			if (commandType == CMD_UPDATE &&
678 				FindNodeMatchingCheckFunction((Node *) targetEntry->expr,
679 											  CitusIsVolatileFunction))
680 			{
681 				return DeferredError(ERRCODE_FEATURE_NOT_SUPPORTED,
682 									 "functions used in UPDATE queries on distributed "
683 									 "tables must not be VOLATILE",
684 									 NULL, NULL);
685 			}
686 
687 			if (commandType == CMD_UPDATE && targetEntryPartitionColumn &&
688 				TargetEntryChangesValue(targetEntry, partitionColumn,
689 										queryTree->jointree))
690 			{
691 				return DeferredError(ERRCODE_FEATURE_NOT_SUPPORTED,
692 									 "modifying the partition value of rows is not "
693 									 "allowed",
694 									 NULL, NULL);
695 			}
696 
697 			if (commandType == CMD_UPDATE &&
698 				MasterIrreducibleExpression((Node *) targetEntry->expr,
699 											&hasVarArgument, &hasBadCoalesce))
700 			{
701 				Assert(hasVarArgument || hasBadCoalesce);
702 			}
703 
704 			if (FindNodeMatchingCheckFunction((Node *) targetEntry->expr,
705 											  NodeIsFieldStore))
706 			{
707 				/* DELETE cannot do field indirection already */
708 				Assert(commandType == CMD_UPDATE || commandType == CMD_INSERT);
709 				return DeferredError(ERRCODE_FEATURE_NOT_SUPPORTED,
710 									 "inserting or modifying composite type fields is not "
711 									 "supported", NULL,
712 									 "Use the column name to insert or update the composite "
713 									 "type as a single value");
714 			}
715 		}
716 
717 		if (joinTree != NULL)
718 		{
719 			if (FindNodeMatchingCheckFunction((Node *) joinTree->quals,
720 											  CitusIsVolatileFunction))
721 			{
722 				return DeferredError(ERRCODE_FEATURE_NOT_SUPPORTED,
723 									 "functions used in the WHERE clause of modification "
724 									 "queries on distributed tables must not be VOLATILE",
725 									 NULL, NULL);
726 			}
727 			else if (MasterIrreducibleExpression(joinTree->quals, &hasVarArgument,
728 												 &hasBadCoalesce))
729 			{
730 				Assert(hasVarArgument || hasBadCoalesce);
731 			}
732 		}
733 
734 		if (hasVarArgument)
735 		{
736 			return DeferredError(ERRCODE_FEATURE_NOT_SUPPORTED,
737 								 "STABLE functions used in UPDATE queries "
738 								 "cannot be called with column references",
739 								 NULL, NULL);
740 		}
741 
742 		if (hasBadCoalesce)
743 		{
744 			return DeferredError(ERRCODE_FEATURE_NOT_SUPPORTED,
745 								 "non-IMMUTABLE functions are not allowed in CASE or "
746 								 "COALESCE statements",
747 								 NULL, NULL);
748 		}
749 
750 		if (contain_mutable_functions((Node *) queryTree->returningList))
751 		{
752 			return DeferredError(ERRCODE_FEATURE_NOT_SUPPORTED,
753 								 "non-IMMUTABLE functions are not allowed in the "
754 								 "RETURNING clause",
755 								 NULL, NULL);
756 		}
757 
758 		if (queryTree->jointree->quals != NULL &&
759 			nodeTag(queryTree->jointree->quals) == T_CurrentOfExpr)
760 		{
761 			return DeferredError(ERRCODE_FEATURE_NOT_SUPPORTED,
762 								 "cannot run DML queries with cursors", NULL,
763 								 NULL);
764 		}
765 	}
766 
767 	deferredError = ErrorIfOnConflictNotSupported(queryTree);
768 	if (deferredError != NULL)
769 	{
770 		return deferredError;
771 	}
772 
773 
774 	/* set it for caller to use when we don't return any errors */
775 	*distributedTableIdOutput = resultRelationId;
776 
777 	return NULL;
778 }
779 
780 
781 /*
782  * DeferErrorIfUnsupportedLocalTableJoin returns an error message
783  * if there is an unsupported join in the given range table list.
784  */
785 static DeferredErrorMessage *
DeferErrorIfUnsupportedLocalTableJoin(List * rangeTableList)786 DeferErrorIfUnsupportedLocalTableJoin(List *rangeTableList)
787 {
788 	if (ModifiesLocalTableWithRemoteCitusLocalTable(rangeTableList))
789 	{
790 		return DeferredError(ERRCODE_FEATURE_NOT_SUPPORTED,
791 							 "Modifying local tables with remote local tables is "
792 							 "not supported.",
793 							 NULL,
794 							 "Consider wrapping remote local table to a CTE, or subquery");
795 	}
796 	return NULL;
797 }
798 
799 
800 /*
801  * ModifiesLocalTableWithRemoteCitusLocalTable returns true if a local
802  * table is modified with a remote citus local table. This could be a case with
803  * MX structure.
804  */
805 static bool
ModifiesLocalTableWithRemoteCitusLocalTable(List * rangeTableList)806 ModifiesLocalTableWithRemoteCitusLocalTable(List *rangeTableList)
807 {
808 	bool containsLocalResultRelation = false;
809 	bool containsRemoteCitusLocalTable = false;
810 
811 	RangeTblEntry *rangeTableEntry = NULL;
812 	foreach_ptr(rangeTableEntry, rangeTableList)
813 	{
814 		if (!IsRecursivelyPlannableRelation(rangeTableEntry))
815 		{
816 			continue;
817 		}
818 		if (IsCitusTableType(rangeTableEntry->relid, CITUS_LOCAL_TABLE))
819 		{
820 			if (!IsLocallyAccessibleCitusLocalTable(rangeTableEntry->relid))
821 			{
822 				containsRemoteCitusLocalTable = true;
823 			}
824 		}
825 		else if (!IsCitusTable(rangeTableEntry->relid))
826 		{
827 			containsLocalResultRelation = true;
828 		}
829 	}
830 	return containsLocalResultRelation && containsRemoteCitusLocalTable;
831 }
832 
833 
834 /*
835  * IsLocallyAccessibleCitusLocalTable returns true if the given table
836  * is a citus local table that can be accessed using local execution.
837  */
838 static bool
IsLocallyAccessibleCitusLocalTable(Oid relationId)839 IsLocallyAccessibleCitusLocalTable(Oid relationId)
840 {
841 	if (!IsCitusTableType(relationId, CITUS_LOCAL_TABLE))
842 	{
843 		return false;
844 	}
845 
846 	List *shardIntervalList = LoadShardIntervalList(relationId);
847 
848 	/*
849 	 * Citus local tables should always have exactly one shard, but we have
850 	 * this check for safety.
851 	 */
852 	if (list_length(shardIntervalList) != 1)
853 	{
854 		return false;
855 	}
856 
857 	ShardInterval *shardInterval = linitial(shardIntervalList);
858 	uint64 shardId = shardInterval->shardId;
859 	ShardPlacement *localShardPlacement =
860 		ActiveShardPlacementOnGroup(GetLocalGroupId(), shardId);
861 	return localShardPlacement != NULL;
862 }
863 
864 
865 /*
866  * NodeIsFieldStore returns true if given Node is a FieldStore object.
867  */
868 static bool
NodeIsFieldStore(Node * node)869 NodeIsFieldStore(Node *node)
870 {
871 	return node && IsA(node, FieldStore);
872 }
873 
874 
875 /*
876  * ModifyQuerySupported returns NULL if the query only contains supported
877  * features, otherwise it returns an error description.
878  * Note that we need both the original query and the modified one because
879  * different checks need different versions. In particular, we cannot
880  * perform the ContainsReadIntermediateResultFunction check on the
881  * rewritten query because it may have been replaced by a subplan,
882  * while some of the checks for setting the partition column value rely
883  * on the rewritten query.
884  */
885 DeferredErrorMessage *
ModifyQuerySupported(Query * queryTree,Query * originalQuery,bool multiShardQuery,PlannerRestrictionContext * plannerRestrictionContext)886 ModifyQuerySupported(Query *queryTree, Query *originalQuery, bool multiShardQuery,
887 					 PlannerRestrictionContext *plannerRestrictionContext)
888 {
889 	Oid distributedTableId = InvalidOid;
890 	DeferredErrorMessage *error = ModifyPartialQuerySupported(queryTree, multiShardQuery,
891 															  &distributedTableId);
892 	if (error)
893 	{
894 		return error;
895 	}
896 
897 	List *rangeTableList = NIL;
898 	uint32 queryTableCount = 0;
899 	CmdType commandType = queryTree->commandType;
900 	bool fastPathRouterQuery =
901 		plannerRestrictionContext->fastPathRestrictionContext->fastPathRouterQuery;
902 
903 	/*
904 	 * Here, we check if a recursively planned query tries to modify
905 	 * rows based on the ctid column. This is a bad idea because ctid of
906 	 * the rows could be changed before the modification part of
907 	 * the query is executed.
908 	 *
909 	 * We can exclude fast path queries since they cannot have intermediate
910 	 * results by definition.
911 	 */
912 	if (!fastPathRouterQuery &&
913 		ContainsReadIntermediateResultFunction((Node *) originalQuery))
914 	{
915 		bool hasTidColumn = FindNodeMatchingCheckFunction(
916 			(Node *) originalQuery->jointree, IsTidColumn);
917 
918 		if (hasTidColumn)
919 		{
920 			return DeferredError(ERRCODE_FEATURE_NOT_SUPPORTED,
921 								 "cannot perform distributed planning for the given "
922 								 "modification",
923 								 "Recursively planned distributed modifications "
924 								 "with ctid on where clause are not supported.",
925 								 NULL);
926 		}
927 	}
928 
929 	/*
930 	 * Extract range table entries for queries that are not fast path. We can skip fast
931 	 * path queries because their definition is a single RTE entry, which is a relation,
932 	 * so the following check doesn't apply for fast-path queries.
933 	 */
934 	if (!fastPathRouterQuery)
935 	{
936 		ExtractRangeTableEntryWalker((Node *) originalQuery, &rangeTableList);
937 	}
938 	bool containsLocalTableDistributedTableJoin =
939 		ContainsLocalTableDistributedTableJoin(queryTree->rtable);
940 
941 	RangeTblEntry *rangeTableEntry = NULL;
942 	foreach_ptr(rangeTableEntry, rangeTableList)
943 	{
944 		if (rangeTableEntry->rtekind == RTE_RELATION)
945 		{
946 			/* we do not expect to see a view in modify query */
947 			if (rangeTableEntry->relkind == RELKIND_VIEW)
948 			{
949 				/*
950 				 * we already check if modify is run on a view in DeferErrorIfModifyView
951 				 * function call. In addition, since Postgres replaced views in FROM
952 				 * clause with subqueries, encountering with a view should not be a problem here.
953 				 */
954 			}
955 			else if (rangeTableEntry->relkind == RELKIND_MATVIEW)
956 			{
957 				return DeferredError(ERRCODE_FEATURE_NOT_SUPPORTED,
958 									 "materialized views in modify queries are not supported",
959 									 NULL, NULL);
960 			}
961 			/* for other kinds of relations, check if its distributed */
962 			else
963 			{
964 				if (IsRelationLocalTableOrMatView(rangeTableEntry->relid) &&
965 					containsLocalTableDistributedTableJoin)
966 				{
967 					StringInfo errorMessage = makeStringInfo();
968 					char *relationName = get_rel_name(rangeTableEntry->relid);
969 					if (IsCitusTable(rangeTableEntry->relid))
970 					{
971 						appendStringInfo(errorMessage,
972 										 "local table %s cannot be joined with these distributed tables",
973 										 relationName);
974 					}
975 					else
976 					{
977 						appendStringInfo(errorMessage, "relation %s is not distributed",
978 										 relationName);
979 					}
980 					return DeferredError(ERRCODE_FEATURE_NOT_SUPPORTED,
981 										 errorMessage->data, NULL, NULL);
982 				}
983 			}
984 
985 			queryTableCount++;
986 		}
987 		else if (rangeTableEntry->rtekind == RTE_VALUES ||
988 				 rangeTableEntry->rtekind == RTE_RESULT
989 				 )
990 		{
991 			/* do nothing, this type is supported */
992 		}
993 		else
994 		{
995 			char *rangeTableEntryErrorDetail = NULL;
996 
997 			/*
998 			 * We support UPDATE and DELETE with subqueries and joins unless
999 			 * they are multi shard queries.
1000 			 */
1001 			if (UpdateOrDeleteQuery(queryTree))
1002 			{
1003 				continue;
1004 			}
1005 
1006 			/*
1007 			 * Error out for rangeTableEntries that we do not support.
1008 			 * We do not explicitly specify "in FROM clause" in the error detail
1009 			 * for the features that we do not support at all (SUBQUERY, JOIN).
1010 			 */
1011 			if (rangeTableEntry->rtekind == RTE_SUBQUERY)
1012 			{
1013 				StringInfo errorHint = makeStringInfo();
1014 				CitusTableCacheEntry *cacheEntry = GetCitusTableCacheEntry(
1015 					distributedTableId);
1016 				char *partitionKeyString = cacheEntry->partitionKeyString;
1017 				char *partitionColumnName = ColumnToColumnName(distributedTableId,
1018 															   partitionKeyString);
1019 
1020 				appendStringInfo(errorHint, "Consider using an equality filter on "
1021 											"partition column \"%s\" to target a single shard.",
1022 								 partitionColumnName);
1023 
1024 				return DeferredError(ERRCODE_FEATURE_NOT_SUPPORTED, "subqueries are not "
1025 																	"supported in modifications across multiple shards",
1026 									 errorHint->data, NULL);
1027 			}
1028 			else if (rangeTableEntry->rtekind == RTE_JOIN)
1029 			{
1030 				rangeTableEntryErrorDetail = "Joins are not supported in distributed"
1031 											 " modifications.";
1032 			}
1033 			else if (rangeTableEntry->rtekind == RTE_FUNCTION)
1034 			{
1035 				rangeTableEntryErrorDetail = "Functions must not appear in the FROM"
1036 											 " clause of a distributed modifications.";
1037 			}
1038 			else if (rangeTableEntry->rtekind == RTE_CTE)
1039 			{
1040 				rangeTableEntryErrorDetail = "Common table expressions are not supported"
1041 											 " in distributed modifications.";
1042 			}
1043 			else
1044 			{
1045 				rangeTableEntryErrorDetail = "Unrecognized range table entry.";
1046 			}
1047 
1048 			return DeferredError(ERRCODE_FEATURE_NOT_SUPPORTED,
1049 								 "cannot perform distributed planning for the given "
1050 								 "modifications",
1051 								 rangeTableEntryErrorDetail,
1052 								 NULL);
1053 		}
1054 	}
1055 
1056 	if (commandType != CMD_INSERT)
1057 	{
1058 		DeferredErrorMessage *errorMessage = NULL;
1059 
1060 		if (multiShardQuery)
1061 		{
1062 			errorMessage = MultiShardUpdateDeleteSupported(originalQuery,
1063 														   plannerRestrictionContext);
1064 		}
1065 		else
1066 		{
1067 			errorMessage = SingleShardUpdateDeleteSupported(originalQuery,
1068 															plannerRestrictionContext);
1069 		}
1070 
1071 		if (errorMessage != NULL)
1072 		{
1073 			return errorMessage;
1074 		}
1075 	}
1076 
1077 #if PG_VERSION_NUM >= PG_VERSION_14
1078 	DeferredErrorMessage *CTEWithSearchClauseError =
1079 		ErrorIfQueryHasCTEWithSearchClause(originalQuery);
1080 	if (CTEWithSearchClauseError != NULL)
1081 	{
1082 		return CTEWithSearchClauseError;
1083 	}
1084 #endif
1085 
1086 	return NULL;
1087 }
1088 
1089 
1090 /*
1091  * Modify statements on simple updetable views are not supported yet.
1092  * Actually, we need the original query (the query before postgres
1093  * pg_rewrite_query) to detect if the view sitting in rtable is to
1094  * be updated or just to be used in FROM clause.
1095  * Hence, tracing the postgres source code, we deduced that postgres
1096  * puts the relation to be modified to the first entry of rtable.
1097  * If first element of the range table list is a simple updatable
1098  * view and this view is not coming from FROM clause (inFromCl = False),
1099  * then update is run "on" that view.
1100  */
1101 static DeferredErrorMessage *
DeferErrorIfModifyView(Query * queryTree)1102 DeferErrorIfModifyView(Query *queryTree)
1103 {
1104 	if (queryTree->rtable != NIL)
1105 	{
1106 		RangeTblEntry *firstRangeTableElement = (RangeTblEntry *) linitial(
1107 			queryTree->rtable);
1108 
1109 		if (firstRangeTableElement->rtekind == RTE_RELATION &&
1110 			firstRangeTableElement->relkind == RELKIND_VIEW &&
1111 			firstRangeTableElement->inFromCl == false)
1112 		{
1113 			return DeferredError(ERRCODE_FEATURE_NOT_SUPPORTED,
1114 								 "cannot modify views when the query contains citus tables",
1115 								 NULL,
1116 								 NULL);
1117 		}
1118 	}
1119 
1120 	return NULL;
1121 }
1122 
1123 
1124 /*
1125  * ErrorIfOnConflictNotSupprted returns an error if an INSERT query has an
1126  * unsupported ON CONFLICT clause. In particular, changing the partition
1127  * column value or using volatile functions is not allowed.
1128  */
1129 DeferredErrorMessage *
ErrorIfOnConflictNotSupported(Query * queryTree)1130 ErrorIfOnConflictNotSupported(Query *queryTree)
1131 {
1132 	uint32 rangeTableId = 1;
1133 	ListCell *setTargetCell = NULL;
1134 	bool specifiesPartitionValue = false;
1135 
1136 	CmdType commandType = queryTree->commandType;
1137 	if (commandType != CMD_INSERT || queryTree->onConflict == NULL)
1138 	{
1139 		return NULL;
1140 	}
1141 
1142 	Oid distributedTableId = ExtractFirstCitusTableId(queryTree);
1143 	Var *partitionColumn = PartitionColumn(distributedTableId, rangeTableId);
1144 
1145 	List *onConflictSet = queryTree->onConflict->onConflictSet;
1146 	Node *arbiterWhere = queryTree->onConflict->arbiterWhere;
1147 	Node *onConflictWhere = queryTree->onConflict->onConflictWhere;
1148 
1149 	/*
1150 	 * onConflictSet is expanded via expand_targetlist() on the standard planner.
1151 	 * This ends up adding all the columns to the onConflictSet even if the user
1152 	 * does not explicitly state the columns in the query.
1153 	 *
1154 	 * The following loop simply allows "DO UPDATE SET part_col = table.part_col"
1155 	 * types of elements in the target list, which are added by expand_targetlist().
1156 	 * Any other attempt to update partition column value is forbidden.
1157 	 */
1158 	foreach(setTargetCell, onConflictSet)
1159 	{
1160 		TargetEntry *setTargetEntry = (TargetEntry *) lfirst(setTargetCell);
1161 		bool setTargetEntryPartitionColumn = false;
1162 
1163 		/* reference tables do not have partition column */
1164 		if (partitionColumn == NULL)
1165 		{
1166 			setTargetEntryPartitionColumn = false;
1167 		}
1168 		else
1169 		{
1170 			Oid resultRelationId = ModifyQueryResultRelationId(queryTree);
1171 
1172 			AttrNumber targetColumnAttrNumber = InvalidAttrNumber;
1173 			if (setTargetEntry->resname)
1174 			{
1175 				targetColumnAttrNumber = get_attnum(resultRelationId,
1176 													setTargetEntry->resname);
1177 				if (targetColumnAttrNumber == partitionColumn->varattno)
1178 				{
1179 					setTargetEntryPartitionColumn = true;
1180 				}
1181 			}
1182 		}
1183 		if (setTargetEntryPartitionColumn)
1184 		{
1185 			Expr *setExpr = setTargetEntry->expr;
1186 			if (IsA(setExpr, Var) &&
1187 				((Var *) setExpr)->varattno == partitionColumn->varattno)
1188 			{
1189 				specifiesPartitionValue = false;
1190 			}
1191 			else
1192 			{
1193 				specifiesPartitionValue = true;
1194 			}
1195 		}
1196 		else
1197 		{
1198 			/*
1199 			 * Similarly, allow  "DO UPDATE SET col_1 = table.col_1" types of
1200 			 * target list elements. Note that, the following check allows
1201 			 * "DO UPDATE SET col_1 = table.col_2", which is not harmful.
1202 			 */
1203 			if (IsA(setTargetEntry->expr, Var))
1204 			{
1205 				continue;
1206 			}
1207 			else if (contain_mutable_functions((Node *) setTargetEntry->expr))
1208 			{
1209 				return DeferredError(ERRCODE_FEATURE_NOT_SUPPORTED,
1210 									 "functions used in the DO UPDATE SET clause of "
1211 									 "INSERTs on distributed tables must be marked "
1212 									 "IMMUTABLE",
1213 									 NULL, NULL);
1214 			}
1215 		}
1216 	}
1217 
1218 	/* error if either arbiter or on conflict WHERE contains a mutable function */
1219 	if (contain_mutable_functions((Node *) arbiterWhere) ||
1220 		contain_mutable_functions((Node *) onConflictWhere))
1221 	{
1222 		return DeferredError(ERRCODE_FEATURE_NOT_SUPPORTED,
1223 							 "functions used in the WHERE clause of the "
1224 							 "ON CONFLICT clause of INSERTs on distributed "
1225 							 "tables must be marked IMMUTABLE",
1226 							 NULL, NULL);
1227 	}
1228 
1229 	if (specifiesPartitionValue)
1230 	{
1231 		return DeferredError(ERRCODE_FEATURE_NOT_SUPPORTED,
1232 							 "modifying the partition value of rows is not "
1233 							 "allowed",
1234 							 NULL, NULL);
1235 	}
1236 
1237 	return NULL;
1238 }
1239 
1240 
1241 /*
1242  * MultiShardUpdateDeleteSupported returns the error message if the update/delete is
1243  * not pushdownable, otherwise it returns NULL.
1244  */
1245 static DeferredErrorMessage *
MultiShardUpdateDeleteSupported(Query * originalQuery,PlannerRestrictionContext * plannerRestrictionContext)1246 MultiShardUpdateDeleteSupported(Query *originalQuery,
1247 								PlannerRestrictionContext *plannerRestrictionContext)
1248 {
1249 	DeferredErrorMessage *errorMessage = NULL;
1250 	RangeTblEntry *resultRangeTable = ExtractResultRelationRTE(originalQuery);
1251 	Oid resultRelationOid = resultRangeTable->relid;
1252 
1253 	if (HasDangerousJoinUsing(originalQuery->rtable, (Node *) originalQuery->jointree))
1254 	{
1255 		errorMessage = DeferredError(ERRCODE_FEATURE_NOT_SUPPORTED,
1256 									 "a join with USING causes an internal naming conflict, use "
1257 									 "ON instead",
1258 									 NULL, NULL);
1259 	}
1260 	else if (FindNodeMatchingCheckFunction((Node *) originalQuery,
1261 										   CitusIsVolatileFunction))
1262 	{
1263 		errorMessage = DeferredError(ERRCODE_FEATURE_NOT_SUPPORTED,
1264 									 "functions used in UPDATE queries on distributed "
1265 									 "tables must not be VOLATILE",
1266 									 NULL, NULL);
1267 	}
1268 	else if (IsCitusTableType(resultRelationOid, REFERENCE_TABLE))
1269 	{
1270 		errorMessage = DeferredError(ERRCODE_FEATURE_NOT_SUPPORTED,
1271 									 "only reference tables may be queried when targeting "
1272 									 "a reference table with multi shard UPDATE/DELETE queries "
1273 									 "with multiple tables ",
1274 									 NULL, NULL);
1275 	}
1276 	else
1277 	{
1278 		errorMessage = DeferErrorIfUnsupportedSubqueryPushdown(originalQuery,
1279 															   plannerRestrictionContext);
1280 	}
1281 
1282 	return errorMessage;
1283 }
1284 
1285 
1286 /*
1287  * SingleShardUpdateDeleteSupported returns the error message if the update/delete query is
1288  * not routable, otherwise it returns NULL.
1289  */
1290 static DeferredErrorMessage *
SingleShardUpdateDeleteSupported(Query * originalQuery,PlannerRestrictionContext * plannerRestrictionContext)1291 SingleShardUpdateDeleteSupported(Query *originalQuery,
1292 								 PlannerRestrictionContext *plannerRestrictionContext)
1293 {
1294 	DeferredErrorMessage *errorMessage = NULL;
1295 
1296 	/*
1297 	 * We currently do not support volatile functions in update/delete statements because
1298 	 * the function evaluation logic does not know how to distinguish volatile functions
1299 	 * (that need to be evaluated per row) from stable functions (that need to be evaluated per query),
1300 	 * and it is also not safe to push the volatile functions down on replicated tables.
1301 	 */
1302 	if (FindNodeMatchingCheckFunction((Node *) originalQuery,
1303 									  CitusIsVolatileFunction))
1304 	{
1305 		errorMessage = DeferredError(ERRCODE_FEATURE_NOT_SUPPORTED,
1306 									 "functions used in UPDATE queries on distributed "
1307 									 "tables must not be VOLATILE",
1308 									 NULL, NULL);
1309 	}
1310 
1311 	return errorMessage;
1312 }
1313 
1314 
1315 /*
1316  * HasDangerousJoinUsing search jointree for unnamed JOIN USING. Check the
1317  * implementation of has_dangerous_join_using in ruleutils.
1318  */
1319 static bool
HasDangerousJoinUsing(List * rtableList,Node * joinTreeNode)1320 HasDangerousJoinUsing(List *rtableList, Node *joinTreeNode)
1321 {
1322 	if (IsA(joinTreeNode, RangeTblRef))
1323 	{
1324 		/* nothing to do here */
1325 	}
1326 	else if (IsA(joinTreeNode, FromExpr))
1327 	{
1328 		FromExpr *fromExpr = (FromExpr *) joinTreeNode;
1329 		ListCell *listCell;
1330 
1331 		foreach(listCell, fromExpr->fromlist)
1332 		{
1333 			if (HasDangerousJoinUsing(rtableList, (Node *) lfirst(listCell)))
1334 			{
1335 				return true;
1336 			}
1337 		}
1338 	}
1339 	else if (IsA(joinTreeNode, JoinExpr))
1340 	{
1341 		JoinExpr *joinExpr = (JoinExpr *) joinTreeNode;
1342 
1343 		/* Is it an unnamed JOIN with USING? */
1344 		if (joinExpr->alias == NULL && joinExpr->usingClause)
1345 		{
1346 			/*
1347 			 * Yes, so check each join alias var to see if any of them are not
1348 			 * simple references to underlying columns. If so, we have a
1349 			 * dangerous situation and must pick unique aliases.
1350 			 */
1351 			RangeTblEntry *joinRTE = rt_fetch(joinExpr->rtindex, rtableList);
1352 			ListCell *listCell;
1353 
1354 			foreach(listCell, joinRTE->joinaliasvars)
1355 			{
1356 				Var *aliasVar = (Var *) lfirst(listCell);
1357 
1358 				if (aliasVar != NULL && !IsA(aliasVar, Var))
1359 				{
1360 					return true;
1361 				}
1362 			}
1363 		}
1364 
1365 		/* Nope, but inspect children */
1366 		if (HasDangerousJoinUsing(rtableList, joinExpr->larg))
1367 		{
1368 			return true;
1369 		}
1370 		if (HasDangerousJoinUsing(rtableList, joinExpr->rarg))
1371 		{
1372 			return true;
1373 		}
1374 	}
1375 	else
1376 	{
1377 		elog(ERROR, "unrecognized node type: %d",
1378 			 (int) nodeTag(joinTreeNode));
1379 	}
1380 	return false;
1381 }
1382 
1383 
1384 /*
1385  * UpdateOrDeleteQuery checks if the given query is an UPDATE or DELETE command.
1386  * If it is, it returns true otherwise it returns false.
1387  */
1388 bool
UpdateOrDeleteQuery(Query * query)1389 UpdateOrDeleteQuery(Query *query)
1390 {
1391 	return query->commandType == CMD_UPDATE ||
1392 		   query->commandType == CMD_DELETE;
1393 }
1394 
1395 
1396 /*
1397  * If the expression contains STABLE functions which accept any parameters derived from a
1398  * Var returns true and sets varArgument.
1399  *
1400  * If the expression contains a CASE or COALESCE which invoke non-IMMUTABLE functions
1401  * returns true and sets badCoalesce.
1402  *
1403  * Assumes the expression contains no VOLATILE functions.
1404  *
1405  * Var's are allowed, but only if they are passed solely to IMMUTABLE functions
1406  *
1407  * We special-case CASE/COALESCE because those are evaluated lazily. We could evaluate
1408  * CASE/COALESCE expressions which don't reference Vars, or partially evaluate some
1409  * which do, but for now we just error out. That makes both the code and user-education
1410  * easier.
1411  */
1412 static bool
MasterIrreducibleExpression(Node * expression,bool * varArgument,bool * badCoalesce)1413 MasterIrreducibleExpression(Node *expression, bool *varArgument, bool *badCoalesce)
1414 {
1415 	WalkerState data;
1416 	data.containsVar = data.varArgument = data.badCoalesce = false;
1417 
1418 	bool result = MasterIrreducibleExpressionWalker(expression, &data);
1419 
1420 	*varArgument |= data.varArgument;
1421 	*badCoalesce |= data.badCoalesce;
1422 	return result;
1423 }
1424 
1425 
1426 static bool
MasterIrreducibleExpressionWalker(Node * expression,WalkerState * state)1427 MasterIrreducibleExpressionWalker(Node *expression, WalkerState *state)
1428 {
1429 	char volatileFlag = 0;
1430 	WalkerState childState = { false, false, false };
1431 	bool containsDisallowedFunction = false;
1432 	bool hasVolatileFunction PG_USED_FOR_ASSERTS_ONLY = false;
1433 
1434 	if (expression == NULL)
1435 	{
1436 		return false;
1437 	}
1438 
1439 	if (IsA(expression, CoalesceExpr))
1440 	{
1441 		CoalesceExpr *expr = (CoalesceExpr *) expression;
1442 
1443 		if (contain_mutable_functions((Node *) (expr->args)))
1444 		{
1445 			state->badCoalesce = true;
1446 			return true;
1447 		}
1448 		else
1449 		{
1450 			/*
1451 			 * There's no need to recurse. Since there are no STABLE functions
1452 			 * varArgument will never be set.
1453 			 */
1454 			return false;
1455 		}
1456 	}
1457 
1458 	if (IsA(expression, CaseExpr))
1459 	{
1460 		if (contain_mutable_functions(expression))
1461 		{
1462 			state->badCoalesce = true;
1463 			return true;
1464 		}
1465 
1466 		return false;
1467 	}
1468 
1469 	if (IsA(expression, Var))
1470 	{
1471 		state->containsVar = true;
1472 		return false;
1473 	}
1474 
1475 	/*
1476 	 * In order for statement replication to give us consistent results it's important
1477 	 * that we either disallow or evaluate on the coordinator anything which has a
1478 	 * volatility category above IMMUTABLE. Newer versions of postgres might add node
1479 	 * types which should be checked in this function.
1480 	 *
1481 	 * Look through contain_mutable_functions_walker or future PG's equivalent for new
1482 	 * node types before bumping this version number to fix compilation; e.g. for any
1483 	 * PostgreSQL after 9.5, see check_functions_in_node. Review
1484 	 * MasterIrreducibleExpressionFunctionChecker for any changes in volatility
1485 	 * permissibility ordering.
1486 	 *
1487 	 * Once you've added them to this check, make sure you also evaluate them in the
1488 	 * executor!
1489 	 */
1490 
1491 	hasVolatileFunction =
1492 		check_functions_in_node(expression, MasterIrreducibleExpressionFunctionChecker,
1493 								&volatileFlag);
1494 
1495 	/* the caller should have already checked for this */
1496 	Assert(!hasVolatileFunction);
1497 	Assert(volatileFlag != PROVOLATILE_VOLATILE);
1498 
1499 	if (volatileFlag == PROVOLATILE_STABLE)
1500 	{
1501 		containsDisallowedFunction =
1502 			expression_tree_walker(expression,
1503 								   MasterIrreducibleExpressionWalker,
1504 								   &childState);
1505 
1506 		if (childState.containsVar)
1507 		{
1508 			state->varArgument = true;
1509 		}
1510 
1511 		state->badCoalesce |= childState.badCoalesce;
1512 		state->varArgument |= childState.varArgument;
1513 
1514 		return (containsDisallowedFunction || childState.containsVar);
1515 	}
1516 
1517 	/* keep traversing */
1518 	return expression_tree_walker(expression,
1519 								  MasterIrreducibleExpressionWalker,
1520 								  state);
1521 }
1522 
1523 
1524 /*
1525  * MasterIrreducibleExpressionFunctionChecker returns true if a provided function
1526  * oid corresponds to a volatile function. It also updates provided context if
1527  * the current volatility flag is more permissive than the provided one. It is
1528  * only called from check_functions_in_node as checker function.
1529  */
1530 static bool
MasterIrreducibleExpressionFunctionChecker(Oid func_id,void * context)1531 MasterIrreducibleExpressionFunctionChecker(Oid func_id, void *context)
1532 {
1533 	char volatileFlag = func_volatile(func_id);
1534 	char *volatileContext = (char *) context;
1535 
1536 	if (volatileFlag == PROVOLATILE_VOLATILE || *volatileContext == PROVOLATILE_VOLATILE)
1537 	{
1538 		*volatileContext = PROVOLATILE_VOLATILE;
1539 	}
1540 	else if (volatileFlag == PROVOLATILE_STABLE || *volatileContext == PROVOLATILE_STABLE)
1541 	{
1542 		*volatileContext = PROVOLATILE_STABLE;
1543 	}
1544 	else
1545 	{
1546 		*volatileContext = PROVOLATILE_IMMUTABLE;
1547 	}
1548 
1549 	return (volatileFlag == PROVOLATILE_VOLATILE);
1550 }
1551 
1552 
1553 /*
1554  * TargetEntryChangesValue determines whether the given target entry may
1555  * change the value in a given column, given a join tree. The result is
1556  * true unless the expression refers directly to the column, or the
1557  * expression is a value that is implied by the qualifiers of the join
1558  * tree, or the target entry sets a different column.
1559  */
1560 static bool
TargetEntryChangesValue(TargetEntry * targetEntry,Var * column,FromExpr * joinTree)1561 TargetEntryChangesValue(TargetEntry *targetEntry, Var *column, FromExpr *joinTree)
1562 {
1563 	bool isColumnValueChanged = true;
1564 	Expr *setExpr = targetEntry->expr;
1565 
1566 	if (IsA(setExpr, Var))
1567 	{
1568 		Var *newValue = (Var *) setExpr;
1569 		if (newValue->varattno == column->varattno)
1570 		{
1571 			/* target entry of the form SET col = table.col */
1572 			isColumnValueChanged = false;
1573 		}
1574 	}
1575 	else if (IsA(setExpr, Const))
1576 	{
1577 		Const *newValue = (Const *) setExpr;
1578 		List *restrictClauseList = WhereClauseList(joinTree);
1579 		OpExpr *equalityExpr = MakeOpExpression(column, BTEqualStrategyNumber);
1580 		Node *rightOp = get_rightop((Expr *) equalityExpr);
1581 
1582 		Assert(rightOp != NULL);
1583 		Assert(IsA(rightOp, Const));
1584 		Const *rightConst = (Const *) rightOp;
1585 
1586 		rightConst->constvalue = newValue->constvalue;
1587 		rightConst->constisnull = newValue->constisnull;
1588 		rightConst->constbyval = newValue->constbyval;
1589 
1590 		bool predicateIsImplied = predicate_implied_by(list_make1(equalityExpr),
1591 													   restrictClauseList, false);
1592 		if (predicateIsImplied)
1593 		{
1594 			/* target entry of the form SET col = <x> WHERE col = <x> AND ... */
1595 			isColumnValueChanged = false;
1596 		}
1597 	}
1598 
1599 	return isColumnValueChanged;
1600 }
1601 
1602 
1603 /*
1604  * RouterInsertJob builds a Job to represent an insertion performed by the provided
1605  * query. For inserts we always defer shard pruning and generating the task list to
1606  * the executor.
1607  */
1608 static Job *
RouterInsertJob(Query * originalQuery)1609 RouterInsertJob(Query *originalQuery)
1610 {
1611 	Assert(originalQuery->commandType == CMD_INSERT);
1612 
1613 	bool isMultiRowInsert = IsMultiRowInsert(originalQuery);
1614 	if (isMultiRowInsert)
1615 	{
1616 		/* add default expressions to RTE_VALUES in multi-row INSERTs */
1617 		NormalizeMultiRowInsertTargetList(originalQuery);
1618 	}
1619 
1620 	Job *job = CreateJob(originalQuery);
1621 	job->requiresCoordinatorEvaluation = RequiresCoordinatorEvaluation(originalQuery);
1622 	job->deferredPruning = true;
1623 	job->partitionKeyValue = ExtractInsertPartitionKeyValue(originalQuery);
1624 
1625 	return job;
1626 }
1627 
1628 
1629 /*
1630  * CreateJob returns a new Job for the given query.
1631  */
1632 static Job *
CreateJob(Query * query)1633 CreateJob(Query *query)
1634 {
1635 	Job *job = CitusMakeNode(Job);
1636 	job->jobId = UniqueJobId();
1637 	job->jobQuery = query;
1638 	job->taskList = NIL;
1639 	job->dependentJobList = NIL;
1640 	job->subqueryPushdown = false;
1641 	job->requiresCoordinatorEvaluation = false;
1642 	job->deferredPruning = false;
1643 
1644 	return job;
1645 }
1646 
1647 
1648 /*
1649  * ErrorIfNoShardsExist throws an error if the given table has no shards.
1650  */
1651 static void
ErrorIfNoShardsExist(CitusTableCacheEntry * cacheEntry)1652 ErrorIfNoShardsExist(CitusTableCacheEntry *cacheEntry)
1653 {
1654 	int shardCount = cacheEntry->shardIntervalArrayLength;
1655 	if (shardCount == 0)
1656 	{
1657 		Oid distributedTableId = cacheEntry->relationId;
1658 		char *relationName = get_rel_name(distributedTableId);
1659 
1660 		ereport(ERROR, (errcode(ERRCODE_OBJECT_NOT_IN_PREREQUISITE_STATE),
1661 						errmsg("could not find any shards"),
1662 						errdetail("No shards exist for distributed table \"%s\".",
1663 								  relationName),
1664 						errhint("Run master_create_worker_shards to create shards "
1665 								"and try again.")));
1666 	}
1667 }
1668 
1669 
1670 /*
1671  * RouterInsertTaskList generates a list of tasks for performing an INSERT on
1672  * a distributed table via the router executor.
1673  */
1674 List *
RouterInsertTaskList(Query * query,bool parametersInQueryResolved,DeferredErrorMessage ** planningError)1675 RouterInsertTaskList(Query *query, bool parametersInQueryResolved,
1676 					 DeferredErrorMessage **planningError)
1677 {
1678 	List *insertTaskList = NIL;
1679 
1680 	Oid distributedTableId = ExtractFirstCitusTableId(query);
1681 	CitusTableCacheEntry *cacheEntry = GetCitusTableCacheEntry(distributedTableId);
1682 
1683 	ErrorIfNoShardsExist(cacheEntry);
1684 
1685 	Assert(query->commandType == CMD_INSERT);
1686 
1687 	List *modifyRouteList = BuildRoutesForInsert(query, planningError);
1688 	if (*planningError != NULL)
1689 	{
1690 		return NIL;
1691 	}
1692 
1693 	ModifyRoute *modifyRoute = NULL;
1694 	foreach_ptr(modifyRoute, modifyRouteList)
1695 	{
1696 		Task *modifyTask = CreateTask(MODIFY_TASK);
1697 		modifyTask->anchorShardId = modifyRoute->shardId;
1698 		modifyTask->replicationModel = cacheEntry->replicationModel;
1699 		modifyTask->rowValuesLists = modifyRoute->rowValuesLists;
1700 
1701 		RelationShard *relationShard = CitusMakeNode(RelationShard);
1702 		relationShard->shardId = modifyRoute->shardId;
1703 		relationShard->relationId = distributedTableId;
1704 
1705 		modifyTask->relationShardList = list_make1(relationShard);
1706 		modifyTask->taskPlacementList = ActiveShardPlacementList(
1707 			modifyRoute->shardId);
1708 		modifyTask->parametersInQueryStringResolved = parametersInQueryResolved;
1709 
1710 		insertTaskList = lappend(insertTaskList, modifyTask);
1711 	}
1712 
1713 	return insertTaskList;
1714 }
1715 
1716 
1717 /*
1718  * CreateTask returns a new Task with the given type.
1719  */
1720 static Task *
CreateTask(TaskType taskType)1721 CreateTask(TaskType taskType)
1722 {
1723 	Task *task = CitusMakeNode(Task);
1724 	task->taskType = taskType;
1725 	task->jobId = INVALID_JOB_ID;
1726 	task->taskId = INVALID_TASK_ID;
1727 	SetTaskQueryString(task, NULL);
1728 	task->anchorShardId = INVALID_SHARD_ID;
1729 	task->taskPlacementList = NIL;
1730 	task->dependentTaskList = NIL;
1731 
1732 	task->partitionId = 0;
1733 	task->upstreamTaskId = INVALID_TASK_ID;
1734 	task->shardInterval = NULL;
1735 	task->assignmentConstrained = false;
1736 	task->replicationModel = REPLICATION_MODEL_INVALID;
1737 	task->relationRowLockList = NIL;
1738 
1739 	task->modifyWithSubquery = false;
1740 	task->partiallyLocalOrRemote = false;
1741 	task->relationShardList = NIL;
1742 
1743 	return task;
1744 }
1745 
1746 
1747 /*
1748  * ExtractFirstCitusTableId takes a given query, and finds the relationId
1749  * for the first distributed table in that query. If the function cannot find a
1750  * distributed table, it returns InvalidOid.
1751  *
1752  * We only use this function for modifications and fast path queries, which
1753  * should have the first distributed table in the top-level rtable.
1754  */
1755 Oid
ExtractFirstCitusTableId(Query * query)1756 ExtractFirstCitusTableId(Query *query)
1757 {
1758 	List *rangeTableList = query->rtable;
1759 	ListCell *rangeTableCell = NULL;
1760 	Oid distributedTableId = InvalidOid;
1761 
1762 	foreach(rangeTableCell, rangeTableList)
1763 	{
1764 		RangeTblEntry *rangeTableEntry = (RangeTblEntry *) lfirst(rangeTableCell);
1765 
1766 		if (IsCitusTable(rangeTableEntry->relid))
1767 		{
1768 			distributedTableId = rangeTableEntry->relid;
1769 			break;
1770 		}
1771 	}
1772 
1773 	return distributedTableId;
1774 }
1775 
1776 
1777 /*
1778  * RouterJob builds a Job to represent a single shard select/update/delete and
1779  * multiple shard update/delete queries.
1780  */
1781 static Job *
RouterJob(Query * originalQuery,PlannerRestrictionContext * plannerRestrictionContext,DeferredErrorMessage ** planningError)1782 RouterJob(Query *originalQuery, PlannerRestrictionContext *plannerRestrictionContext,
1783 		  DeferredErrorMessage **planningError)
1784 {
1785 	uint64 shardId = INVALID_SHARD_ID;
1786 	List *placementList = NIL;
1787 	List *relationShardList = NIL;
1788 	List *prunedShardIntervalListList = NIL;
1789 	bool isMultiShardModifyQuery = false;
1790 	Const *partitionKeyValue = NULL;
1791 
1792 	/* router planner should create task even if it doesn't hit a shard at all */
1793 	bool replacePrunedQueryWithDummy = true;
1794 
1795 	bool isLocalTableModification = false;
1796 
1797 	/* check if this query requires coordinator evaluation */
1798 	bool requiresCoordinatorEvaluation = RequiresCoordinatorEvaluation(originalQuery);
1799 	FastPathRestrictionContext *fastPathRestrictionContext =
1800 		plannerRestrictionContext->fastPathRestrictionContext;
1801 
1802 	/*
1803 	 * We prefer to defer shard pruning/task generation to the
1804 	 * execution when the parameter on the distribution key
1805 	 * cannot be resolved.
1806 	 */
1807 	if (fastPathRestrictionContext->fastPathRouterQuery &&
1808 		fastPathRestrictionContext->distributionKeyHasParam)
1809 	{
1810 		Job *job = CreateJob(originalQuery);
1811 		job->deferredPruning = true;
1812 
1813 		ereport(DEBUG2, (errmsg("Deferred pruning for a fast-path router "
1814 								"query")));
1815 		return job;
1816 	}
1817 	else
1818 	{
1819 		(*planningError) = PlanRouterQuery(originalQuery, plannerRestrictionContext,
1820 										   &placementList, &shardId, &relationShardList,
1821 										   &prunedShardIntervalListList,
1822 										   replacePrunedQueryWithDummy,
1823 										   &isMultiShardModifyQuery,
1824 										   &partitionKeyValue,
1825 										   &isLocalTableModification);
1826 	}
1827 
1828 	if (*planningError)
1829 	{
1830 		return NULL;
1831 	}
1832 
1833 	Job *job = CreateJob(originalQuery);
1834 	job->partitionKeyValue = partitionKeyValue;
1835 
1836 	if (originalQuery->resultRelation > 0)
1837 	{
1838 		RangeTblEntry *updateOrDeleteRTE = ExtractResultRelationRTE(originalQuery);
1839 
1840 		/*
1841 		 * If all of the shards are pruned, we replace the relation RTE into
1842 		 * subquery RTE that returns no results. However, this is not useful
1843 		 * for UPDATE and DELETE queries. Therefore, if we detect a UPDATE or
1844 		 * DELETE RTE with subquery type, we just set task list to empty and return
1845 		 * the job.
1846 		 */
1847 		if (updateOrDeleteRTE->rtekind == RTE_SUBQUERY)
1848 		{
1849 			job->taskList = NIL;
1850 			return job;
1851 		}
1852 	}
1853 
1854 	if (isMultiShardModifyQuery)
1855 	{
1856 		job->taskList = QueryPushdownSqlTaskList(originalQuery, job->jobId,
1857 												 plannerRestrictionContext->
1858 												 relationRestrictionContext,
1859 												 prunedShardIntervalListList,
1860 												 MODIFY_TASK,
1861 												 requiresCoordinatorEvaluation,
1862 												 planningError);
1863 		if (*planningError)
1864 		{
1865 			return NULL;
1866 		}
1867 	}
1868 	else
1869 	{
1870 		GenerateSingleShardRouterTaskList(job, relationShardList,
1871 										  placementList, shardId,
1872 										  isLocalTableModification);
1873 	}
1874 
1875 	job->requiresCoordinatorEvaluation = requiresCoordinatorEvaluation;
1876 	return job;
1877 }
1878 
1879 
1880 /*
1881  * SingleShardRouterTaskList is a wrapper around other corresponding task
1882  * list generation functions specific to single shard selects and modifications.
1883  *
1884  * The function updates the input job's taskList in-place.
1885  */
1886 void
GenerateSingleShardRouterTaskList(Job * job,List * relationShardList,List * placementList,uint64 shardId,bool isLocalTableModification)1887 GenerateSingleShardRouterTaskList(Job *job, List *relationShardList,
1888 								  List *placementList, uint64 shardId, bool
1889 								  isLocalTableModification)
1890 {
1891 	Query *originalQuery = job->jobQuery;
1892 
1893 	if (originalQuery->commandType == CMD_SELECT)
1894 	{
1895 		job->taskList = SingleShardTaskList(originalQuery, job->jobId,
1896 											relationShardList, placementList,
1897 											shardId,
1898 											job->parametersInJobQueryResolved,
1899 											isLocalTableModification);
1900 
1901 		/*
1902 		 * Queries to reference tables, or distributed tables with multiple replica's have
1903 		 * their task placements reordered according to the configured
1904 		 * task_assignment_policy. This is only applicable to select queries as the modify
1905 		 * queries will _always_ be executed on all placements.
1906 		 *
1907 		 * We also ignore queries that are targeting only intermediate results (e.g., no
1908 		 * valid anchorShardId).
1909 		 */
1910 		if (shardId != INVALID_SHARD_ID)
1911 		{
1912 			ReorderTaskPlacementsByTaskAssignmentPolicy(job, TaskAssignmentPolicy,
1913 														placementList);
1914 		}
1915 	}
1916 	else if (shardId == INVALID_SHARD_ID && !isLocalTableModification)
1917 	{
1918 		/* modification that prunes to 0 shards */
1919 		job->taskList = NIL;
1920 	}
1921 	else
1922 	{
1923 		job->taskList = SingleShardTaskList(originalQuery, job->jobId,
1924 											relationShardList, placementList,
1925 											shardId,
1926 											job->parametersInJobQueryResolved,
1927 											isLocalTableModification);
1928 	}
1929 }
1930 
1931 
1932 /*
1933  * ReorderTaskPlacementsByTaskAssignmentPolicy applies selective reordering for supported
1934  * TaskAssignmentPolicyTypes.
1935  *
1936  * Supported Types
1937  * - TASK_ASSIGNMENT_ROUND_ROBIN round robin schedule queries among placements
1938  *
1939  * By default it does not reorder the task list, implying a first-replica strategy.
1940  */
1941 static void
ReorderTaskPlacementsByTaskAssignmentPolicy(Job * job,TaskAssignmentPolicyType taskAssignmentPolicy,List * placementList)1942 ReorderTaskPlacementsByTaskAssignmentPolicy(Job *job,
1943 											TaskAssignmentPolicyType taskAssignmentPolicy,
1944 											List *placementList)
1945 {
1946 	if (taskAssignmentPolicy == TASK_ASSIGNMENT_ROUND_ROBIN)
1947 	{
1948 		/*
1949 		 * We hit a single shard on router plans, and there should be only
1950 		 * one task in the task list
1951 		 */
1952 		Assert(list_length(job->taskList) == 1);
1953 		Task *task = (Task *) linitial(job->taskList);
1954 
1955 		/*
1956 		 * For round-robin SELECT queries, we don't want to include the coordinator
1957 		 * because the user is trying to distributed the load across nodes via
1958 		 * round-robin policy. Otherwise, the local execution would prioritize
1959 		 * executing the local tasks and especially for reference tables on the
1960 		 * coordinator this would prevent load balancing accross nodes.
1961 		 *
1962 		 * For other worker nodes in Citus MX, we let the local execution to kick-in
1963 		 * even for round-robin policy, that's because we expect the clients to evenly
1964 		 * connect to the worker nodes.
1965 		 */
1966 		Assert(ReadOnlyTask(task->taskType));
1967 		placementList = RemoveCoordinatorPlacementIfNotSingleNode(placementList);
1968 
1969 		/* reorder the placement list */
1970 		List *reorderedPlacementList = RoundRobinReorder(placementList);
1971 		task->taskPlacementList = reorderedPlacementList;
1972 
1973 		ShardPlacement *primaryPlacement = (ShardPlacement *) linitial(
1974 			reorderedPlacementList);
1975 		ereport(DEBUG3, (errmsg("assigned task %u to node %s:%u", task->taskId,
1976 								primaryPlacement->nodeName,
1977 								primaryPlacement->nodePort)));
1978 	}
1979 }
1980 
1981 
1982 /*
1983  * RemoveCoordinatorPlacementIfNotSingleNode gets a task placement list and returns the list
1984  * by removing the placement belonging to the coordinator (if any).
1985  *
1986  * If the list has a single element or no placements on the coordinator, the list
1987  * returned is unmodified.
1988  */
1989 List *
RemoveCoordinatorPlacementIfNotSingleNode(List * placementList)1990 RemoveCoordinatorPlacementIfNotSingleNode(List *placementList)
1991 {
1992 	ListCell *placementCell = NULL;
1993 
1994 	if (list_length(placementList) < 2)
1995 	{
1996 		return placementList;
1997 	}
1998 
1999 	foreach(placementCell, placementList)
2000 	{
2001 		ShardPlacement *placement = (ShardPlacement *) lfirst(placementCell);
2002 
2003 		if (placement->groupId == COORDINATOR_GROUP_ID)
2004 		{
2005 			return list_delete_ptr(placementList, placement);
2006 		}
2007 	}
2008 
2009 	return placementList;
2010 }
2011 
2012 
2013 /*
2014  * SingleShardTaskList generates a task for single shard query
2015  * and returns it as a list.
2016  */
2017 static List *
SingleShardTaskList(Query * query,uint64 jobId,List * relationShardList,List * placementList,uint64 shardId,bool parametersInQueryResolved,bool isLocalTableModification)2018 SingleShardTaskList(Query *query, uint64 jobId, List *relationShardList,
2019 					List *placementList, uint64 shardId,
2020 					bool parametersInQueryResolved,
2021 					bool isLocalTableModification)
2022 {
2023 	TaskType taskType = READ_TASK;
2024 	char replicationModel = 0;
2025 
2026 	if (query->commandType != CMD_SELECT)
2027 	{
2028 		List *rangeTableList = NIL;
2029 		ExtractRangeTableEntryWalker((Node *) query, &rangeTableList);
2030 
2031 		RangeTblEntry *updateOrDeleteRTE = ExtractResultRelationRTE(query);
2032 		Assert(updateOrDeleteRTE != NULL);
2033 
2034 		CitusTableCacheEntry *modificationTableCacheEntry = NULL;
2035 		if (IsCitusTable(updateOrDeleteRTE->relid))
2036 		{
2037 			modificationTableCacheEntry = GetCitusTableCacheEntry(
2038 				updateOrDeleteRTE->relid);
2039 		}
2040 
2041 		if (IsCitusTableType(updateOrDeleteRTE->relid, REFERENCE_TABLE) &&
2042 			SelectsFromDistributedTable(rangeTableList, query))
2043 		{
2044 			ereport(ERROR, (errcode(ERRCODE_FEATURE_NOT_SUPPORTED),
2045 							errmsg("cannot perform select on a distributed table "
2046 								   "and modify a reference table")));
2047 		}
2048 
2049 		taskType = MODIFY_TASK;
2050 		if (modificationTableCacheEntry)
2051 		{
2052 			replicationModel = modificationTableCacheEntry->replicationModel;
2053 		}
2054 	}
2055 
2056 	if (taskType == READ_TASK && query->hasModifyingCTE)
2057 	{
2058 		/* assume ErrorIfQueryHasUnroutableModifyingCTE checked query already */
2059 
2060 		CommonTableExpr *cte = NULL;
2061 		foreach_ptr(cte, query->cteList)
2062 		{
2063 			Query *cteQuery = (Query *) cte->ctequery;
2064 
2065 			if (cteQuery->commandType != CMD_SELECT)
2066 			{
2067 				RangeTblEntry *updateOrDeleteRTE = ExtractResultRelationRTE(cteQuery);
2068 				CitusTableCacheEntry *modificationTableCacheEntry =
2069 					GetCitusTableCacheEntry(
2070 						updateOrDeleteRTE->relid);
2071 
2072 				taskType = MODIFY_TASK;
2073 				replicationModel = modificationTableCacheEntry->replicationModel;
2074 				break;
2075 			}
2076 		}
2077 	}
2078 
2079 	Task *task = CreateTask(taskType);
2080 	task->isLocalTableModification = isLocalTableModification;
2081 	List *relationRowLockList = NIL;
2082 
2083 	RowLocksOnRelations((Node *) query, &relationRowLockList);
2084 
2085 	/*
2086 	 * For performance reasons, we skip generating the queryString. For local
2087 	 * execution this is not needed, so we wait until the executor determines
2088 	 * that the query cannot be executed locally.
2089 	 */
2090 	task->taskPlacementList = placementList;
2091 	SetTaskQueryIfShouldLazyDeparse(task, query);
2092 	task->anchorShardId = shardId;
2093 	task->jobId = jobId;
2094 	task->relationShardList = relationShardList;
2095 	task->relationRowLockList = relationRowLockList;
2096 	task->replicationModel = replicationModel;
2097 	task->parametersInQueryStringResolved = parametersInQueryResolved;
2098 
2099 	return list_make1(task);
2100 }
2101 
2102 
2103 /*
2104  * RowLocksOnRelations forms the list for range table IDs and corresponding
2105  * row lock modes.
2106  */
2107 static bool
RowLocksOnRelations(Node * node,List ** relationRowLockList)2108 RowLocksOnRelations(Node *node, List **relationRowLockList)
2109 {
2110 	if (node == NULL)
2111 	{
2112 		return false;
2113 	}
2114 
2115 	if (IsA(node, Query))
2116 	{
2117 		Query *query = (Query *) node;
2118 		ListCell *rowMarkCell = NULL;
2119 
2120 		foreach(rowMarkCell, query->rowMarks)
2121 		{
2122 			RowMarkClause *rowMarkClause = (RowMarkClause *) lfirst(rowMarkCell);
2123 			RangeTblEntry *rangeTable = rt_fetch(rowMarkClause->rti, query->rtable);
2124 			Oid relationId = rangeTable->relid;
2125 
2126 			if (IsCitusTable(relationId))
2127 			{
2128 				RelationRowLock *relationRowLock = CitusMakeNode(RelationRowLock);
2129 				relationRowLock->relationId = relationId;
2130 				relationRowLock->rowLockStrength = rowMarkClause->strength;
2131 				*relationRowLockList = lappend(*relationRowLockList, relationRowLock);
2132 			}
2133 		}
2134 
2135 		return query_tree_walker(query, RowLocksOnRelations, relationRowLockList, 0);
2136 	}
2137 	else
2138 	{
2139 		return expression_tree_walker(node, RowLocksOnRelations, relationRowLockList);
2140 	}
2141 }
2142 
2143 
2144 /*
2145  * SelectsFromDistributedTable checks if there is a select on a distributed
2146  * table by looking into range table entries.
2147  */
2148 static bool
SelectsFromDistributedTable(List * rangeTableList,Query * query)2149 SelectsFromDistributedTable(List *rangeTableList, Query *query)
2150 {
2151 	ListCell *rangeTableCell = NULL;
2152 	RangeTblEntry *resultRangeTableEntry = NULL;
2153 
2154 	if (query->resultRelation > 0)
2155 	{
2156 		resultRangeTableEntry = ExtractResultRelationRTE(query);
2157 	}
2158 
2159 	foreach(rangeTableCell, rangeTableList)
2160 	{
2161 		RangeTblEntry *rangeTableEntry = (RangeTblEntry *) lfirst(rangeTableCell);
2162 
2163 		if (rangeTableEntry->relid == InvalidOid)
2164 		{
2165 			continue;
2166 		}
2167 
2168 		if (rangeTableEntry->relkind == RELKIND_VIEW ||
2169 			rangeTableEntry->relkind == RELKIND_MATVIEW)
2170 		{
2171 			/*
2172 			 * Skip over views, which would error out in GetCitusTableCacheEntry.
2173 			 * Distributed tables within (regular) views are already in rangeTableList.
2174 			 */
2175 			continue;
2176 		}
2177 
2178 		CitusTableCacheEntry *cacheEntry = GetCitusTableCacheEntry(
2179 			rangeTableEntry->relid);
2180 		if (IsCitusTableTypeCacheEntry(cacheEntry, DISTRIBUTED_TABLE) &&
2181 			(resultRangeTableEntry == NULL || resultRangeTableEntry->relid !=
2182 			 rangeTableEntry->relid))
2183 		{
2184 			return true;
2185 		}
2186 	}
2187 
2188 	return false;
2189 }
2190 
2191 
2192 static bool ContainsOnlyLocalTables(RTEListProperties *rteProperties);
2193 
2194 /*
2195  * RouterQuery runs router pruning logic for SELECT, UPDATE and DELETE queries.
2196  * If there are shards present and query is routable, all RTEs have been updated
2197  * to point to the relevant shards in the originalQuery. Also, placementList is
2198  * filled with the list of worker nodes that has all the required shard placements
2199  * for the query execution. anchorShardId is set to the first pruned shardId of
2200  * the given query. Finally, relationShardList is filled with the list of
2201  * relation-to-shard mappings for the query.
2202  *
2203  * If the given query is not routable, it fills planningError with the related
2204  * DeferredErrorMessage. The caller can check this error message to see if query
2205  * is routable or not.
2206  *
2207  * Note: If the query prunes down to 0 shards due to filters (e.g. WHERE false),
2208  * or the query has only read_intermediate_result calls (no relations left after
2209  * recursively planning CTEs and subqueries), then it will be assigned to an
2210  * arbitrary worker node in a round-robin fashion.
2211  *
2212  * Relations that prune down to 0 shards are replaced by subqueries returning
2213  * 0 values in UpdateRelationToShardNames.
2214  */
2215 DeferredErrorMessage *
PlanRouterQuery(Query * originalQuery,PlannerRestrictionContext * plannerRestrictionContext,List ** placementList,uint64 * anchorShardId,List ** relationShardList,List ** prunedShardIntervalListList,bool replacePrunedQueryWithDummy,bool * multiShardModifyQuery,Const ** partitionValueConst,bool * isLocalTableModification)2216 PlanRouterQuery(Query *originalQuery,
2217 				PlannerRestrictionContext *plannerRestrictionContext,
2218 				List **placementList, uint64 *anchorShardId, List **relationShardList,
2219 				List **prunedShardIntervalListList,
2220 				bool replacePrunedQueryWithDummy, bool *multiShardModifyQuery,
2221 				Const **partitionValueConst,
2222 				bool *isLocalTableModification)
2223 {
2224 	bool isMultiShardQuery = false;
2225 	DeferredErrorMessage *planningError = NULL;
2226 	bool shardsPresent = false;
2227 	CmdType commandType = originalQuery->commandType;
2228 	bool fastPathRouterQuery =
2229 		plannerRestrictionContext->fastPathRestrictionContext->fastPathRouterQuery;
2230 
2231 	*placementList = NIL;
2232 
2233 	/*
2234 	 * When FastPathRouterQuery() returns true, we know that standard_planner() has
2235 	 * not been called. Thus, restriction information is not avaliable and we do the
2236 	 * shard pruning based on the distribution column in the quals of the query.
2237 	 */
2238 	if (fastPathRouterQuery)
2239 	{
2240 		Const *distributionKeyValue =
2241 			plannerRestrictionContext->fastPathRestrictionContext->distributionKeyValue;
2242 
2243 		List *shardIntervalList =
2244 			TargetShardIntervalForFastPathQuery(originalQuery, &isMultiShardQuery,
2245 												distributionKeyValue,
2246 												partitionValueConst);
2247 
2248 		/*
2249 		 * This could only happen when there is a parameter on the distribution key.
2250 		 * We defer error here, later the planner is forced to use a generic plan
2251 		 * by assigning arbitrarily high cost to the plan.
2252 		 */
2253 		if (UpdateOrDeleteQuery(originalQuery) && isMultiShardQuery)
2254 		{
2255 			planningError = DeferredError(ERRCODE_FEATURE_NOT_SUPPORTED,
2256 										  "Router planner cannot handle multi-shard "
2257 										  "modify queries", NULL, NULL);
2258 			return planningError;
2259 		}
2260 
2261 		*prunedShardIntervalListList = shardIntervalList;
2262 
2263 		if (!isMultiShardQuery)
2264 		{
2265 			ereport(DEBUG2, (errmsg("Distributed planning for a fast-path router "
2266 									"query")));
2267 		}
2268 	}
2269 	else
2270 	{
2271 		*prunedShardIntervalListList =
2272 			TargetShardIntervalsForRestrictInfo(plannerRestrictionContext->
2273 												relationRestrictionContext,
2274 												&isMultiShardQuery,
2275 												partitionValueConst);
2276 	}
2277 
2278 	if (isMultiShardQuery)
2279 	{
2280 		/*
2281 		 * If multiShardQuery is true and it is a type of SELECT query, then
2282 		 * return deferred error. We do not support multi-shard SELECT queries
2283 		 * with this code path.
2284 		 */
2285 		if (commandType == CMD_SELECT)
2286 		{
2287 			return DeferredError(ERRCODE_FEATURE_NOT_SUPPORTED,
2288 								 "Router planner cannot handle multi-shard select queries",
2289 								 NULL, NULL);
2290 		}
2291 
2292 		Assert(UpdateOrDeleteQuery(originalQuery));
2293 		planningError = ModifyQuerySupported(originalQuery, originalQuery,
2294 											 isMultiShardQuery,
2295 											 plannerRestrictionContext);
2296 		if (planningError != NULL)
2297 		{
2298 			return planningError;
2299 		}
2300 		else
2301 		{
2302 			*multiShardModifyQuery = true;
2303 			return planningError;
2304 		}
2305 	}
2306 
2307 	*relationShardList =
2308 		RelationShardListForShardIntervalList(*prunedShardIntervalListList,
2309 											  &shardsPresent);
2310 
2311 	if (!shardsPresent && !replacePrunedQueryWithDummy)
2312 	{
2313 		/*
2314 		 * For INSERT ... SELECT, this query could be still a valid for some other target
2315 		 * shard intervals. Thus, we should return empty list if there aren't any matching
2316 		 * workers, so that the caller can decide what to do with this task.
2317 		 */
2318 		return NULL;
2319 	}
2320 
2321 	/*
2322 	 * We bail out if there are RTEs that prune multiple shards above, but
2323 	 * there can also be multiple RTEs that reference the same relation.
2324 	 */
2325 	if (RelationPrunesToMultipleShards(*relationShardList))
2326 	{
2327 		planningError = DeferredError(ERRCODE_FEATURE_NOT_SUPPORTED,
2328 									  "cannot run command which targets "
2329 									  "multiple shards", NULL, NULL);
2330 		return planningError;
2331 	}
2332 
2333 	/* we need anchor shard id for select queries with router planner */
2334 	uint64 shardId = GetAnchorShardId(*prunedShardIntervalListList);
2335 
2336 	/* both Postgres tables and materialized tables are locally avaliable */
2337 	RTEListProperties *rteProperties = GetRTEListPropertiesForQuery(originalQuery);
2338 	if (shardId == INVALID_SHARD_ID && ContainsOnlyLocalTables(rteProperties))
2339 	{
2340 		if (commandType != CMD_SELECT)
2341 		{
2342 			*isLocalTableModification = true;
2343 		}
2344 	}
2345 	bool hasPostgresLocalRelation =
2346 		rteProperties->hasPostgresLocalTable || rteProperties->hasMaterializedView;
2347 	List *taskPlacementList =
2348 		CreateTaskPlacementListForShardIntervals(*prunedShardIntervalListList,
2349 												 shardsPresent,
2350 												 replacePrunedQueryWithDummy,
2351 												 hasPostgresLocalRelation);
2352 	if (taskPlacementList == NIL)
2353 	{
2354 		planningError = DeferredError(ERRCODE_FEATURE_NOT_SUPPORTED,
2355 									  "found no worker with all shard placements",
2356 									  NULL, NULL);
2357 		return planningError;
2358 	}
2359 
2360 	/*
2361 	 * If this is an UPDATE or DELETE query which requires coordinator evaluation,
2362 	 * don't try update shard names, and postpone that to execution phase.
2363 	 */
2364 	bool isUpdateOrDelete = UpdateOrDeleteQuery(originalQuery);
2365 	if (!(isUpdateOrDelete && RequiresCoordinatorEvaluation(originalQuery)))
2366 	{
2367 		UpdateRelationToShardNames((Node *) originalQuery, *relationShardList);
2368 	}
2369 
2370 	*multiShardModifyQuery = false;
2371 	*placementList = taskPlacementList;
2372 	*anchorShardId = shardId;
2373 
2374 	return planningError;
2375 }
2376 
2377 
2378 /*
2379  * ContainsOnlyLocalTables returns true if there is only
2380  * local tables and not any distributed or reference table.
2381  */
2382 static bool
ContainsOnlyLocalTables(RTEListProperties * rteProperties)2383 ContainsOnlyLocalTables(RTEListProperties *rteProperties)
2384 {
2385 	return !rteProperties->hasDistributedTable && !rteProperties->hasReferenceTable;
2386 }
2387 
2388 
2389 /*
2390  * CreateTaskPlacementListForShardIntervals returns a list of shard placements
2391  * on which it can access all shards in shardIntervalListList, which contains
2392  * a list of shards for each relation in the query.
2393  *
2394  * If the query contains a local table then hasLocalRelation should be set to
2395  * true. In that case, CreateTaskPlacementListForShardIntervals only returns
2396  * a placement for the local node or an empty list if the shards cannot be
2397  * accessed locally.
2398  *
2399  * If generateDummyPlacement is true and there are no shards that need to be
2400  * accessed to answer the query (shardsPresent is false), then a single
2401  * placement is returned that is either local or follows a round-robin policy.
2402  * A typical example is a router query that only reads an intermediate result.
2403  * This will happen on the coordinator, unless the user wants to balance the
2404  * load by setting the citus.task_assignment_policy.
2405  */
2406 List *
CreateTaskPlacementListForShardIntervals(List * shardIntervalListList,bool shardsPresent,bool generateDummyPlacement,bool hasLocalRelation)2407 CreateTaskPlacementListForShardIntervals(List *shardIntervalListList, bool shardsPresent,
2408 										 bool generateDummyPlacement,
2409 										 bool hasLocalRelation)
2410 {
2411 	List *placementList = NIL;
2412 
2413 	if (shardsPresent)
2414 	{
2415 		/*
2416 		 * Determine the workers that have all shard placements, if any.
2417 		 */
2418 		List *shardPlacementList =
2419 			PlacementsForWorkersContainingAllShards(shardIntervalListList);
2420 
2421 		if (hasLocalRelation)
2422 		{
2423 			ShardPlacement *taskPlacement = NULL;
2424 
2425 			/*
2426 			 * If there is a local table, we only allow the local placement to
2427 			 * be used. If there is none, we disallow the query.
2428 			 */
2429 			foreach_ptr(taskPlacement, shardPlacementList)
2430 			{
2431 				if (taskPlacement->groupId == GetLocalGroupId())
2432 				{
2433 					placementList = lappend(placementList, taskPlacement);
2434 				}
2435 			}
2436 		}
2437 		else
2438 		{
2439 			placementList = shardPlacementList;
2440 		}
2441 	}
2442 	else if (generateDummyPlacement)
2443 	{
2444 		ShardPlacement *dummyPlacement = CreateDummyPlacement(hasLocalRelation);
2445 
2446 		placementList = list_make1(dummyPlacement);
2447 	}
2448 
2449 	return placementList;
2450 }
2451 
2452 
2453 /*
2454  * CreateLocalDummyPlacement creates a dummy placement for the local node that
2455  * can be used for queries that don't involve any shards. The typical examples
2456  * are:
2457  *       (a) queries that consist of only intermediate results
2458  *       (b) queries that hit zero shards (... WHERE false;)
2459  */
2460 static ShardPlacement *
CreateLocalDummyPlacement()2461 CreateLocalDummyPlacement()
2462 {
2463 	ShardPlacement *dummyPlacement = CitusMakeNode(ShardPlacement);
2464 	dummyPlacement->nodeId = LOCAL_NODE_ID;
2465 	dummyPlacement->nodeName = LocalHostName;
2466 	dummyPlacement->nodePort = PostPortNumber;
2467 	dummyPlacement->groupId = GetLocalGroupId();
2468 	return dummyPlacement;
2469 }
2470 
2471 
2472 /*
2473  * CreateDummyPlacement creates a dummy placement that can be used for queries
2474  * that don't involve any shards. The typical examples are:
2475  *       (a) queries that consist of only intermediate results
2476  *       (b) queries that hit zero shards (... WHERE false;)
2477  *
2478  * If round robin policy is set, the placement could be on any node in pg_dist_node.
2479  * Else, the local node is set for the placement.
2480  *
2481  * Queries can also involve local tables. In that case we always use the local
2482  * node.
2483  */
2484 static ShardPlacement *
CreateDummyPlacement(bool hasLocalRelation)2485 CreateDummyPlacement(bool hasLocalRelation)
2486 {
2487 	static uint32 zeroShardQueryRoundRobin = 0;
2488 
2489 	if (TaskAssignmentPolicy != TASK_ASSIGNMENT_ROUND_ROBIN || hasLocalRelation)
2490 	{
2491 		return CreateLocalDummyPlacement();
2492 	}
2493 
2494 	List *workerNodeList = ActiveReadableNonCoordinatorNodeList();
2495 	if (workerNodeList == NIL)
2496 	{
2497 		/*
2498 		 * We want to round-robin over the workers, but there are no workers.
2499 		 * To make sure the query can still succeed we fall back to returning
2500 		 * a local dummy placement.
2501 		 */
2502 		return CreateLocalDummyPlacement();
2503 	}
2504 
2505 	int workerNodeCount = list_length(workerNodeList);
2506 	int workerNodeIndex = zeroShardQueryRoundRobin % workerNodeCount;
2507 	WorkerNode *workerNode = (WorkerNode *) list_nth(workerNodeList,
2508 													 workerNodeIndex);
2509 
2510 	ShardPlacement *dummyPlacement = CitusMakeNode(ShardPlacement);
2511 	SetPlacementNodeMetadata(dummyPlacement, workerNode);
2512 
2513 	zeroShardQueryRoundRobin++;
2514 
2515 	return dummyPlacement;
2516 }
2517 
2518 
2519 /*
2520  * RelationShardListForShardIntervalList is a utility function which gets a list of
2521  * shardInterval, and returns a list of RelationShard.
2522  */
2523 List *
RelationShardListForShardIntervalList(List * shardIntervalList,bool * shardsPresent)2524 RelationShardListForShardIntervalList(List *shardIntervalList, bool *shardsPresent)
2525 {
2526 	List *relationShardList = NIL;
2527 	ListCell *shardIntervalListCell = NULL;
2528 
2529 	foreach(shardIntervalListCell, shardIntervalList)
2530 	{
2531 		List *prunedShardIntervalList = (List *) lfirst(shardIntervalListCell);
2532 
2533 		/* no shard is present or all shards are pruned out case will be handled later */
2534 		if (prunedShardIntervalList == NIL)
2535 		{
2536 			continue;
2537 		}
2538 
2539 		*shardsPresent = true;
2540 
2541 		ListCell *shardIntervalCell = NULL;
2542 		foreach(shardIntervalCell, prunedShardIntervalList)
2543 		{
2544 			ShardInterval *shardInterval = (ShardInterval *) lfirst(shardIntervalCell);
2545 			RelationShard *relationShard = CitusMakeNode(RelationShard);
2546 
2547 			relationShard->relationId = shardInterval->relationId;
2548 			relationShard->shardId = shardInterval->shardId;
2549 
2550 			relationShardList = lappend(relationShardList, relationShard);
2551 		}
2552 	}
2553 
2554 	return relationShardList;
2555 }
2556 
2557 
2558 /*
2559  * GetAnchorShardId returns the anchor shard id given relation shard list.
2560  * The desired anchor shard is found as follows:
2561  *
2562  * - Return the first distributed table shard id in the relationShardList if
2563  * there is any.
2564  * - Return a random reference table shard id if all the shards belong to
2565  * reference tables
2566  * - Return INVALID_SHARD_ID on empty lists
2567  */
2568 uint64
GetAnchorShardId(List * prunedShardIntervalListList)2569 GetAnchorShardId(List *prunedShardIntervalListList)
2570 {
2571 	ListCell *prunedShardIntervalListCell = NULL;
2572 	uint64 referenceShardId = INVALID_SHARD_ID;
2573 
2574 	foreach(prunedShardIntervalListCell, prunedShardIntervalListList)
2575 	{
2576 		List *prunedShardIntervalList = (List *) lfirst(prunedShardIntervalListCell);
2577 
2578 		/* no shard is present or all shards are pruned out case will be handled later */
2579 		if (prunedShardIntervalList == NIL)
2580 		{
2581 			continue;
2582 		}
2583 
2584 		ShardInterval *shardInterval = linitial(prunedShardIntervalList);
2585 
2586 		if (ReferenceTableShardId(shardInterval->shardId))
2587 		{
2588 			referenceShardId = shardInterval->shardId;
2589 		}
2590 		else
2591 		{
2592 			return shardInterval->shardId;
2593 		}
2594 	}
2595 
2596 	return referenceShardId;
2597 }
2598 
2599 
2600 /*
2601  * TargetShardIntervalForFastPathQuery gets a query which is in
2602  * the form defined by FastPathRouterQuery() and returns exactly
2603  * one list of a a one shard interval (see FastPathRouterQuery()
2604  * for the detail).
2605  *
2606  * If the caller requested the distributionKey value that this function
2607  * yields, set outputPartitionValueConst.
2608  */
2609 List *
TargetShardIntervalForFastPathQuery(Query * query,bool * isMultiShardQuery,Const * inputDistributionKeyValue,Const ** outputPartitionValueConst)2610 TargetShardIntervalForFastPathQuery(Query *query, bool *isMultiShardQuery,
2611 									Const *inputDistributionKeyValue,
2612 									Const **outputPartitionValueConst)
2613 {
2614 	Oid relationId = ExtractFirstCitusTableId(query);
2615 
2616 	if (IsCitusTableType(relationId, CITUS_TABLE_WITH_NO_DIST_KEY))
2617 	{
2618 		/* we don't need to do shard pruning for non-distributed tables */
2619 		return list_make1(LoadShardIntervalList(relationId));
2620 	}
2621 
2622 	if (inputDistributionKeyValue && !inputDistributionKeyValue->constisnull)
2623 	{
2624 		CitusTableCacheEntry *cache = GetCitusTableCacheEntry(relationId);
2625 		Var *distributionKey = cache->partitionColumn;
2626 
2627 		/*
2628 		 * We currently don't allow implicitly coerced values to be handled by fast-
2629 		 * path planner. Still, let's be defensive for any  future changes..
2630 		 */
2631 		if (inputDistributionKeyValue->consttype != distributionKey->vartype)
2632 		{
2633 			bool missingOk = false;
2634 			inputDistributionKeyValue =
2635 				TransformPartitionRestrictionValue(distributionKey,
2636 												   inputDistributionKeyValue, missingOk);
2637 		}
2638 
2639 		ShardInterval *cachedShardInterval =
2640 			FindShardInterval(inputDistributionKeyValue->constvalue, cache);
2641 		if (cachedShardInterval == NULL)
2642 		{
2643 			ereport(ERROR, (errmsg(
2644 								"could not find shardinterval to which to send the query")));
2645 		}
2646 
2647 		if (outputPartitionValueConst != NULL)
2648 		{
2649 			/* set the outgoing partition column value if requested */
2650 			*outputPartitionValueConst = inputDistributionKeyValue;
2651 		}
2652 		ShardInterval *shardInterval = CopyShardInterval(cachedShardInterval);
2653 		List *shardIntervalList = list_make1(shardInterval);
2654 
2655 		return list_make1(shardIntervalList);
2656 	}
2657 
2658 	Node *quals = query->jointree->quals;
2659 	int relationIndex = 1;
2660 
2661 	/*
2662 	 * We couldn't do the shard pruning based on inputDistributionKeyValue as it might
2663 	 * be passed as NULL. Still, we can search the quals for distribution key.
2664 	 */
2665 	Const *distributionKeyValueInQuals = NULL;
2666 	List *prunedShardIntervalList =
2667 		PruneShards(relationId, relationIndex, make_ands_implicit((Expr *) quals),
2668 					&distributionKeyValueInQuals);
2669 
2670 	if (!distributionKeyValueInQuals || distributionKeyValueInQuals->constisnull)
2671 	{
2672 		/*
2673 		 * If the distribution key equals to NULL, we prefer to treat it as a zero shard
2674 		 * query as it cannot return any rows.
2675 		 */
2676 		return NIL;
2677 	}
2678 
2679 	/* we're only expecting single shard from a single table */
2680 	Node *distKey PG_USED_FOR_ASSERTS_ONLY = NULL;
2681 	Assert(FastPathRouterQuery(query, &distKey) || !EnableFastPathRouterPlanner);
2682 
2683 	if (list_length(prunedShardIntervalList) > 1)
2684 	{
2685 		*isMultiShardQuery = true;
2686 	}
2687 	else if (list_length(prunedShardIntervalList) == 1 &&
2688 			 outputPartitionValueConst != NULL)
2689 	{
2690 		/* set the outgoing partition column value if requested */
2691 		*outputPartitionValueConst = distributionKeyValueInQuals;
2692 	}
2693 
2694 	return list_make1(prunedShardIntervalList);
2695 }
2696 
2697 
2698 /*
2699  * TargetShardIntervalsForRestrictInfo performs shard pruning for all referenced
2700  * relations in the relation restriction context and returns list of shards per
2701  * relation. Shard pruning is done based on provided restriction context per relation.
2702  * The function sets multiShardQuery to true if any of the relations pruned down to
2703  * more than one active shard. It also records pruned shard intervals in relation
2704  * restriction context to be used later on. Some queries may have contradiction
2705  * clauses like 'and false' or 'and 1=0', such queries are treated as if all of
2706  * the shards of joining relations are pruned out.
2707  */
2708 List *
TargetShardIntervalsForRestrictInfo(RelationRestrictionContext * restrictionContext,bool * multiShardQuery,Const ** partitionValueConst)2709 TargetShardIntervalsForRestrictInfo(RelationRestrictionContext *restrictionContext,
2710 									bool *multiShardQuery, Const **partitionValueConst)
2711 {
2712 	List *prunedShardIntervalListList = NIL;
2713 	ListCell *restrictionCell = NULL;
2714 	bool multiplePartitionValuesExist = false;
2715 	Const *queryPartitionValueConst = NULL;
2716 
2717 	Assert(restrictionContext != NULL);
2718 
2719 	foreach(restrictionCell, restrictionContext->relationRestrictionList)
2720 	{
2721 		RelationRestriction *relationRestriction =
2722 			(RelationRestriction *) lfirst(restrictionCell);
2723 		Oid relationId = relationRestriction->relationId;
2724 
2725 		if (!IsCitusTable(relationId))
2726 		{
2727 			/* ignore local tables for shard pruning purposes */
2728 			continue;
2729 		}
2730 
2731 		Index tableId = relationRestriction->index;
2732 		CitusTableCacheEntry *cacheEntry = GetCitusTableCacheEntry(relationId);
2733 		int shardCount = cacheEntry->shardIntervalArrayLength;
2734 		List *baseRestrictionList = relationRestriction->relOptInfo->baserestrictinfo;
2735 		List *restrictClauseList = get_all_actual_clauses(baseRestrictionList);
2736 		List *prunedShardIntervalList = NIL;
2737 
2738 		/*
2739 		 * Queries may have contradiction clauses like 'false', or '1=0' in
2740 		 * their filters. Such queries would have pseudo constant 'false'
2741 		 * inside relOptInfo->joininfo list. We treat such cases as if all
2742 		 * shards of the table are pruned out.
2743 		 */
2744 		bool joinFalseQuery = JoinConditionIsOnFalse(
2745 			relationRestriction->relOptInfo->joininfo);
2746 		if (!joinFalseQuery && shardCount > 0)
2747 		{
2748 			Const *restrictionPartitionValueConst = NULL;
2749 			prunedShardIntervalList = PruneShards(relationId, tableId, restrictClauseList,
2750 												  &restrictionPartitionValueConst);
2751 
2752 			if (list_length(prunedShardIntervalList) > 1)
2753 			{
2754 				(*multiShardQuery) = true;
2755 			}
2756 			if (restrictionPartitionValueConst != NULL &&
2757 				queryPartitionValueConst == NULL)
2758 			{
2759 				queryPartitionValueConst = restrictionPartitionValueConst;
2760 			}
2761 			else if (restrictionPartitionValueConst != NULL &&
2762 					 !equal(queryPartitionValueConst, restrictionPartitionValueConst))
2763 			{
2764 				multiplePartitionValuesExist = true;
2765 			}
2766 		}
2767 
2768 		prunedShardIntervalListList = lappend(prunedShardIntervalListList,
2769 											  prunedShardIntervalList);
2770 	}
2771 
2772 	/*
2773 	 * Different resrictions might have different partition columns.
2774 	 * We report partition column value if there is only one.
2775 	 */
2776 	if (multiplePartitionValuesExist)
2777 	{
2778 		queryPartitionValueConst = NULL;
2779 	}
2780 
2781 	/* set the outgoing partition column value if requested */
2782 	if (partitionValueConst != NULL)
2783 	{
2784 		*partitionValueConst = queryPartitionValueConst;
2785 	}
2786 
2787 	return prunedShardIntervalListList;
2788 }
2789 
2790 
2791 /*
2792  * JoinConditionIsOnFalse returns true for queries that
2793  * have contradiction clauses like 'false', or '1=0' in
2794  * their filters. Such queries would have pseudo constant 'false'
2795  * inside joininfo list.
2796  */
2797 bool
JoinConditionIsOnFalse(List * joinInfoList)2798 JoinConditionIsOnFalse(List *joinInfoList)
2799 {
2800 	List *pseudoJoinRestrictionList = extract_actual_clauses(joinInfoList, true);
2801 
2802 	bool joinFalseQuery = ContainsFalseClause(pseudoJoinRestrictionList);
2803 	return joinFalseQuery;
2804 }
2805 
2806 
2807 /*
2808  * RelationPrunesToMultipleShards returns true if the given list of
2809  * relation-to-shard mappings contains at least two mappings with
2810  * the same relation, but different shards.
2811  */
2812 static bool
RelationPrunesToMultipleShards(List * relationShardList)2813 RelationPrunesToMultipleShards(List *relationShardList)
2814 {
2815 	ListCell *relationShardCell = NULL;
2816 	RelationShard *previousRelationShard = NULL;
2817 
2818 	relationShardList = SortList(relationShardList, CompareRelationShards);
2819 
2820 	foreach(relationShardCell, relationShardList)
2821 	{
2822 		RelationShard *relationShard = (RelationShard *) lfirst(relationShardCell);
2823 
2824 		if (previousRelationShard != NULL &&
2825 			relationShard->relationId == previousRelationShard->relationId &&
2826 			relationShard->shardId != previousRelationShard->shardId)
2827 		{
2828 			return true;
2829 		}
2830 
2831 		previousRelationShard = relationShard;
2832 	}
2833 
2834 	return false;
2835 }
2836 
2837 
2838 /*
2839  * PlacementsForWorkersContainingAllShards returns list of shard placements for workers
2840  * that contain all shard intervals in the given list of shard interval lists.
2841  */
2842 List *
PlacementsForWorkersContainingAllShards(List * shardIntervalListList)2843 PlacementsForWorkersContainingAllShards(List *shardIntervalListList)
2844 {
2845 	bool firstShard = true;
2846 	List *currentPlacementList = NIL;
2847 	List *shardIntervalList = NIL;
2848 
2849 	foreach_ptr(shardIntervalList, shardIntervalListList)
2850 	{
2851 		if (shardIntervalList == NIL)
2852 		{
2853 			continue;
2854 		}
2855 
2856 		Assert(list_length(shardIntervalList) == 1);
2857 
2858 		ShardInterval *shardInterval = (ShardInterval *) linitial(shardIntervalList);
2859 		uint64 shardId = shardInterval->shardId;
2860 
2861 		/* retrieve all active shard placements for this shard */
2862 		List *newPlacementList = ActiveShardPlacementList(shardId);
2863 
2864 		if (firstShard)
2865 		{
2866 			firstShard = false;
2867 			currentPlacementList = newPlacementList;
2868 		}
2869 		else
2870 		{
2871 			/* keep placements that still exists for this shard */
2872 			currentPlacementList = IntersectPlacementList(currentPlacementList,
2873 														  newPlacementList);
2874 		}
2875 
2876 		/*
2877 		 * Bail out if placement list becomes empty. This means there is no worker
2878 		 * containing all shards referenced by the query, hence we can not forward
2879 		 * this query directly to any worker.
2880 		 */
2881 		if (currentPlacementList == NIL)
2882 		{
2883 			break;
2884 		}
2885 	}
2886 
2887 	return currentPlacementList;
2888 }
2889 
2890 
2891 /*
2892  * BuildRoutesForInsert returns a list of ModifyRoute objects for an INSERT
2893  * query or an empty list if the partition column value is defined as an ex-
2894  * pression that still needs to be evaluated. If any partition column value
2895  * falls within 0 or multiple (overlapping) shards, the planning error is set.
2896  *
2897  * Multi-row INSERTs are handled by grouping their rows by target shard. These
2898  * groups are returned in ascending order by shard id, ready for later deparse
2899  * to shard-specific SQL.
2900  */
2901 static List *
BuildRoutesForInsert(Query * query,DeferredErrorMessage ** planningError)2902 BuildRoutesForInsert(Query *query, DeferredErrorMessage **planningError)
2903 {
2904 	Oid distributedTableId = ExtractFirstCitusTableId(query);
2905 	CitusTableCacheEntry *cacheEntry = GetCitusTableCacheEntry(distributedTableId);
2906 	List *modifyRouteList = NIL;
2907 	ListCell *insertValuesCell = NULL;
2908 
2909 	Assert(query->commandType == CMD_INSERT);
2910 
2911 	/* reference tables and citus local tables can only have one shard */
2912 	if (IsCitusTableTypeCacheEntry(cacheEntry, CITUS_TABLE_WITH_NO_DIST_KEY))
2913 	{
2914 		List *shardIntervalList = LoadShardIntervalList(distributedTableId);
2915 
2916 		int shardCount = list_length(shardIntervalList);
2917 		if (shardCount != 1)
2918 		{
2919 			if (IsCitusTableTypeCacheEntry(cacheEntry, REFERENCE_TABLE))
2920 			{
2921 				ereport(ERROR, (errmsg("reference table cannot have %d shards",
2922 									   shardCount)));
2923 			}
2924 			else if (IsCitusTableTypeCacheEntry(cacheEntry, CITUS_LOCAL_TABLE))
2925 			{
2926 				ereport(ERROR, (errmsg("local table cannot have %d shards",
2927 									   shardCount)));
2928 			}
2929 		}
2930 
2931 		ShardInterval *shardInterval = linitial(shardIntervalList);
2932 		ModifyRoute *modifyRoute = palloc(sizeof(ModifyRoute));
2933 
2934 		modifyRoute->shardId = shardInterval->shardId;
2935 
2936 		RangeTblEntry *valuesRTE = ExtractDistributedInsertValuesRTE(query);
2937 		if (valuesRTE != NULL)
2938 		{
2939 			/* add the values list for a multi-row INSERT */
2940 			modifyRoute->rowValuesLists = valuesRTE->values_lists;
2941 		}
2942 		else
2943 		{
2944 			modifyRoute->rowValuesLists = NIL;
2945 		}
2946 
2947 		modifyRouteList = lappend(modifyRouteList, modifyRoute);
2948 
2949 		return modifyRouteList;
2950 	}
2951 
2952 	Var *partitionColumn = cacheEntry->partitionColumn;
2953 
2954 	/* get full list of insert values and iterate over them to prune */
2955 	List *insertValuesList = ExtractInsertValuesList(query, partitionColumn);
2956 
2957 	foreach(insertValuesCell, insertValuesList)
2958 	{
2959 		InsertValues *insertValues = (InsertValues *) lfirst(insertValuesCell);
2960 		List *prunedShardIntervalList = NIL;
2961 		Node *partitionValueExpr = (Node *) insertValues->partitionValueExpr;
2962 
2963 		/*
2964 		 * We only support constant partition values at this point. Sometimes
2965 		 * they are wrappend in an implicit coercion though. Most notably
2966 		 * FuncExpr coercions for casts created with CREATE CAST ... WITH
2967 		 * FUNCTION .. AS IMPLICIT. To support this first we strip them here.
2968 		 * Then we do the coercion manually below using
2969 		 * TransformPartitionRestrictionValue, if the types are not the same.
2970 		 *
2971 		 * NOTE: eval_const_expressions below would do some of these removals
2972 		 * too, but it's unclear if it would do all of them. It is possible
2973 		 * that there are no cases where this strip_implicit_coercions call is
2974 		 * really necessary at all, but currently that's hard to rule out.
2975 		 * So to be on the safe side we call strip_implicit_coercions too, to
2976 		 * be sure we support as much as possible.
2977 		 */
2978 		partitionValueExpr = strip_implicit_coercions(partitionValueExpr);
2979 
2980 		/*
2981 		 * By evaluating constant expressions an expression such as 2 + 4
2982 		 * will become const 6. That way we can use them as a partition column
2983 		 * value. Normally the planner evaluates constant expressions, but we
2984 		 * may be working on the original query tree here. So we do it here
2985 		 * explicitely before checking that the partition value is a const.
2986 		 *
2987 		 * NOTE: We do not use expression_planner here, since all it does
2988 		 * apart from calling eval_const_expressions is call fix_opfuncids.
2989 		 * This is not needed here, since it's a no-op for T_Const nodes and we
2990 		 * error out below in all other cases.
2991 		 */
2992 		partitionValueExpr = eval_const_expressions(NULL, partitionValueExpr);
2993 
2994 		if (!IsA(partitionValueExpr, Const))
2995 		{
2996 			ereport(ERROR, (errcode(ERRCODE_FEATURE_NOT_SUPPORTED),
2997 							errmsg("failed to evaluate partition key in insert"),
2998 							errhint("try using constant values for partition column")));
2999 		}
3000 
3001 		Const *partitionValueConst = (Const *) partitionValueExpr;
3002 		if (partitionValueConst->constisnull)
3003 		{
3004 			ereport(ERROR, (errcode(ERRCODE_NULL_VALUE_NOT_ALLOWED),
3005 							errmsg("cannot perform an INSERT with NULL in the partition "
3006 								   "column")));
3007 		}
3008 
3009 		/* actually do the coercions that we skipped before, if fails throw an
3010 		 * error */
3011 		if (partitionValueConst->consttype != partitionColumn->vartype)
3012 		{
3013 			bool missingOk = false;
3014 			partitionValueConst =
3015 				TransformPartitionRestrictionValue(partitionColumn,
3016 												   partitionValueConst,
3017 												   missingOk);
3018 		}
3019 
3020 		if (IsCitusTableTypeCacheEntry(cacheEntry, HASH_DISTRIBUTED) ||
3021 			IsCitusTableTypeCacheEntry(cacheEntry, RANGE_DISTRIBUTED))
3022 		{
3023 			Datum partitionValue = partitionValueConst->constvalue;
3024 
3025 			ShardInterval *shardInterval = FindShardInterval(partitionValue, cacheEntry);
3026 			if (shardInterval != NULL)
3027 			{
3028 				prunedShardIntervalList = list_make1(shardInterval);
3029 			}
3030 		}
3031 		else
3032 		{
3033 			Index tableId = 1;
3034 			OpExpr *equalityExpr = MakeOpExpression(partitionColumn,
3035 													BTEqualStrategyNumber);
3036 			Node *rightOp = get_rightop((Expr *) equalityExpr);
3037 
3038 			Assert(rightOp != NULL);
3039 			Assert(IsA(rightOp, Const));
3040 			Const *rightConst = (Const *) rightOp;
3041 
3042 			rightConst->constvalue = partitionValueConst->constvalue;
3043 			rightConst->constisnull = partitionValueConst->constisnull;
3044 			rightConst->constbyval = partitionValueConst->constbyval;
3045 
3046 			List *restrictClauseList = list_make1(equalityExpr);
3047 
3048 			prunedShardIntervalList = PruneShards(distributedTableId, tableId,
3049 												  restrictClauseList, NULL);
3050 		}
3051 
3052 		int prunedShardIntervalCount = list_length(prunedShardIntervalList);
3053 		if (prunedShardIntervalCount != 1)
3054 		{
3055 			char *partitionKeyString = cacheEntry->partitionKeyString;
3056 			char *partitionColumnName = ColumnToColumnName(distributedTableId,
3057 														   partitionKeyString);
3058 			StringInfo errorMessage = makeStringInfo();
3059 			StringInfo errorHint = makeStringInfo();
3060 			const char *targetCountType = NULL;
3061 
3062 			if (prunedShardIntervalCount == 0)
3063 			{
3064 				targetCountType = "no";
3065 			}
3066 			else
3067 			{
3068 				targetCountType = "multiple";
3069 			}
3070 
3071 			if (prunedShardIntervalCount == 0)
3072 			{
3073 				appendStringInfo(errorHint, "Make sure you have created a shard which "
3074 											"can receive this partition column value.");
3075 			}
3076 			else
3077 			{
3078 				appendStringInfo(errorHint, "Make sure the value for partition column "
3079 											"\"%s\" falls into a single shard.",
3080 								 partitionColumnName);
3081 			}
3082 
3083 			appendStringInfo(errorMessage, "cannot run INSERT command which targets %s "
3084 										   "shards", targetCountType);
3085 
3086 			(*planningError) = DeferredError(ERRCODE_FEATURE_NOT_SUPPORTED,
3087 											 errorMessage->data, NULL,
3088 											 errorHint->data);
3089 
3090 			return NIL;
3091 		}
3092 
3093 		ShardInterval *targetShard = (ShardInterval *) linitial(prunedShardIntervalList);
3094 		insertValues->shardId = targetShard->shardId;
3095 	}
3096 
3097 	modifyRouteList = GroupInsertValuesByShardId(insertValuesList);
3098 
3099 	return modifyRouteList;
3100 }
3101 
3102 
3103 /*
3104  * IsMultiRowInsert returns whether the given query is a multi-row INSERT.
3105  *
3106  * It does this by determining whether the query is an INSERT that has an
3107  * RTE_VALUES. Single-row INSERTs will have their RTE_VALUES optimised away
3108  * in transformInsertStmt, and instead use the target list.
3109  */
3110 bool
IsMultiRowInsert(Query * query)3111 IsMultiRowInsert(Query *query)
3112 {
3113 	return ExtractDistributedInsertValuesRTE(query) != NULL;
3114 }
3115 
3116 
3117 /*
3118  * ExtractDistributedInsertValuesRTE does precisely that. If the provided
3119  * query is not an INSERT, or if the INSERT does not have a VALUES RTE
3120  * (i.e. it is not a multi-row INSERT), this function returns NULL.
3121  * If all those conditions are met, an RTE representing the multiple values
3122  * of a multi-row INSERT is returned.
3123  */
3124 RangeTblEntry *
ExtractDistributedInsertValuesRTE(Query * query)3125 ExtractDistributedInsertValuesRTE(Query *query)
3126 {
3127 	ListCell *rteCell = NULL;
3128 
3129 	if (query->commandType != CMD_INSERT)
3130 	{
3131 		return NULL;
3132 	}
3133 
3134 	foreach(rteCell, query->rtable)
3135 	{
3136 		RangeTblEntry *rte = (RangeTblEntry *) lfirst(rteCell);
3137 
3138 		if (rte->rtekind == RTE_VALUES)
3139 		{
3140 			return rte;
3141 		}
3142 	}
3143 	return NULL;
3144 }
3145 
3146 
3147 /*
3148  * NormalizeMultiRowInsertTargetList ensures all elements of multi-row INSERT target
3149  * lists are Vars. In multi-row INSERTs, most target list entries contain a Var
3150  * expression pointing to a position within the values_lists field of a VALUES
3151  * RTE, but non-NULL default columns are handled differently. Instead of adding
3152  * the default expression to each row, a single expression encoding the DEFAULT
3153  * appears in the target list. For consistency, we move these expressions into
3154  * values lists and replace them with an appropriately constructed Var.
3155  */
3156 static void
NormalizeMultiRowInsertTargetList(Query * query)3157 NormalizeMultiRowInsertTargetList(Query *query)
3158 {
3159 	ListCell *valuesListCell = NULL;
3160 	ListCell *targetEntryCell = NULL;
3161 	int targetEntryNo = 0;
3162 
3163 	RangeTblEntry *valuesRTE = ExtractDistributedInsertValuesRTE(query);
3164 	if (valuesRTE == NULL)
3165 	{
3166 		return;
3167 	}
3168 
3169 	foreach(valuesListCell, valuesRTE->values_lists)
3170 	{
3171 		List *valuesList = (List *) lfirst(valuesListCell);
3172 		Expr **valuesArray = (Expr **) PointerArrayFromList(valuesList);
3173 		List *expandedValuesList = NIL;
3174 
3175 		foreach(targetEntryCell, query->targetList)
3176 		{
3177 			TargetEntry *targetEntry = (TargetEntry *) lfirst(targetEntryCell);
3178 			Expr *targetExpr = targetEntry->expr;
3179 
3180 			if (IsA(targetExpr, Var))
3181 			{
3182 				/* expression from the VALUES section */
3183 				Var *targetListVar = (Var *) targetExpr;
3184 				targetExpr = valuesArray[targetListVar->varattno - 1];
3185 			}
3186 			else
3187 			{
3188 				/* copy the column's default expression */
3189 				targetExpr = copyObject(targetExpr);
3190 			}
3191 
3192 			expandedValuesList = lappend(expandedValuesList, targetExpr);
3193 		}
3194 		SetListCellPtr(valuesListCell, (void *) expandedValuesList);
3195 	}
3196 
3197 	/* reset coltypes, coltypmods, colcollations and rebuild them below */
3198 	valuesRTE->coltypes = NIL;
3199 	valuesRTE->coltypmods = NIL;
3200 	valuesRTE->colcollations = NIL;
3201 
3202 	foreach(targetEntryCell, query->targetList)
3203 	{
3204 		TargetEntry *targetEntry = lfirst(targetEntryCell);
3205 		Node *targetExprNode = (Node *) targetEntry->expr;
3206 
3207 		/* RTE_VALUES comes 2nd, after destination table */
3208 		Index valuesVarno = 2;
3209 
3210 		targetEntryNo++;
3211 
3212 		Oid targetType = exprType(targetExprNode);
3213 		int32 targetTypmod = exprTypmod(targetExprNode);
3214 		Oid targetColl = exprCollation(targetExprNode);
3215 
3216 		valuesRTE->coltypes = lappend_oid(valuesRTE->coltypes, targetType);
3217 		valuesRTE->coltypmods = lappend_int(valuesRTE->coltypmods, targetTypmod);
3218 		valuesRTE->colcollations = lappend_oid(valuesRTE->colcollations, targetColl);
3219 
3220 		if (IsA(targetExprNode, Var))
3221 		{
3222 			Var *targetVar = (Var *) targetExprNode;
3223 			targetVar->varattno = targetEntryNo;
3224 			continue;
3225 		}
3226 
3227 		/* replace the original expression with a Var referencing values_lists */
3228 		Var *syntheticVar = makeVar(valuesVarno, targetEntryNo, targetType, targetTypmod,
3229 									targetColl, 0);
3230 		targetEntry->expr = (Expr *) syntheticVar;
3231 
3232 		/*
3233 		 * Postgres appends a dummy column reference into valuesRTE->eref->colnames
3234 		 * list in addRangeTableEntryForValues for each column specified in VALUES
3235 		 * clause. Now that we replaced DEFAULT column with a synthetic Var, we also
3236 		 * need to add a dummy column reference for that column.
3237 		 */
3238 		AppendNextDummyColReference(valuesRTE->eref);
3239 	}
3240 }
3241 
3242 
3243 /*
3244  * AppendNextDummyColReference appends a new dummy column reference to colnames
3245  * list of given Alias object.
3246  */
3247 static void
AppendNextDummyColReference(Alias * expendedReferenceNames)3248 AppendNextDummyColReference(Alias *expendedReferenceNames)
3249 {
3250 	int existingColReferences = list_length(expendedReferenceNames->colnames);
3251 	int nextColReferenceId = existingColReferences + 1;
3252 	Value *missingColumnString = MakeDummyColumnString(nextColReferenceId);
3253 	expendedReferenceNames->colnames = lappend(expendedReferenceNames->colnames,
3254 											   missingColumnString);
3255 }
3256 
3257 
3258 /*
3259  * MakeDummyColumnString returns a String (Value) object by appending given
3260  * integer to end of the "column" string.
3261  */
3262 static Value *
MakeDummyColumnString(int dummyColumnId)3263 MakeDummyColumnString(int dummyColumnId)
3264 {
3265 	StringInfo dummyColumnStringInfo = makeStringInfo();
3266 	appendStringInfo(dummyColumnStringInfo, "column%d", dummyColumnId);
3267 	Value *dummyColumnString = makeString(dummyColumnStringInfo->data);
3268 
3269 	return dummyColumnString;
3270 }
3271 
3272 
3273 /*
3274  * IntersectPlacementList performs placement pruning based on matching on
3275  * nodeName:nodePort fields of shard placement data. We start pruning from all
3276  * placements of the first relation's shard. Then for each relation's shard, we
3277  * compute intersection of the new shards placement with existing placement list.
3278  * This operation could have been done using other methods, but since we do not
3279  * expect very high replication factor, iterating over a list and making string
3280  * comparisons should be sufficient.
3281  */
3282 List *
IntersectPlacementList(List * lhsPlacementList,List * rhsPlacementList)3283 IntersectPlacementList(List *lhsPlacementList, List *rhsPlacementList)
3284 {
3285 	ListCell *lhsPlacementCell = NULL;
3286 	List *placementList = NIL;
3287 
3288 	/* Keep existing placement in the list if it is also present in new placement list */
3289 	foreach(lhsPlacementCell, lhsPlacementList)
3290 	{
3291 		ShardPlacement *lhsPlacement = (ShardPlacement *) lfirst(lhsPlacementCell);
3292 		ListCell *rhsPlacementCell = NULL;
3293 		foreach(rhsPlacementCell, rhsPlacementList)
3294 		{
3295 			ShardPlacement *rhsPlacement = (ShardPlacement *) lfirst(rhsPlacementCell);
3296 			if (rhsPlacement->nodePort == lhsPlacement->nodePort &&
3297 				strncmp(rhsPlacement->nodeName, lhsPlacement->nodeName,
3298 						WORKER_LENGTH) == 0)
3299 			{
3300 				placementList = lappend(placementList, rhsPlacement);
3301 
3302 				/*
3303 				 * We don't need to add the same placement over and over again. This
3304 				 * could happen if both placements of a shard appear on the same node.
3305 				 */
3306 				break;
3307 			}
3308 		}
3309 	}
3310 
3311 	return placementList;
3312 }
3313 
3314 
3315 /*
3316  * GroupInsertValuesByShardId takes care of grouping the rows from a multi-row
3317  * INSERT by target shard. At this point, all pruning has taken place and we
3318  * need only to build sets of rows for each destination. This is done by a
3319  * simple sort (by shard identifier) and gather step. The sort has the side-
3320  * effect of getting things in ascending order to avoid unnecessary deadlocks
3321  * during Task execution.
3322  */
3323 static List *
GroupInsertValuesByShardId(List * insertValuesList)3324 GroupInsertValuesByShardId(List *insertValuesList)
3325 {
3326 	ModifyRoute *route = NULL;
3327 	ListCell *insertValuesCell = NULL;
3328 	List *modifyRouteList = NIL;
3329 
3330 	insertValuesList = SortList(insertValuesList, CompareInsertValuesByShardId);
3331 	foreach(insertValuesCell, insertValuesList)
3332 	{
3333 		InsertValues *insertValues = (InsertValues *) lfirst(insertValuesCell);
3334 		int64 shardId = insertValues->shardId;
3335 		bool foundSameShardId = false;
3336 
3337 		if (route != NULL)
3338 		{
3339 			if (route->shardId == shardId)
3340 			{
3341 				foundSameShardId = true;
3342 			}
3343 			else
3344 			{
3345 				/* new shard id seen; current aggregation done; add to list */
3346 				modifyRouteList = lappend(modifyRouteList, route);
3347 			}
3348 		}
3349 
3350 		if (foundSameShardId)
3351 		{
3352 			/*
3353 			 * Our current value has the same shard id as our aggregate object,
3354 			 * so append the rowValues.
3355 			 */
3356 			route->rowValuesLists = lappend(route->rowValuesLists,
3357 											insertValues->rowValues);
3358 		}
3359 		else
3360 		{
3361 			/* we encountered a new shard id; build a new aggregate object */
3362 			route = (ModifyRoute *) palloc(sizeof(ModifyRoute));
3363 			route->shardId = insertValues->shardId;
3364 			route->rowValuesLists = list_make1(insertValues->rowValues);
3365 		}
3366 	}
3367 
3368 	/* left holding one final aggregate object; add to list */
3369 	modifyRouteList = lappend(modifyRouteList, route);
3370 
3371 	return modifyRouteList;
3372 }
3373 
3374 
3375 /*
3376  * ExtractInsertValuesList extracts the partition column value for an INSERT
3377  * command and returns it within an InsertValues struct. For single-row INSERTs
3378  * this is simply a value extracted from the target list, but multi-row INSERTs
3379  * will generate a List of InsertValues, each with full row values in addition
3380  * to the partition value. If a partition value is NULL or missing altogether,
3381  * this function errors.
3382  */
3383 static List *
ExtractInsertValuesList(Query * query,Var * partitionColumn)3384 ExtractInsertValuesList(Query *query, Var *partitionColumn)
3385 {
3386 	List *insertValuesList = NIL;
3387 	TargetEntry *targetEntry = get_tle_by_resno(query->targetList,
3388 												partitionColumn->varattno);
3389 
3390 	if (targetEntry == NULL)
3391 	{
3392 		ereport(ERROR, (errcode(ERRCODE_NULL_VALUE_NOT_ALLOWED),
3393 						errmsg("cannot perform an INSERT without a partition column "
3394 							   "value")));
3395 	}
3396 
3397 	/*
3398 	 * We've got a multi-row INSERT. PostgreSQL internally represents such
3399 	 * commands by linking Vars in the target list to lists of values within
3400 	 * a special VALUES range table entry. By extracting the right positional
3401 	 * expression from each list within that RTE, we will extract the partition
3402 	 * values for each row within the multi-row INSERT.
3403 	 */
3404 	if (IsA(targetEntry->expr, Var))
3405 	{
3406 		Var *partitionVar = (Var *) targetEntry->expr;
3407 		ListCell *valuesListCell = NULL;
3408 		Index ivIndex = 0;
3409 
3410 		RangeTblEntry *referencedRTE = rt_fetch(partitionVar->varno, query->rtable);
3411 		foreach(valuesListCell, referencedRTE->values_lists)
3412 		{
3413 			InsertValues *insertValues = (InsertValues *) palloc(sizeof(InsertValues));
3414 			insertValues->rowValues = (List *) lfirst(valuesListCell);
3415 			insertValues->partitionValueExpr = list_nth(insertValues->rowValues,
3416 														(partitionVar->varattno - 1));
3417 			insertValues->shardId = INVALID_SHARD_ID;
3418 			insertValues->listIndex = ivIndex;
3419 
3420 			insertValuesList = lappend(insertValuesList, insertValues);
3421 			ivIndex++;
3422 		}
3423 	}
3424 
3425 	/* nothing's been found yet; this is a simple single-row INSERT */
3426 	if (insertValuesList == NIL)
3427 	{
3428 		InsertValues *insertValues = (InsertValues *) palloc(sizeof(InsertValues));
3429 		insertValues->rowValues = NIL;
3430 		insertValues->partitionValueExpr = targetEntry->expr;
3431 		insertValues->shardId = INVALID_SHARD_ID;
3432 
3433 		insertValuesList = lappend(insertValuesList, insertValues);
3434 	}
3435 
3436 	return insertValuesList;
3437 }
3438 
3439 
3440 /*
3441  * ExtractInsertPartitionKeyValue extracts the partition column value
3442  * from an INSERT query. If the expression in the partition column is
3443  * non-constant or it is a multi-row INSERT with multiple different partition
3444  * column values, the function returns NULL.
3445  */
3446 Const *
ExtractInsertPartitionKeyValue(Query * query)3447 ExtractInsertPartitionKeyValue(Query *query)
3448 {
3449 	Oid distributedTableId = ExtractFirstCitusTableId(query);
3450 	uint32 rangeTableId = 1;
3451 	Const *singlePartitionValueConst = NULL;
3452 
3453 	if (IsCitusTableType(distributedTableId, CITUS_TABLE_WITH_NO_DIST_KEY))
3454 	{
3455 		return NULL;
3456 	}
3457 
3458 	Var *partitionColumn = PartitionColumn(distributedTableId, rangeTableId);
3459 	TargetEntry *targetEntry = get_tle_by_resno(query->targetList,
3460 												partitionColumn->varattno);
3461 	if (targetEntry == NULL)
3462 	{
3463 		/* partition column value not specified */
3464 		return NULL;
3465 	}
3466 
3467 	Node *targetExpression = strip_implicit_coercions((Node *) targetEntry->expr);
3468 
3469 	/*
3470 	 * Multi-row INSERTs have a Var in the target list that points to
3471 	 * an RTE_VALUES.
3472 	 */
3473 	if (IsA(targetExpression, Var))
3474 	{
3475 		Var *partitionVar = (Var *) targetExpression;
3476 		ListCell *valuesListCell = NULL;
3477 
3478 		RangeTblEntry *referencedRTE = rt_fetch(partitionVar->varno, query->rtable);
3479 
3480 		foreach(valuesListCell, referencedRTE->values_lists)
3481 		{
3482 			List *rowValues = (List *) lfirst(valuesListCell);
3483 			Node *partitionValueNode = list_nth(rowValues, partitionVar->varattno - 1);
3484 			Expr *partitionValueExpr = (Expr *) strip_implicit_coercions(
3485 				partitionValueNode);
3486 
3487 			if (!IsA(partitionValueExpr, Const))
3488 			{
3489 				/* non-constant value in the partition column */
3490 				singlePartitionValueConst = NULL;
3491 				break;
3492 			}
3493 
3494 			Const *partitionValueConst = (Const *) partitionValueExpr;
3495 
3496 			if (singlePartitionValueConst == NULL)
3497 			{
3498 				/* first row has a constant in the partition column, looks promising! */
3499 				singlePartitionValueConst = partitionValueConst;
3500 			}
3501 			else if (!equal(partitionValueConst, singlePartitionValueConst))
3502 			{
3503 				/* multiple different values in the partition column, too bad */
3504 				singlePartitionValueConst = NULL;
3505 				break;
3506 			}
3507 			else
3508 			{
3509 				/* another row with the same partition column value! */
3510 			}
3511 		}
3512 	}
3513 	else if (IsA(targetExpression, Const))
3514 	{
3515 		/* single-row INSERT with a constant partition column value */
3516 		singlePartitionValueConst = (Const *) targetExpression;
3517 	}
3518 	else
3519 	{
3520 		/* single-row INSERT with a non-constant partition column value */
3521 		singlePartitionValueConst = NULL;
3522 	}
3523 
3524 	if (singlePartitionValueConst != NULL)
3525 	{
3526 		singlePartitionValueConst = copyObject(singlePartitionValueConst);
3527 	}
3528 
3529 	return singlePartitionValueConst;
3530 }
3531 
3532 
3533 /*
3534  * DeferErrorIfUnsupportedRouterPlannableSelectQuery checks if given query is router plannable,
3535  * SELECT query, setting distributedPlan->planningError if not.
3536  * The query is router plannable if it is a modify query, or if it is a select
3537  * query issued on a hash partitioned distributed table. Router plannable checks
3538  * for select queries can be turned off by setting citus.enable_router_execution
3539  * flag to false.
3540  */
3541 static DeferredErrorMessage *
DeferErrorIfUnsupportedRouterPlannableSelectQuery(Query * query)3542 DeferErrorIfUnsupportedRouterPlannableSelectQuery(Query *query)
3543 {
3544 	List *rangeTableRelationList = NIL;
3545 	ListCell *rangeTableRelationCell = NULL;
3546 
3547 	if (query->commandType != CMD_SELECT)
3548 	{
3549 		return DeferredError(ERRCODE_ASSERT_FAILURE,
3550 							 "Only SELECT query types are supported in this path",
3551 							 NULL, NULL);
3552 	}
3553 
3554 	if (!EnableRouterExecution)
3555 	{
3556 		return DeferredError(ERRCODE_SUCCESSFUL_COMPLETION,
3557 							 "Router planner not enabled.",
3558 							 NULL, NULL);
3559 	}
3560 
3561 	if (contain_nextval_expression_walker((Node *) query->targetList, NULL))
3562 	{
3563 		/*
3564 		 * We let queries with nextval in the target list fall through to
3565 		 * the logical planner, which knows how to handle those queries.
3566 		 */
3567 		return DeferredError(ERRCODE_FEATURE_NOT_SUPPORTED,
3568 							 "Sequences cannot be used in router queries",
3569 							 NULL, NULL);
3570 	}
3571 
3572 	bool hasPostgresOrCitusLocalTable = false;
3573 	bool hasDistributedTable = false;
3574 
3575 	ExtractRangeTableRelationWalker((Node *) query, &rangeTableRelationList);
3576 	foreach(rangeTableRelationCell, rangeTableRelationList)
3577 	{
3578 		RangeTblEntry *rte = (RangeTblEntry *) lfirst(rangeTableRelationCell);
3579 		if (rte->rtekind == RTE_RELATION)
3580 		{
3581 			Oid distributedTableId = rte->relid;
3582 
3583 			/* local tables are allowed if there are no distributed tables */
3584 			if (!IsCitusTable(distributedTableId))
3585 			{
3586 				hasPostgresOrCitusLocalTable = true;
3587 				continue;
3588 			}
3589 			else if (IsCitusTableType(distributedTableId, CITUS_LOCAL_TABLE))
3590 			{
3591 				hasPostgresOrCitusLocalTable = true;
3592 				elog(DEBUG4, "Router planner finds a local table added to metadata");
3593 				continue;
3594 			}
3595 
3596 			if (IsCitusTableType(distributedTableId, APPEND_DISTRIBUTED))
3597 			{
3598 				return DeferredError(
3599 					ERRCODE_FEATURE_NOT_SUPPORTED,
3600 					"Router planner does not support append-partitioned tables.",
3601 					NULL, NULL);
3602 			}
3603 
3604 			if (IsCitusTableType(distributedTableId, DISTRIBUTED_TABLE))
3605 			{
3606 				hasDistributedTable = true;
3607 			}
3608 
3609 			/*
3610 			 * Currently, we don't support tables with replication factor > 1,
3611 			 * except reference tables with SELECT ... FOR UPDATE queries. It is
3612 			 * also not supported from MX nodes.
3613 			 */
3614 			if (query->hasForUpdate)
3615 			{
3616 				uint32 tableReplicationFactor = TableShardReplicationFactor(
3617 					distributedTableId);
3618 
3619 				if (tableReplicationFactor > 1 && IsCitusTableType(distributedTableId,
3620 																   DISTRIBUTED_TABLE))
3621 				{
3622 					return DeferredError(
3623 						ERRCODE_FEATURE_NOT_SUPPORTED,
3624 						"SELECT FOR UPDATE with table replication factor > 1 not supported for non-reference tables.",
3625 						NULL, NULL);
3626 				}
3627 			}
3628 		}
3629 	}
3630 
3631 	/* local tables are not allowed if there are distributed tables */
3632 	if (hasPostgresOrCitusLocalTable && hasDistributedTable)
3633 	{
3634 		return DeferredError(ERRCODE_FEATURE_NOT_SUPPORTED,
3635 							 "Local tables cannot be used in distributed queries.",
3636 							 NULL, NULL);
3637 	}
3638 
3639 #if PG_VERSION_NUM >= PG_VERSION_14
3640 	DeferredErrorMessage *CTEWithSearchClauseError =
3641 		ErrorIfQueryHasCTEWithSearchClause(query);
3642 	if (CTEWithSearchClauseError != NULL)
3643 	{
3644 		return CTEWithSearchClauseError;
3645 	}
3646 #endif
3647 
3648 	return ErrorIfQueryHasUnroutableModifyingCTE(query);
3649 }
3650 
3651 
3652 /*
3653  * Copy a RelationRestrictionContext. Note that several subfields are copied
3654  * shallowly, for lack of copyObject support.
3655  *
3656  * Note that CopyRelationRestrictionContext copies the following fields per relation
3657  * context: index, relationId, distributedRelation, rte, relOptInfo->baserestrictinfo
3658  * and relOptInfo->joininfo. Also, the function shallowly copies plannerInfo and
3659  * prunedShardIntervalList which are read-only. All other parts of the relOptInfo
3660  * is also shallowly copied.
3661  */
3662 RelationRestrictionContext *
CopyRelationRestrictionContext(RelationRestrictionContext * oldContext)3663 CopyRelationRestrictionContext(RelationRestrictionContext *oldContext)
3664 {
3665 	RelationRestrictionContext *newContext =
3666 		(RelationRestrictionContext *) palloc(sizeof(RelationRestrictionContext));
3667 	ListCell *relationRestrictionCell = NULL;
3668 
3669 	newContext->allReferenceTables = oldContext->allReferenceTables;
3670 	newContext->relationRestrictionList = NIL;
3671 
3672 	foreach(relationRestrictionCell, oldContext->relationRestrictionList)
3673 	{
3674 		RelationRestriction *oldRestriction =
3675 			(RelationRestriction *) lfirst(relationRestrictionCell);
3676 		RelationRestriction *newRestriction = (RelationRestriction *)
3677 											  palloc0(sizeof(RelationRestriction));
3678 
3679 		newRestriction->index = oldRestriction->index;
3680 		newRestriction->relationId = oldRestriction->relationId;
3681 		newRestriction->distributedRelation = oldRestriction->distributedRelation;
3682 		newRestriction->rte = copyObject(oldRestriction->rte);
3683 
3684 		/* can't be copied, we copy (flatly) a RelOptInfo, and then decouple baserestrictinfo */
3685 		newRestriction->relOptInfo = palloc(sizeof(RelOptInfo));
3686 		*newRestriction->relOptInfo = *oldRestriction->relOptInfo;
3687 
3688 		newRestriction->relOptInfo->baserestrictinfo =
3689 			copyObject(oldRestriction->relOptInfo->baserestrictinfo);
3690 
3691 		newRestriction->relOptInfo->joininfo =
3692 			copyObject(oldRestriction->relOptInfo->joininfo);
3693 
3694 		/* not copyable, but readonly */
3695 		newRestriction->plannerInfo = oldRestriction->plannerInfo;
3696 
3697 		newContext->relationRestrictionList =
3698 			lappend(newContext->relationRestrictionList, newRestriction);
3699 	}
3700 
3701 	return newContext;
3702 }
3703 
3704 
3705 /*
3706  * ErrorIfQueryHasUnroutableModifyingCTE checks if the query contains modifying common table
3707  * expressions and errors out if it does.
3708  */
3709 static DeferredErrorMessage *
ErrorIfQueryHasUnroutableModifyingCTE(Query * queryTree)3710 ErrorIfQueryHasUnroutableModifyingCTE(Query *queryTree)
3711 {
3712 	Assert(queryTree->commandType == CMD_SELECT);
3713 
3714 	if (!queryTree->hasModifyingCTE)
3715 	{
3716 		return NULL;
3717 	}
3718 
3719 	/* we can't route conflicting replication models */
3720 	char replicationModel = 0;
3721 
3722 	CommonTableExpr *cte = NULL;
3723 	foreach_ptr(cte, queryTree->cteList)
3724 	{
3725 		Query *cteQuery = (Query *) cte->ctequery;
3726 
3727 		/*
3728 		 * Here we only check for command type of top level query. Normally there can be
3729 		 * nested CTE, however PostgreSQL dictates that data-modifying statements must
3730 		 * be at top level of CTE. Therefore it is OK to just check for top level.
3731 		 * Similarly, we do not need to check for subqueries.
3732 		 */
3733 		if (cteQuery->commandType != CMD_SELECT &&
3734 			cteQuery->commandType != CMD_UPDATE &&
3735 			cteQuery->commandType != CMD_DELETE)
3736 		{
3737 			return DeferredError(ERRCODE_FEATURE_NOT_SUPPORTED,
3738 								 "only SELECT, UPDATE, or DELETE common table expressions "
3739 								 "may be router planned",
3740 								 NULL, NULL);
3741 		}
3742 
3743 		if (cteQuery->commandType != CMD_SELECT)
3744 		{
3745 			Oid distributedTableId = InvalidOid;
3746 			DeferredErrorMessage *cteError =
3747 				ModifyPartialQuerySupported(cteQuery, false, &distributedTableId);
3748 			if (cteError)
3749 			{
3750 				return cteError;
3751 			}
3752 
3753 			CitusTableCacheEntry *modificationTableCacheEntry =
3754 				GetCitusTableCacheEntry(distributedTableId);
3755 
3756 			if (IsCitusTableTypeCacheEntry(modificationTableCacheEntry,
3757 										   CITUS_TABLE_WITH_NO_DIST_KEY))
3758 			{
3759 				return DeferredError(ERRCODE_FEATURE_NOT_SUPPORTED,
3760 									 "cannot router plan modification of a non-distributed table",
3761 									 NULL, NULL);
3762 			}
3763 
3764 			if (replicationModel &&
3765 				modificationTableCacheEntry->replicationModel != replicationModel)
3766 			{
3767 				return DeferredError(ERRCODE_FEATURE_NOT_SUPPORTED,
3768 									 "cannot route mixed replication models",
3769 									 NULL, NULL);
3770 			}
3771 
3772 			replicationModel = modificationTableCacheEntry->replicationModel;
3773 		}
3774 	}
3775 
3776 	/* everything OK */
3777 	return NULL;
3778 }
3779 
3780 
3781 #if PG_VERSION_NUM >= PG_VERSION_14
3782 
3783 /*
3784  * ErrorIfQueryHasCTEWithSearchClause checks if the query contains any common table
3785  * expressions with search clause and errors out if it does.
3786  */
3787 static DeferredErrorMessage *
ErrorIfQueryHasCTEWithSearchClause(Query * queryTree)3788 ErrorIfQueryHasCTEWithSearchClause(Query *queryTree)
3789 {
3790 	if (ContainsSearchClauseWalker((Node *) queryTree))
3791 	{
3792 		return DeferredError(ERRCODE_FEATURE_NOT_SUPPORTED,
3793 							 "CTEs with search clauses are not supported",
3794 							 NULL, NULL);
3795 	}
3796 	return NULL;
3797 }
3798 
3799 
3800 /*
3801  * ContainsSearchClauseWalker walks over the node and finds if there are any
3802  * CommonTableExprs with search clause
3803  */
3804 static bool
ContainsSearchClauseWalker(Node * node)3805 ContainsSearchClauseWalker(Node *node)
3806 {
3807 	if (node == NULL)
3808 	{
3809 		return false;
3810 	}
3811 
3812 	if (IsA(node, CommonTableExpr))
3813 	{
3814 		if (((CommonTableExpr *) node)->search_clause != NULL)
3815 		{
3816 			return true;
3817 		}
3818 	}
3819 
3820 	if (IsA(node, Query))
3821 	{
3822 		return query_tree_walker((Query *) node, ContainsSearchClauseWalker, NULL, 0);
3823 	}
3824 
3825 	return expression_tree_walker(node, ContainsSearchClauseWalker, NULL);
3826 }
3827 
3828 
3829 #endif
3830 
3831 
3832 /*
3833  * get_all_actual_clauses
3834  *
3835  * Returns a list containing the bare clauses from 'restrictinfo_list'.
3836  *
3837  * This loses the distinction between regular and pseudoconstant clauses,
3838  * so be careful what you use it for.
3839  */
3840 List *
get_all_actual_clauses(List * restrictinfo_list)3841 get_all_actual_clauses(List *restrictinfo_list)
3842 {
3843 	List *result = NIL;
3844 	ListCell *l;
3845 
3846 	foreach(l, restrictinfo_list)
3847 	{
3848 		RestrictInfo *rinfo = (RestrictInfo *) lfirst(l);
3849 
3850 		Assert(IsA(rinfo, RestrictInfo));
3851 
3852 		result = lappend(result, rinfo->clause);
3853 	}
3854 	return result;
3855 }
3856 
3857 
3858 /*
3859  * CompareInsertValuesByShardId does what it says in the name. Used for sorting
3860  * InsertValues objects by their shard.
3861  */
3862 static int
CompareInsertValuesByShardId(const void * leftElement,const void * rightElement)3863 CompareInsertValuesByShardId(const void *leftElement, const void *rightElement)
3864 {
3865 	InsertValues *leftValue = *((InsertValues **) leftElement);
3866 	InsertValues *rightValue = *((InsertValues **) rightElement);
3867 	int64 leftShardId = leftValue->shardId;
3868 	int64 rightShardId = rightValue->shardId;
3869 	Index leftIndex = leftValue->listIndex;
3870 	Index rightIndex = rightValue->listIndex;
3871 
3872 	if (leftShardId > rightShardId)
3873 	{
3874 		return 1;
3875 	}
3876 	else if (leftShardId < rightShardId)
3877 	{
3878 		return -1;
3879 	}
3880 	else
3881 	{
3882 		/* shard identifiers are the same, list index is secondary sort key */
3883 		if (leftIndex > rightIndex)
3884 		{
3885 			return 1;
3886 		}
3887 		else if (leftIndex < rightIndex)
3888 		{
3889 			return -1;
3890 		}
3891 		else
3892 		{
3893 			return 0;
3894 		}
3895 	}
3896 }
3897