1 /*-------------------------------------------------------------------------
2  *
3  * function.c
4  *    Commands for FUNCTION statements.
5  *
6  *    We currently support replicating function definitions on the
7  *    coordinator in all the worker nodes in the form of
8  *
9  *    CREATE OR REPLACE FUNCTION ... queries.
10  *
11  *    ALTER or DROP operations are not yet propagated.
12  *
13  * Copyright (c) Citus Data, Inc.
14  *
15  *-------------------------------------------------------------------------
16  */
17 
18 #include "postgres.h"
19 #include "miscadmin.h"
20 #include "funcapi.h"
21 
22 #include "distributed/pg_version_constants.h"
23 
24 #include "access/genam.h"
25 #include "access/htup_details.h"
26 #include "access/xact.h"
27 #include "catalog/pg_aggregate.h"
28 #include "catalog/namespace.h"
29 #include "catalog/pg_proc.h"
30 #include "catalog/pg_type.h"
31 #include "commands/extension.h"
32 #include "distributed/citus_ruleutils.h"
33 #include "distributed/citus_safe_lib.h"
34 #include "distributed/colocation_utils.h"
35 #include "distributed/commands.h"
36 #include "distributed/commands/utility_hook.h"
37 #include "distributed/deparser.h"
38 #include "distributed/listutils.h"
39 #include "distributed/maintenanced.h"
40 #include "distributed/metadata_utility.h"
41 #include "distributed/coordinator_protocol.h"
42 #include "distributed/metadata/distobject.h"
43 #include "distributed/metadata/pg_dist_object.h"
44 #include "distributed/metadata_sync.h"
45 #include "distributed/multi_executor.h"
46 #include "distributed/namespace_utils.h"
47 #include "distributed/pg_dist_node.h"
48 #include "distributed/reference_table_utils.h"
49 #include "distributed/relation_access_tracking.h"
50 #include "distributed/version_compat.h"
51 #include "distributed/worker_create_or_replace.h"
52 #include "distributed/worker_transaction.h"
53 #include "nodes/makefuncs.h"
54 #include "parser/parse_coerce.h"
55 #include "parser/parse_type.h"
56 #include "storage/lmgr.h"
57 #include "utils/builtins.h"
58 #include "utils/fmgroids.h"
59 #include "utils/fmgrprotos.h"
60 #include "utils/lsyscache.h"
61 #include "utils/syscache.h"
62 #include "utils/regproc.h"
63 
64 #define DISABLE_LOCAL_CHECK_FUNCTION_BODIES "SET LOCAL check_function_bodies TO off;"
65 #define RESET_CHECK_FUNCTION_BODIES "RESET check_function_bodies;"
66 #define argumentStartsWith(arg, prefix) \
67 	(strncmp(arg, prefix, strlen(prefix)) == 0)
68 
69 /* forward declaration for helper functions*/
70 static char * GetAggregateDDLCommand(const RegProcedure funcOid, bool useCreateOrReplace);
71 static char * GetFunctionAlterOwnerCommand(const RegProcedure funcOid);
72 static int GetDistributionArgIndex(Oid functionOid, char *distributionArgumentName,
73 								   Oid *distributionArgumentOid);
74 static int GetFunctionColocationId(Oid functionOid, char *colocateWithName, Oid
75 								   distributionArgumentOid);
76 static void EnsureFunctionCanBeColocatedWithTable(Oid functionOid, Oid
77 												  distributionColumnType, Oid
78 												  sourceRelationId);
79 static void UpdateFunctionDistributionInfo(const ObjectAddress *distAddress,
80 										   int *distribution_argument_index,
81 										   int *colocationId);
82 static void EnsureSequentialModeForFunctionDDL(void);
83 static void TriggerSyncMetadataToPrimaryNodes(void);
84 static bool ShouldPropagateCreateFunction(CreateFunctionStmt *stmt);
85 static bool ShouldPropagateAlterFunction(const ObjectAddress *address);
86 static ObjectAddress FunctionToObjectAddress(ObjectType objectType,
87 											 ObjectWithArgs *objectWithArgs,
88 											 bool missing_ok);
89 static void ErrorIfUnsupportedAlterFunctionStmt(AlterFunctionStmt *stmt);
90 static void ErrorIfFunctionDependsOnExtension(const ObjectAddress *functionAddress);
91 static char * quote_qualified_func_name(Oid funcOid);
92 static void DistributeFunctionWithDistributionArgument(RegProcedure funcOid,
93 													   char *distributionArgumentName,
94 													   Oid distributionArgumentOid,
95 													   char *colocateWithTableName,
96 													   const ObjectAddress *
97 													   functionAddress);
98 static void DistributeFunctionColocatedWithDistributedTable(RegProcedure funcOid,
99 															char *colocateWithTableName,
100 															const ObjectAddress *
101 															functionAddress);
102 static void DistributeFunctionColocatedWithReferenceTable(const
103 														  ObjectAddress *functionAddress);
104 
105 
106 PG_FUNCTION_INFO_V1(create_distributed_function);
107 
108 
109 /*
110  * create_distributed_function gets a function or procedure name with their list of
111  * argument types in parantheses, then it creates a new distributed function.
112  */
113 Datum
create_distributed_function(PG_FUNCTION_ARGS)114 create_distributed_function(PG_FUNCTION_ARGS)
115 {
116 	RegProcedure funcOid = PG_GETARG_OID(0);
117 
118 	text *distributionArgumentNameText = NULL; /* optional */
119 	text *colocateWithText = NULL; /* optional */
120 
121 	StringInfoData ddlCommand = { 0 };
122 	ObjectAddress functionAddress = { 0 };
123 
124 	Oid distributionArgumentOid = InvalidOid;
125 	bool colocatedWithReferenceTable = false;
126 
127 	char *distributionArgumentName = NULL;
128 	char *colocateWithTableName = NULL;
129 
130 	/* if called on NULL input, error out */
131 	if (funcOid == InvalidOid)
132 	{
133 		ereport(ERROR, (errmsg("the first parameter for create_distributed_function() "
134 							   "should be a single a valid function or procedure name "
135 							   "followed by a list of parameters in parantheses"),
136 						errhint("skip the parameters with OUT argtype as they are not "
137 								"part of the signature in PostgreSQL")));
138 	}
139 
140 	if (PG_ARGISNULL(1))
141 	{
142 		/*
143 		 * Using the default value, so distribute the function but do not set
144 		 * the distribution argument.
145 		 */
146 		distributionArgumentName = NULL;
147 	}
148 	else
149 	{
150 		distributionArgumentNameText = PG_GETARG_TEXT_P(1);
151 		distributionArgumentName = text_to_cstring(distributionArgumentNameText);
152 	}
153 
154 	if (PG_ARGISNULL(2))
155 	{
156 		ereport(ERROR, (errmsg("colocate_with parameter should not be NULL"),
157 						errhint("To use the default value, set colocate_with option "
158 								"to \"default\"")));
159 	}
160 	else
161 	{
162 		colocateWithText = PG_GETARG_TEXT_P(2);
163 		colocateWithTableName = text_to_cstring(colocateWithText);
164 
165 		/* check if the colocation belongs to a reference table */
166 		if (pg_strncasecmp(colocateWithTableName, "default", NAMEDATALEN) != 0)
167 		{
168 			Oid colocationRelationId = ResolveRelationId(colocateWithText, false);
169 			colocatedWithReferenceTable = IsCitusTableType(colocationRelationId,
170 														   REFERENCE_TABLE);
171 		}
172 	}
173 
174 	EnsureCoordinator();
175 	EnsureFunctionOwner(funcOid);
176 
177 	ObjectAddressSet(functionAddress, ProcedureRelationId, funcOid);
178 	ErrorIfFunctionDependsOnExtension(&functionAddress);
179 
180 	/*
181 	 * when we allow propagation within a transaction block we should make sure to only
182 	 * allow this in sequential mode
183 	 */
184 	EnsureSequentialModeForFunctionDDL();
185 
186 	EnsureDependenciesExistOnAllNodes(&functionAddress);
187 
188 	const char *createFunctionSQL = GetFunctionDDLCommand(funcOid, true);
189 	const char *alterFunctionOwnerSQL = GetFunctionAlterOwnerCommand(funcOid);
190 	initStringInfo(&ddlCommand);
191 	appendStringInfo(&ddlCommand, "%s;%s", createFunctionSQL, alterFunctionOwnerSQL);
192 	SendCommandToWorkersAsUser(NON_COORDINATOR_NODES, CurrentUserName(), ddlCommand.data);
193 
194 	MarkObjectDistributed(&functionAddress);
195 
196 	if (distributionArgumentName != NULL)
197 	{
198 		DistributeFunctionWithDistributionArgument(funcOid, distributionArgumentName,
199 												   distributionArgumentOid,
200 												   colocateWithTableName,
201 												   &functionAddress);
202 	}
203 	else if (!colocatedWithReferenceTable)
204 	{
205 		DistributeFunctionColocatedWithDistributedTable(funcOid, colocateWithTableName,
206 														&functionAddress);
207 	}
208 	else if (colocatedWithReferenceTable)
209 	{
210 		DistributeFunctionColocatedWithReferenceTable(&functionAddress);
211 	}
212 
213 	PG_RETURN_VOID();
214 }
215 
216 
217 /*
218  * DistributeFunctionWithDistributionArgument updates pg_dist_object records for
219  * a function/procedure that has a distribution argument, and triggers metadata
220  * sync so that the functions can be delegated on workers.
221  */
222 static void
DistributeFunctionWithDistributionArgument(RegProcedure funcOid,char * distributionArgumentName,Oid distributionArgumentOid,char * colocateWithTableName,const ObjectAddress * functionAddress)223 DistributeFunctionWithDistributionArgument(RegProcedure funcOid,
224 										   char *distributionArgumentName,
225 										   Oid distributionArgumentOid,
226 										   char *colocateWithTableName,
227 										   const ObjectAddress *functionAddress)
228 {
229 	/* get the argument index, or error out if we cannot find a valid index */
230 	int distributionArgumentIndex =
231 		GetDistributionArgIndex(funcOid, distributionArgumentName,
232 								&distributionArgumentOid);
233 
234 	/* get the colocation id, or error out if we cannot find an appropriate one */
235 	int colocationId =
236 		GetFunctionColocationId(funcOid, colocateWithTableName,
237 								distributionArgumentOid);
238 
239 	/* record the distribution argument and colocationId */
240 	UpdateFunctionDistributionInfo(functionAddress, &distributionArgumentIndex,
241 								   &colocationId);
242 
243 	/*
244 	 * Once we have at least one distributed function/procedure with distribution
245 	 * argument, we sync the metadata to nodes so that the function/procedure
246 	 * delegation can be handled locally on the nodes.
247 	 */
248 	TriggerSyncMetadataToPrimaryNodes();
249 }
250 
251 
252 /*
253  * DistributeFunctionColocatedWithDistributedTable updates pg_dist_object records for
254  * a function/procedure that is colocated with a distributed table.
255  */
256 static void
DistributeFunctionColocatedWithDistributedTable(RegProcedure funcOid,char * colocateWithTableName,const ObjectAddress * functionAddress)257 DistributeFunctionColocatedWithDistributedTable(RegProcedure funcOid,
258 												char *colocateWithTableName,
259 												const ObjectAddress *functionAddress)
260 {
261 	/*
262 	 * cannot provide colocate_with without distribution_arg_name when the function
263 	 * is not collocated with a reference table
264 	 */
265 	if (pg_strncasecmp(colocateWithTableName, "default", NAMEDATALEN) != 0)
266 	{
267 		char *functionName = get_func_name(funcOid);
268 
269 
270 		ereport(ERROR, (errcode(ERRCODE_INVALID_PARAMETER_VALUE),
271 						errmsg("cannot distribute the function \"%s\" since the "
272 							   "distribution argument is not valid ", functionName),
273 						errhint("To provide \"colocate_with\" option with a"
274 								" distributed table, the distribution argument"
275 								" parameter should also be provided")));
276 	}
277 
278 	/* set distribution argument and colocationId to NULL */
279 	UpdateFunctionDistributionInfo(functionAddress, NULL, NULL);
280 }
281 
282 
283 /*
284  * DistributeFunctionColocatedWithReferenceTable updates pg_dist_object records for
285  * a function/procedure that is colocated with a reference table.
286  */
287 static void
DistributeFunctionColocatedWithReferenceTable(const ObjectAddress * functionAddress)288 DistributeFunctionColocatedWithReferenceTable(const ObjectAddress *functionAddress)
289 {
290 	/* get the reference table colocation id */
291 	int colocationId = CreateReferenceTableColocationId();
292 
293 	/* set distribution argument to NULL and colocationId to the reference table colocation id */
294 	int *distributionArgumentIndex = NULL;
295 	UpdateFunctionDistributionInfo(functionAddress, distributionArgumentIndex,
296 								   &colocationId);
297 
298 	/*
299 	 * Once we have at least one distributed function/procedure that reads
300 	 * from a reference table, we sync the metadata to nodes so that the
301 	 * function/procedure delegation can be handled locally on the nodes.
302 	 */
303 	TriggerSyncMetadataToPrimaryNodes();
304 }
305 
306 
307 /*
308  * CreateFunctionDDLCommandsIdempotent returns a list of DDL statements (const char *) to be
309  * executed on a node to recreate the function addressed by the functionAddress.
310  */
311 List *
CreateFunctionDDLCommandsIdempotent(const ObjectAddress * functionAddress)312 CreateFunctionDDLCommandsIdempotent(const ObjectAddress *functionAddress)
313 {
314 	Assert(functionAddress->classId == ProcedureRelationId);
315 
316 	char *ddlCommand = GetFunctionDDLCommand(functionAddress->objectId, true);
317 	char *alterFunctionOwnerSQL = GetFunctionAlterOwnerCommand(functionAddress->objectId);
318 
319 	return list_make4(
320 		DISABLE_LOCAL_CHECK_FUNCTION_BODIES,
321 		ddlCommand,
322 		alterFunctionOwnerSQL,
323 		RESET_CHECK_FUNCTION_BODIES);
324 }
325 
326 
327 /*
328  * GetDistributionArgIndex calculates the distribution argument with the given
329  * parameters. The function errors out if no valid argument is found.
330  */
331 static int
GetDistributionArgIndex(Oid functionOid,char * distributionArgumentName,Oid * distributionArgumentOid)332 GetDistributionArgIndex(Oid functionOid, char *distributionArgumentName,
333 						Oid *distributionArgumentOid)
334 {
335 	int distributionArgumentIndex = -1;
336 
337 	Oid *argTypes = NULL;
338 	char **argNames = NULL;
339 	char *argModes = NULL;
340 
341 
342 	*distributionArgumentOid = InvalidOid;
343 
344 	HeapTuple proctup = SearchSysCache1(PROCOID, ObjectIdGetDatum(functionOid));
345 	if (!HeapTupleIsValid(proctup))
346 	{
347 		elog(ERROR, "cache lookup failed for function %u", functionOid);
348 	}
349 
350 	int numberOfArgs = get_func_arg_info(proctup, &argTypes, &argNames, &argModes);
351 
352 	if (argumentStartsWith(distributionArgumentName, "$"))
353 	{
354 		/* skip the first character, we're safe because text_to_cstring pallocs */
355 		distributionArgumentName++;
356 
357 		/* throws error if the input is not an integer */
358 		distributionArgumentIndex = pg_atoi(distributionArgumentName, 4, 0);
359 
360 		if (distributionArgumentIndex < 1 || distributionArgumentIndex > numberOfArgs)
361 		{
362 			char *functionName = get_func_name(functionOid);
363 
364 			ereport(ERROR, (errcode(ERRCODE_INVALID_PARAMETER_VALUE),
365 							errmsg("cannot distribute the function \"%s\" since "
366 								   "the distribution argument is not valid",
367 								   functionName),
368 							errhint("Either provide a valid function argument name "
369 									"or a valid \"$paramIndex\" to "
370 									"create_distributed_function()")));
371 		}
372 
373 		/*
374 		 * Internal representation for the distributionArgumentIndex
375 		 * starts from 0 whereas user facing API starts from 1.
376 		 */
377 		distributionArgumentIndex -= 1;
378 		*distributionArgumentOid = argTypes[distributionArgumentIndex];
379 
380 		ReleaseSysCache(proctup);
381 
382 		Assert(*distributionArgumentOid != InvalidOid);
383 
384 		return distributionArgumentIndex;
385 	}
386 
387 	/*
388 	 * The user didn't provid "$paramIndex" but potentially the name of the parameter.
389 	 * So, loop over the arguments and try to find the argument name that matches
390 	 * the parameter that user provided.
391 	 */
392 	for (int argIndex = 0; argIndex < numberOfArgs; ++argIndex)
393 	{
394 		char *argNameOnIndex = argNames != NULL ? argNames[argIndex] : NULL;
395 
396 		if (argNameOnIndex != NULL &&
397 			pg_strncasecmp(argNameOnIndex, distributionArgumentName, NAMEDATALEN) == 0)
398 		{
399 			distributionArgumentIndex = argIndex;
400 
401 			*distributionArgumentOid = argTypes[argIndex];
402 
403 			/* we found, no need to continue */
404 			break;
405 		}
406 	}
407 
408 	/* we still couldn't find the argument, so error out */
409 	if (distributionArgumentIndex == -1)
410 	{
411 		char *functionName = get_func_name(functionOid);
412 
413 		ereport(ERROR, (errcode(ERRCODE_INVALID_PARAMETER_VALUE),
414 						errmsg("cannot distribute the function \"%s\" since the "
415 							   "distribution argument is not valid ", functionName),
416 						errhint("Either provide a valid function argument name "
417 								"or a valid \"$paramIndex\" to "
418 								"create_distributed_function()")));
419 	}
420 
421 	ReleaseSysCache(proctup);
422 
423 	Assert(*distributionArgumentOid != InvalidOid);
424 
425 	return distributionArgumentIndex;
426 }
427 
428 
429 /*
430  * GetFunctionColocationId gets the parameters for deciding the colocationId
431  * of the function that is being distributed. The function errors out if it is
432  * not possible to assign a colocationId to the input function.
433  */
434 static int
GetFunctionColocationId(Oid functionOid,char * colocateWithTableName,Oid distributionArgumentOid)435 GetFunctionColocationId(Oid functionOid, char *colocateWithTableName,
436 						Oid distributionArgumentOid)
437 {
438 	int colocationId = INVALID_COLOCATION_ID;
439 	Relation pgDistColocation = table_open(DistColocationRelationId(), ShareLock);
440 
441 	if (pg_strncasecmp(colocateWithTableName, "default", NAMEDATALEN) == 0)
442 	{
443 		/* check for default colocation group */
444 		colocationId = ColocationId(ShardCount, ShardReplicationFactor,
445 									distributionArgumentOid, get_typcollation(
446 										distributionArgumentOid));
447 
448 		if (colocationId == INVALID_COLOCATION_ID)
449 		{
450 			char *functionName = get_func_name(functionOid);
451 
452 			ereport(ERROR, (errcode(ERRCODE_INVALID_PARAMETER_VALUE),
453 							errmsg("cannot distribute the function \"%s\" since there "
454 								   "is no table to colocate with", functionName),
455 							errhint("Provide a distributed table via \"colocate_with\" "
456 									"option to create_distributed_function()")));
457 		}
458 
459 		Oid colocatedTableId = ColocatedTableId(colocationId);
460 		if (colocatedTableId != InvalidOid)
461 		{
462 			EnsureFunctionCanBeColocatedWithTable(functionOid, distributionArgumentOid,
463 												  colocatedTableId);
464 		}
465 	}
466 	else
467 	{
468 		Oid sourceRelationId =
469 			ResolveRelationId(cstring_to_text(colocateWithTableName), false);
470 
471 		EnsureFunctionCanBeColocatedWithTable(functionOid, distributionArgumentOid,
472 											  sourceRelationId);
473 
474 		colocationId = TableColocationId(sourceRelationId);
475 	}
476 
477 	/* keep the lock */
478 	table_close(pgDistColocation, NoLock);
479 
480 	return colocationId;
481 }
482 
483 
484 /*
485  * EnsureFunctionCanBeColocatedWithTable checks whether the given arguments are
486  * suitable to distribute the function to be colocated with given source table.
487  */
488 static void
EnsureFunctionCanBeColocatedWithTable(Oid functionOid,Oid distributionColumnType,Oid sourceRelationId)489 EnsureFunctionCanBeColocatedWithTable(Oid functionOid, Oid distributionColumnType,
490 									  Oid sourceRelationId)
491 {
492 	CitusTableCacheEntry *sourceTableEntry = GetCitusTableCacheEntry(sourceRelationId);
493 	char sourceReplicationModel = sourceTableEntry->replicationModel;
494 
495 	if (!IsCitusTableTypeCacheEntry(sourceTableEntry, HASH_DISTRIBUTED) &&
496 		!IsCitusTableTypeCacheEntry(sourceTableEntry, REFERENCE_TABLE))
497 	{
498 		char *functionName = get_func_name(functionOid);
499 		char *sourceRelationName = get_rel_name(sourceRelationId);
500 
501 		ereport(ERROR, (errcode(ERRCODE_FEATURE_NOT_SUPPORTED),
502 						errmsg("cannot colocate function \"%s\" and table \"%s\" because "
503 							   "colocate_with option is only supported for hash "
504 							   "distributed tables and reference tables.",
505 							   functionName, sourceRelationName)));
506 	}
507 
508 	if (IsCitusTableTypeCacheEntry(sourceTableEntry, REFERENCE_TABLE) &&
509 		distributionColumnType != InvalidOid)
510 	{
511 		char *functionName = get_func_name(functionOid);
512 		char *sourceRelationName = get_rel_name(sourceRelationId);
513 
514 		ereport(ERROR, (errcode(ERRCODE_FEATURE_NOT_SUPPORTED),
515 						errmsg("cannot colocate function \"%s\" and table \"%s\" because "
516 							   "distribution arguments are not supported when "
517 							   "colocating with reference tables.",
518 							   functionName, sourceRelationName)));
519 	}
520 
521 	if (sourceReplicationModel != REPLICATION_MODEL_STREAMING)
522 	{
523 		char *functionName = get_func_name(functionOid);
524 		char *sourceRelationName = get_rel_name(sourceRelationId);
525 
526 		ereport(ERROR, (errmsg("cannot colocate function \"%s\" and table \"%s\"",
527 							   functionName, sourceRelationName),
528 						errdetail("Citus currently only supports colocating function "
529 								  "with distributed tables that are created using "
530 								  "streaming replication model."),
531 						errhint("When distributing tables make sure that "
532 								"citus.shard_replication_factor = 1")));
533 	}
534 
535 	/*
536 	 * If the types are the same, we're good. If not, we still check if there
537 	 * is any coercion path between the types.
538 	 */
539 	Var *sourceDistributionColumn = DistPartitionKeyOrError(sourceRelationId);
540 	Oid sourceDistributionColumnType = sourceDistributionColumn->vartype;
541 	if (sourceDistributionColumnType != distributionColumnType)
542 	{
543 		Oid coercionFuncId = InvalidOid;
544 
545 		CoercionPathType coercionType =
546 			find_coercion_pathway(distributionColumnType, sourceDistributionColumnType,
547 								  COERCION_EXPLICIT, &coercionFuncId);
548 
549 		/* if there is no path for coercion, error out*/
550 		if (coercionType == COERCION_PATH_NONE)
551 		{
552 			char *functionName = get_func_name(functionOid);
553 			char *sourceRelationName = get_rel_name(sourceRelationId);
554 
555 			ereport(ERROR, (errmsg("cannot colocate function \"%s\" and table \"%s\" "
556 								   "because distribution column types don't match and "
557 								   "there is no coercion path", sourceRelationName,
558 								   functionName)));
559 		}
560 	}
561 }
562 
563 
564 /*
565  * UpdateFunctionDistributionInfo gets object address of a function and
566  * updates its distribution_argument_index and colocationId in pg_dist_object.
567  */
568 static void
UpdateFunctionDistributionInfo(const ObjectAddress * distAddress,int * distribution_argument_index,int * colocationId)569 UpdateFunctionDistributionInfo(const ObjectAddress *distAddress,
570 							   int *distribution_argument_index,
571 							   int *colocationId)
572 {
573 	const bool indexOK = true;
574 
575 	ScanKeyData scanKey[3];
576 	Datum values[Natts_pg_dist_object];
577 	bool isnull[Natts_pg_dist_object];
578 	bool replace[Natts_pg_dist_object];
579 
580 	Relation pgDistObjectRel = table_open(DistObjectRelationId(), RowExclusiveLock);
581 	TupleDesc tupleDescriptor = RelationGetDescr(pgDistObjectRel);
582 
583 	/* scan pg_dist_object for classid = $1 AND objid = $2 AND objsubid = $3 via index */
584 	ScanKeyInit(&scanKey[0], Anum_pg_dist_object_classid, BTEqualStrategyNumber, F_OIDEQ,
585 				ObjectIdGetDatum(distAddress->classId));
586 	ScanKeyInit(&scanKey[1], Anum_pg_dist_object_objid, BTEqualStrategyNumber, F_OIDEQ,
587 				ObjectIdGetDatum(distAddress->objectId));
588 	ScanKeyInit(&scanKey[2], Anum_pg_dist_object_objsubid, BTEqualStrategyNumber,
589 				F_INT4EQ, ObjectIdGetDatum(distAddress->objectSubId));
590 
591 	SysScanDesc scanDescriptor = systable_beginscan(pgDistObjectRel,
592 													DistObjectPrimaryKeyIndexId(),
593 													indexOK,
594 													NULL, 3, scanKey);
595 
596 	HeapTuple heapTuple = systable_getnext(scanDescriptor);
597 	if (!HeapTupleIsValid(heapTuple))
598 	{
599 		ereport(ERROR, (errmsg("could not find valid entry for node \"%d,%d,%d\" "
600 							   "in pg_dist_object", distAddress->classId,
601 							   distAddress->objectId, distAddress->objectSubId)));
602 	}
603 
604 	memset(replace, 0, sizeof(replace));
605 
606 	replace[Anum_pg_dist_object_distribution_argument_index - 1] = true;
607 
608 	if (distribution_argument_index != NULL)
609 	{
610 		values[Anum_pg_dist_object_distribution_argument_index - 1] = Int32GetDatum(
611 			*distribution_argument_index);
612 		isnull[Anum_pg_dist_object_distribution_argument_index - 1] = false;
613 	}
614 	else
615 	{
616 		isnull[Anum_pg_dist_object_distribution_argument_index - 1] = true;
617 	}
618 
619 	replace[Anum_pg_dist_object_colocationid - 1] = true;
620 	if (colocationId != NULL)
621 	{
622 		values[Anum_pg_dist_object_colocationid - 1] = Int32GetDatum(*colocationId);
623 		isnull[Anum_pg_dist_object_colocationid - 1] = false;
624 	}
625 	else
626 	{
627 		isnull[Anum_pg_dist_object_colocationid - 1] = true;
628 	}
629 
630 	heapTuple = heap_modify_tuple(heapTuple, tupleDescriptor, values, isnull, replace);
631 
632 	CatalogTupleUpdate(pgDistObjectRel, &heapTuple->t_self, heapTuple);
633 
634 	CitusInvalidateRelcacheByRelid(DistObjectRelationId());
635 
636 	CommandCounterIncrement();
637 
638 	systable_endscan(scanDescriptor);
639 
640 	table_close(pgDistObjectRel, NoLock);
641 }
642 
643 
644 /*
645  * GetFunctionDDLCommand returns the complete "CREATE OR REPLACE FUNCTION ..." statement for
646  * the specified function followed by "ALTER FUNCTION .. SET OWNER ..".
647  *
648  * useCreateOrReplace is ignored for non-aggregate functions.
649  */
650 char *
GetFunctionDDLCommand(const RegProcedure funcOid,bool useCreateOrReplace)651 GetFunctionDDLCommand(const RegProcedure funcOid, bool useCreateOrReplace)
652 {
653 	char *createFunctionSQL = NULL;
654 
655 	if (get_func_prokind(funcOid) == PROKIND_AGGREGATE)
656 	{
657 		createFunctionSQL = GetAggregateDDLCommand(funcOid, useCreateOrReplace);
658 	}
659 	else
660 	{
661 		Datum sqlTextDatum = (Datum) 0;
662 
663 		PushOverrideEmptySearchPath(CurrentMemoryContext);
664 
665 		sqlTextDatum = DirectFunctionCall1(pg_get_functiondef,
666 										   ObjectIdGetDatum(funcOid));
667 		createFunctionSQL = TextDatumGetCString(sqlTextDatum);
668 
669 		/* revert back to original search_path */
670 		PopOverrideSearchPath();
671 	}
672 
673 	return createFunctionSQL;
674 }
675 
676 
677 /*
678  * GetFunctionAlterOwnerCommand returns "ALTER FUNCTION .. SET OWNER .." statement for
679  * the specified function.
680  */
681 static char *
GetFunctionAlterOwnerCommand(const RegProcedure funcOid)682 GetFunctionAlterOwnerCommand(const RegProcedure funcOid)
683 {
684 	HeapTuple proctup = SearchSysCache1(PROCOID, ObjectIdGetDatum(funcOid));
685 	StringInfo alterCommand = makeStringInfo();
686 	Oid procOwner = InvalidOid;
687 
688 
689 	if (HeapTupleIsValid(proctup))
690 	{
691 		Form_pg_proc procform = (Form_pg_proc) GETSTRUCT(proctup);
692 
693 		procOwner = procform->proowner;
694 
695 		ReleaseSysCache(proctup);
696 	}
697 	else if (!OidIsValid(funcOid) || !HeapTupleIsValid(proctup))
698 	{
699 		ereport(ERROR, (errmsg("cannot find function with oid: %d", funcOid)));
700 	}
701 
702 	/*
703 	 * If the function exists we want to use format_procedure_qualified to
704 	 * serialize its canonical arguments
705 	 */
706 	char *functionSignature = format_procedure_qualified(funcOid);
707 	char *functionOwner = GetUserNameFromId(procOwner, false);
708 
709 	appendStringInfo(alterCommand, "ALTER ROUTINE %s OWNER TO %s;",
710 					 functionSignature,
711 					 quote_identifier(functionOwner));
712 
713 	return alterCommand->data;
714 }
715 
716 
717 /*
718  * GetAggregateDDLCommand returns a string for creating an aggregate.
719  * CREATE OR REPLACE AGGREGATE was only introduced in pg12,
720  * so a second parameter useCreateOrReplace signals whether to
721  * to create a plain CREATE AGGREGATE or not. In pg11 we return a string
722  * which is a call to worker_create_or_replace_object in lieu of
723  * CREATE OR REPLACE AGGREGATE.
724  */
725 static char *
GetAggregateDDLCommand(const RegProcedure funcOid,bool useCreateOrReplace)726 GetAggregateDDLCommand(const RegProcedure funcOid, bool useCreateOrReplace)
727 {
728 	StringInfoData buf = { 0 };
729 	int i = 0;
730 	Oid *argtypes = NULL;
731 	char **argnames = NULL;
732 	char *argmodes = NULL;
733 	int insertorderbyat = -1;
734 	int argsprinted = 0;
735 	int inputargno = 0;
736 
737 	HeapTuple proctup = SearchSysCache1(PROCOID, funcOid);
738 	if (!HeapTupleIsValid(proctup))
739 	{
740 		elog(ERROR, "cache lookup failed for %d", funcOid);
741 	}
742 
743 	Form_pg_proc proc = (Form_pg_proc) GETSTRUCT(proctup);
744 
745 	Assert(proc->prokind == PROKIND_AGGREGATE);
746 
747 	initStringInfo(&buf);
748 
749 	const char *name = NameStr(proc->proname);
750 	const char *nsp = get_namespace_name(proc->pronamespace);
751 
752 	if (useCreateOrReplace)
753 	{
754 		appendStringInfo(&buf, "CREATE OR REPLACE AGGREGATE %s(",
755 						 quote_qualified_identifier(nsp, name));
756 	}
757 	else
758 	{
759 		appendStringInfo(&buf, "CREATE AGGREGATE %s(",
760 						 quote_qualified_identifier(nsp, name));
761 	}
762 
763 	/* Parameters, borrows heavily from print_function_arguments in postgres */
764 	int numargs = get_func_arg_info(proctup, &argtypes, &argnames, &argmodes);
765 
766 	HeapTuple aggtup = SearchSysCache1(AGGFNOID, funcOid);
767 	if (!HeapTupleIsValid(aggtup))
768 	{
769 		elog(ERROR, "cache lookup failed for %d", funcOid);
770 	}
771 	Form_pg_aggregate agg = (Form_pg_aggregate) GETSTRUCT(aggtup);
772 
773 	if (AGGKIND_IS_ORDERED_SET(agg->aggkind))
774 	{
775 		insertorderbyat = agg->aggnumdirectargs;
776 	}
777 
778 	for (i = 0; i < numargs; i++)
779 	{
780 		Oid argtype = argtypes[i];
781 		char *argname = argnames ? argnames[i] : NULL;
782 		char argmode = argmodes ? argmodes[i] : PROARGMODE_IN;
783 		const char *modename;
784 
785 		switch (argmode)
786 		{
787 			case PROARGMODE_IN:
788 			{
789 				modename = "";
790 				break;
791 			}
792 
793 			case PROARGMODE_VARIADIC:
794 			{
795 				modename = "VARIADIC ";
796 				break;
797 			}
798 
799 			default:
800 			{
801 				elog(ERROR, "unexpected parameter mode '%c'", argmode);
802 				modename = NULL;
803 				break;
804 			}
805 		}
806 
807 		inputargno++;       /* this is a 1-based counter */
808 		if (argsprinted == insertorderbyat)
809 		{
810 			appendStringInfoString(&buf, " ORDER BY ");
811 		}
812 		else if (argsprinted)
813 		{
814 			appendStringInfoString(&buf, ", ");
815 		}
816 
817 		appendStringInfoString(&buf, modename);
818 
819 		if (argname && argname[0])
820 		{
821 			appendStringInfo(&buf, "%s ", quote_identifier(argname));
822 		}
823 
824 		appendStringInfoString(&buf, format_type_be_qualified(argtype));
825 
826 		argsprinted++;
827 
828 		/* nasty hack: print the last arg twice for variadic ordered-set agg */
829 		if (argsprinted == insertorderbyat && i == numargs - 1)
830 		{
831 			i--;
832 		}
833 	}
834 
835 	appendStringInfo(&buf, ") (STYPE = %s,SFUNC = %s",
836 					 format_type_be_qualified(agg->aggtranstype),
837 					 quote_qualified_func_name(agg->aggtransfn));
838 
839 	if (agg->aggtransspace != 0)
840 	{
841 		appendStringInfo(&buf, ", SSPACE = %d", agg->aggtransspace);
842 	}
843 
844 	if (agg->aggfinalfn != InvalidOid)
845 	{
846 		const char *finalmodifystring = NULL;
847 		switch (agg->aggfinalmodify)
848 		{
849 			case AGGMODIFY_READ_ONLY:
850 			{
851 				finalmodifystring = "READ_ONLY";
852 				break;
853 			}
854 
855 			case AGGMODIFY_SHAREABLE:
856 			{
857 				finalmodifystring = "SHAREABLE";
858 				break;
859 			}
860 
861 			case AGGMODIFY_READ_WRITE:
862 			{
863 				finalmodifystring = "READ_WRITE";
864 				break;
865 			}
866 		}
867 
868 		appendStringInfo(&buf, ", FINALFUNC = %s",
869 						 quote_qualified_func_name(agg->aggfinalfn));
870 
871 		if (finalmodifystring != NULL)
872 		{
873 			appendStringInfo(&buf, ", FINALFUNC_MODIFY = %s", finalmodifystring);
874 		}
875 
876 		if (agg->aggfinalextra)
877 		{
878 			appendStringInfoString(&buf, ", FINALFUNC_EXTRA");
879 		}
880 	}
881 
882 	if (agg->aggmtransspace != 0)
883 	{
884 		appendStringInfo(&buf, ", MSSPACE = %d", agg->aggmtransspace);
885 	}
886 
887 	if (agg->aggmfinalfn)
888 	{
889 		const char *mfinalmodifystring = NULL;
890 		switch (agg->aggfinalmodify)
891 		{
892 			case AGGMODIFY_READ_ONLY:
893 			{
894 				mfinalmodifystring = "READ_ONLY";
895 				break;
896 			}
897 
898 			case AGGMODIFY_SHAREABLE:
899 			{
900 				mfinalmodifystring = "SHAREABLE";
901 				break;
902 			}
903 
904 			case AGGMODIFY_READ_WRITE:
905 			{
906 				mfinalmodifystring = "READ_WRITE";
907 				break;
908 			}
909 		}
910 
911 		appendStringInfo(&buf, ", MFINALFUNC = %s",
912 						 quote_qualified_func_name(agg->aggmfinalfn));
913 
914 		if (mfinalmodifystring != NULL)
915 		{
916 			appendStringInfo(&buf, ", MFINALFUNC_MODIFY = %s", mfinalmodifystring);
917 		}
918 
919 		if (agg->aggmfinalextra)
920 		{
921 			appendStringInfoString(&buf, ", MFINALFUNC_EXTRA");
922 		}
923 	}
924 
925 	if (agg->aggmtransfn)
926 	{
927 		appendStringInfo(&buf, ", MSFUNC = %s",
928 						 quote_qualified_func_name(agg->aggmtransfn));
929 
930 		if (agg->aggmtranstype)
931 		{
932 			appendStringInfo(&buf, ", MSTYPE = %s",
933 							 format_type_be_qualified(agg->aggmtranstype));
934 		}
935 	}
936 
937 	if (agg->aggtransspace != 0)
938 	{
939 		appendStringInfo(&buf, ", SSPACE = %d", agg->aggtransspace);
940 	}
941 
942 	if (agg->aggminvtransfn)
943 	{
944 		appendStringInfo(&buf, ", MINVFUNC = %s",
945 						 quote_qualified_func_name(agg->aggminvtransfn));
946 	}
947 
948 	if (agg->aggcombinefn)
949 	{
950 		appendStringInfo(&buf, ", COMBINEFUNC = %s",
951 						 quote_qualified_func_name(agg->aggcombinefn));
952 	}
953 
954 	if (agg->aggserialfn)
955 	{
956 		appendStringInfo(&buf, ", SERIALFUNC = %s",
957 						 quote_qualified_func_name(agg->aggserialfn));
958 	}
959 
960 	if (agg->aggdeserialfn)
961 	{
962 		appendStringInfo(&buf, ", DESERIALFUNC = %s",
963 						 quote_qualified_func_name(agg->aggdeserialfn));
964 	}
965 
966 	if (agg->aggsortop != InvalidOid)
967 	{
968 		appendStringInfo(&buf, ", SORTOP = %s",
969 						 generate_operator_name(agg->aggsortop, argtypes[0],
970 												argtypes[0]));
971 	}
972 
973 	{
974 		const char *parallelstring = NULL;
975 		switch (proc->proparallel)
976 		{
977 			case PROPARALLEL_SAFE:
978 			{
979 				parallelstring = "SAFE";
980 				break;
981 			}
982 
983 			case PROPARALLEL_RESTRICTED:
984 			{
985 				parallelstring = "RESTRICTED";
986 				break;
987 			}
988 
989 			case PROPARALLEL_UNSAFE:
990 			{
991 				break;
992 			}
993 
994 			default:
995 			{
996 				elog(WARNING, "Unknown parallel option, ignoring: %c", proc->proparallel);
997 				break;
998 			}
999 		}
1000 
1001 		if (parallelstring != NULL)
1002 		{
1003 			appendStringInfo(&buf, ", PARALLEL = %s", parallelstring);
1004 		}
1005 	}
1006 
1007 	{
1008 		bool isNull = false;
1009 		Datum textInitVal = SysCacheGetAttr(AGGFNOID, aggtup,
1010 											Anum_pg_aggregate_agginitval,
1011 											&isNull);
1012 		if (!isNull)
1013 		{
1014 			char *strInitVal = TextDatumGetCString(textInitVal);
1015 			char *strInitValQuoted = quote_literal_cstr(strInitVal);
1016 
1017 			appendStringInfo(&buf, ", INITCOND = %s", strInitValQuoted);
1018 
1019 			pfree(strInitValQuoted);
1020 			pfree(strInitVal);
1021 		}
1022 	}
1023 
1024 	{
1025 		bool isNull = false;
1026 		Datum textInitVal = SysCacheGetAttr(AGGFNOID, aggtup,
1027 											Anum_pg_aggregate_aggminitval,
1028 											&isNull);
1029 		if (!isNull)
1030 		{
1031 			char *strInitVal = TextDatumGetCString(textInitVal);
1032 			char *strInitValQuoted = quote_literal_cstr(strInitVal);
1033 
1034 			appendStringInfo(&buf, ", MINITCOND = %s", strInitValQuoted);
1035 
1036 			pfree(strInitValQuoted);
1037 			pfree(strInitVal);
1038 		}
1039 	}
1040 
1041 	if (agg->aggkind == AGGKIND_HYPOTHETICAL)
1042 	{
1043 		appendStringInfoString(&buf, ", HYPOTHETICAL");
1044 	}
1045 
1046 	appendStringInfoChar(&buf, ')');
1047 
1048 	ReleaseSysCache(aggtup);
1049 	ReleaseSysCache(proctup);
1050 
1051 	return buf.data;
1052 }
1053 
1054 
1055 /*
1056  * EnsureSequentialModeForFunctionDDL makes sure that the current transaction is already in
1057  * sequential mode, or can still safely be put in sequential mode, it errors if that is
1058  * not possible. The error contains information for the user to retry the transaction with
1059  * sequential mode set from the beginning.
1060  *
1061  * As functions are node scoped objects there exists only 1 instance of the function used by
1062  * potentially multiple shards. To make sure all shards in the transaction can interact
1063  * with the function the function needs to be visible on all connections used by the transaction,
1064  * meaning we can only use 1 connection per node.
1065  */
1066 static void
EnsureSequentialModeForFunctionDDL(void)1067 EnsureSequentialModeForFunctionDDL(void)
1068 {
1069 	if (ParallelQueryExecutedInTransaction())
1070 	{
1071 		ereport(ERROR, (errmsg("cannot create function because there was a "
1072 							   "parallel operation on a distributed table in the "
1073 							   "transaction"),
1074 						errdetail("When creating a distributed function, Citus needs to "
1075 								  "perform all operations over a single connection per "
1076 								  "node to ensure consistency."),
1077 						errhint("Try re-running the transaction with "
1078 								"\"SET LOCAL citus.multi_shard_modify_mode TO "
1079 								"\'sequential\';\"")));
1080 	}
1081 
1082 	ereport(DEBUG1, (errmsg("switching to sequential query execution mode"),
1083 					 errdetail(
1084 						 "A distributed function is created. To make sure subsequent "
1085 						 "commands see the type correctly we need to make sure to "
1086 						 "use only one connection for all future commands")));
1087 	SetLocalMultiShardModifyModeToSequential();
1088 }
1089 
1090 
1091 /*
1092  * TriggerSyncMetadataToPrimaryNodes iterates over the active primary nodes,
1093  * and triggers the metadata syncs if the node has not the metadata. Later,
1094  * maintenance daemon will sync the metadata to nodes.
1095  */
1096 static void
TriggerSyncMetadataToPrimaryNodes(void)1097 TriggerSyncMetadataToPrimaryNodes(void)
1098 {
1099 	List *workerList = ActivePrimaryNonCoordinatorNodeList(ShareLock);
1100 	bool triggerMetadataSync = false;
1101 
1102 	WorkerNode *workerNode = NULL;
1103 	foreach_ptr(workerNode, workerList)
1104 	{
1105 		/* if already has metadata, no need to do it again */
1106 		if (!workerNode->hasMetadata)
1107 		{
1108 			/*
1109 			 * Let the maintanince deamon do the hard work of syncing the metadata. We prefer
1110 			 * this because otherwise node activation might fail withing transaction blocks.
1111 			 */
1112 			LockRelationOid(DistNodeRelationId(), ExclusiveLock);
1113 			SetWorkerColumnLocalOnly(workerNode, Anum_pg_dist_node_hasmetadata,
1114 									 BoolGetDatum(true));
1115 
1116 			triggerMetadataSync = true;
1117 		}
1118 		else if (!workerNode->metadataSynced)
1119 		{
1120 			triggerMetadataSync = true;
1121 		}
1122 	}
1123 
1124 	/* let the maintanince deamon know about the metadata sync */
1125 	if (triggerMetadataSync)
1126 	{
1127 		TriggerMetadataSyncOnCommit();
1128 	}
1129 }
1130 
1131 
1132 /*
1133  * ShouldPropagateCreateFunction tests if we need to propagate a CREATE FUNCTION
1134  * statement. We only propagate replace's of distributed functions to keep the function on
1135  * the workers in sync with the one on the coordinator.
1136  */
1137 static bool
ShouldPropagateCreateFunction(CreateFunctionStmt * stmt)1138 ShouldPropagateCreateFunction(CreateFunctionStmt *stmt)
1139 {
1140 	if (creating_extension)
1141 	{
1142 		/*
1143 		 * extensions should be created separately on the workers, functions cascading
1144 		 * from an extension should therefore not be propagated.
1145 		 */
1146 		return false;
1147 	}
1148 
1149 	if (!EnableDependencyCreation)
1150 	{
1151 		/*
1152 		 * we are configured to disable object propagation, should not propagate anything
1153 		 */
1154 		return false;
1155 	}
1156 
1157 	if (!stmt->replace)
1158 	{
1159 		/*
1160 		 * Since we only care for a replace of distributed functions if the statement is
1161 		 * not a replace we are going to ignore.
1162 		 */
1163 		return false;
1164 	}
1165 
1166 	/*
1167 	 * Even though its a replace we should accept an non-existing function, it will just
1168 	 * not be distributed
1169 	 */
1170 	ObjectAddress address = GetObjectAddressFromParseTree((Node *) stmt, true);
1171 	if (!IsObjectDistributed(&address))
1172 	{
1173 		/* do not propagate alter function for non-distributed functions */
1174 		return false;
1175 	}
1176 
1177 	return true;
1178 }
1179 
1180 
1181 /*
1182  * ShouldPropagateAlterFunction returns, based on the address of a function, if alter
1183  * statements targeting the function should be propagated.
1184  */
1185 static bool
ShouldPropagateAlterFunction(const ObjectAddress * address)1186 ShouldPropagateAlterFunction(const ObjectAddress *address)
1187 {
1188 	if (creating_extension)
1189 	{
1190 		/*
1191 		 * extensions should be created separately on the workers, functions cascading
1192 		 * from an extension should therefore not be propagated.
1193 		 */
1194 		return false;
1195 	}
1196 
1197 	if (!EnableDependencyCreation)
1198 	{
1199 		/*
1200 		 * we are configured to disable object propagation, should not propagate anything
1201 		 */
1202 		return false;
1203 	}
1204 
1205 	if (!IsObjectDistributed(address))
1206 	{
1207 		/* do not propagate alter function for non-distributed functions */
1208 		return false;
1209 	}
1210 
1211 	return true;
1212 }
1213 
1214 
1215 /*
1216  * PreprocessCreateFunctionStmt is called during the planning phase for CREATE [OR REPLACE]
1217  * FUNCTION. We primarily care for the replace variant of this statement to keep
1218  * distributed functions in sync. We bail via a check on ShouldPropagateCreateFunction
1219  * which checks for the OR REPLACE modifier.
1220  *
1221  * Since we use pg_get_functiondef to get the ddl command we actually do not do any
1222  * planning here, instead we defer the plan creation to the processing step.
1223  *
1224  * Instead we do our basic housekeeping where we make sure we are on the coordinator and
1225  * can propagate the function in sequential mode.
1226  */
1227 List *
PreprocessCreateFunctionStmt(Node * node,const char * queryString,ProcessUtilityContext processUtilityContext)1228 PreprocessCreateFunctionStmt(Node *node, const char *queryString,
1229 							 ProcessUtilityContext processUtilityContext)
1230 {
1231 	CreateFunctionStmt *stmt = castNode(CreateFunctionStmt, node);
1232 
1233 	if (!ShouldPropagateCreateFunction(stmt))
1234 	{
1235 		return NIL;
1236 	}
1237 
1238 	EnsureCoordinator();
1239 
1240 	EnsureSequentialModeForFunctionDDL();
1241 
1242 	/*
1243 	 * ddl jobs will be generated during the Processing phase as we need the function to
1244 	 * be updated in the catalog to get its sql representation
1245 	 */
1246 	return NIL;
1247 }
1248 
1249 
1250 /*
1251  * PostprocessCreateFunctionStmt actually creates the plan we need to execute for function
1252  * propagation. This is the downside of using pg_get_functiondef to get the sql statement.
1253  *
1254  * Besides creating the plan we also make sure all (new) dependencies of the function are
1255  * created on all nodes.
1256  */
1257 List *
PostprocessCreateFunctionStmt(Node * node,const char * queryString)1258 PostprocessCreateFunctionStmt(Node *node, const char *queryString)
1259 {
1260 	CreateFunctionStmt *stmt = castNode(CreateFunctionStmt, node);
1261 
1262 	if (!ShouldPropagateCreateFunction(stmt))
1263 	{
1264 		return NIL;
1265 	}
1266 
1267 	ObjectAddress address = GetObjectAddressFromParseTree((Node *) stmt, false);
1268 	EnsureDependenciesExistOnAllNodes(&address);
1269 
1270 	List *commands = list_make4(DISABLE_DDL_PROPAGATION,
1271 								GetFunctionDDLCommand(address.objectId, true),
1272 								GetFunctionAlterOwnerCommand(address.objectId),
1273 								ENABLE_DDL_PROPAGATION);
1274 
1275 	return NodeDDLTaskList(NON_COORDINATOR_NODES, commands);
1276 }
1277 
1278 
1279 /*
1280  * CreateFunctionStmtObjectAddress returns the ObjectAddress for the subject of the
1281  * CREATE [OR REPLACE] FUNCTION statement. If missing_ok is false it will error with the
1282  * normal postgres error for unfound functions.
1283  */
1284 ObjectAddress
CreateFunctionStmtObjectAddress(Node * node,bool missing_ok)1285 CreateFunctionStmtObjectAddress(Node *node, bool missing_ok)
1286 {
1287 	CreateFunctionStmt *stmt = castNode(CreateFunctionStmt, node);
1288 	ObjectType objectType = OBJECT_FUNCTION;
1289 
1290 	if (stmt->is_procedure)
1291 	{
1292 		objectType = OBJECT_PROCEDURE;
1293 	}
1294 
1295 	ObjectWithArgs *objectWithArgs = makeNode(ObjectWithArgs);
1296 	objectWithArgs->objname = stmt->funcname;
1297 
1298 	FunctionParameter *funcParam = NULL;
1299 	foreach_ptr(funcParam, stmt->parameters)
1300 	{
1301 		objectWithArgs->objargs = lappend(objectWithArgs->objargs, funcParam->argType);
1302 	}
1303 
1304 	return FunctionToObjectAddress(objectType, objectWithArgs, missing_ok);
1305 }
1306 
1307 
1308 /*
1309  * DefineAggregateStmtObjectAddress finds the ObjectAddress for the composite type described
1310  * by the DefineStmtObjectAddress. If missing_ok is false this function throws an error if the
1311  * aggregate does not exist.
1312  *
1313  * objectId in the address can be invalid if missing_ok was set to true.
1314  */
1315 ObjectAddress
DefineAggregateStmtObjectAddress(Node * node,bool missing_ok)1316 DefineAggregateStmtObjectAddress(Node *node, bool missing_ok)
1317 {
1318 	DefineStmt *stmt = castNode(DefineStmt, node);
1319 
1320 	Assert(stmt->kind == OBJECT_AGGREGATE);
1321 
1322 	ObjectWithArgs *objectWithArgs = makeNode(ObjectWithArgs);
1323 	objectWithArgs->objname = stmt->defnames;
1324 
1325 	FunctionParameter *funcParam = NULL;
1326 	foreach_ptr(funcParam, linitial(stmt->args))
1327 	{
1328 		objectWithArgs->objargs = lappend(objectWithArgs->objargs, funcParam->argType);
1329 	}
1330 
1331 	return FunctionToObjectAddress(OBJECT_AGGREGATE, objectWithArgs, missing_ok);
1332 }
1333 
1334 
1335 /*
1336  * PreprocessAlterFunctionStmt is invoked for alter function statements with actions. Here we
1337  * plan the jobs to be executed on the workers for functions that have been distributed in
1338  * the cluster.
1339  */
1340 List *
PreprocessAlterFunctionStmt(Node * node,const char * queryString,ProcessUtilityContext processUtilityContext)1341 PreprocessAlterFunctionStmt(Node *node, const char *queryString,
1342 							ProcessUtilityContext processUtilityContext)
1343 {
1344 	AlterFunctionStmt *stmt = castNode(AlterFunctionStmt, node);
1345 	AssertObjectTypeIsFunctional(stmt->objtype);
1346 
1347 	ObjectAddress address = GetObjectAddressFromParseTree((Node *) stmt, false);
1348 	if (!ShouldPropagateAlterFunction(&address))
1349 	{
1350 		return NIL;
1351 	}
1352 
1353 	EnsureCoordinator();
1354 	ErrorIfUnsupportedAlterFunctionStmt(stmt);
1355 	EnsureSequentialModeForFunctionDDL();
1356 	QualifyTreeNode((Node *) stmt);
1357 	const char *sql = DeparseTreeNode((Node *) stmt);
1358 
1359 	List *commands = list_make3(DISABLE_DDL_PROPAGATION,
1360 								(void *) sql,
1361 								ENABLE_DDL_PROPAGATION);
1362 
1363 	return NodeDDLTaskList(NON_COORDINATOR_NODES, commands);
1364 }
1365 
1366 
1367 /*
1368  * PreprocessRenameFunctionStmt is called when the user is renaming a function. The invocation
1369  * happens before the statement is applied locally.
1370  *
1371  * As the function already exists we have access to the ObjectAddress, this is used to
1372  * check if it is distributed. If so the rename is executed on all the workers to keep the
1373  * types in sync across the cluster.
1374  */
1375 List *
PreprocessRenameFunctionStmt(Node * node,const char * queryString,ProcessUtilityContext processUtilityContext)1376 PreprocessRenameFunctionStmt(Node *node, const char *queryString,
1377 							 ProcessUtilityContext processUtilityContext)
1378 {
1379 	RenameStmt *stmt = castNode(RenameStmt, node);
1380 	AssertObjectTypeIsFunctional(stmt->renameType);
1381 
1382 	ObjectAddress address = GetObjectAddressFromParseTree((Node *) stmt, false);
1383 	if (!ShouldPropagateAlterFunction(&address))
1384 	{
1385 		return NIL;
1386 	}
1387 
1388 	EnsureCoordinator();
1389 	EnsureSequentialModeForFunctionDDL();
1390 	QualifyTreeNode((Node *) stmt);
1391 	const char *sql = DeparseTreeNode((Node *) stmt);
1392 
1393 	List *commands = list_make3(DISABLE_DDL_PROPAGATION,
1394 								(void *) sql,
1395 								ENABLE_DDL_PROPAGATION);
1396 
1397 	return NodeDDLTaskList(NON_COORDINATOR_NODES, commands);
1398 }
1399 
1400 
1401 /*
1402  * PreprocessAlterFunctionSchemaStmt is executed before the statement is applied to the local
1403  * postgres instance.
1404  *
1405  * In this stage we can prepare the commands that need to be run on all workers.
1406  */
1407 List *
PreprocessAlterFunctionSchemaStmt(Node * node,const char * queryString,ProcessUtilityContext processUtilityContext)1408 PreprocessAlterFunctionSchemaStmt(Node *node, const char *queryString,
1409 								  ProcessUtilityContext processUtilityContext)
1410 {
1411 	AlterObjectSchemaStmt *stmt = castNode(AlterObjectSchemaStmt, node);
1412 	AssertObjectTypeIsFunctional(stmt->objectType);
1413 
1414 	ObjectAddress address = GetObjectAddressFromParseTree((Node *) stmt, false);
1415 	if (!ShouldPropagateAlterFunction(&address))
1416 	{
1417 		return NIL;
1418 	}
1419 
1420 	EnsureCoordinator();
1421 	EnsureSequentialModeForFunctionDDL();
1422 	QualifyTreeNode((Node *) stmt);
1423 	const char *sql = DeparseTreeNode((Node *) stmt);
1424 
1425 	List *commands = list_make3(DISABLE_DDL_PROPAGATION,
1426 								(void *) sql,
1427 								ENABLE_DDL_PROPAGATION);
1428 
1429 	return NodeDDLTaskList(NON_COORDINATOR_NODES, commands);
1430 }
1431 
1432 
1433 /*
1434  * PreprocessAlterFunctionOwnerStmt is called for change of owner ship of functions before the owner
1435  * ship is changed on the local instance.
1436  *
1437  * If the function for which the owner is changed is distributed we execute the change on
1438  * all the workers to keep the type in sync across the cluster.
1439  */
1440 List *
PreprocessAlterFunctionOwnerStmt(Node * node,const char * queryString,ProcessUtilityContext processUtilityContext)1441 PreprocessAlterFunctionOwnerStmt(Node *node, const char *queryString,
1442 								 ProcessUtilityContext processUtilityContext)
1443 {
1444 	AlterOwnerStmt *stmt = castNode(AlterOwnerStmt, node);
1445 	AssertObjectTypeIsFunctional(stmt->objectType);
1446 
1447 	ObjectAddress address = GetObjectAddressFromParseTree((Node *) stmt, false);
1448 	if (!ShouldPropagateAlterFunction(&address))
1449 	{
1450 		return NIL;
1451 	}
1452 
1453 	EnsureCoordinator();
1454 	EnsureSequentialModeForFunctionDDL();
1455 	QualifyTreeNode((Node *) stmt);
1456 	const char *sql = DeparseTreeNode((Node *) stmt);
1457 
1458 	List *commands = list_make3(DISABLE_DDL_PROPAGATION,
1459 								(void *) sql,
1460 								ENABLE_DDL_PROPAGATION);
1461 
1462 	return NodeDDLTaskList(NON_COORDINATOR_NODES, commands);
1463 }
1464 
1465 
1466 /*
1467  * PreprocessDropFunctionStmt gets called during the planning phase of a DROP FUNCTION statement
1468  * and returns a list of DDLJob's that will drop any distributed functions from the
1469  * workers.
1470  *
1471  * The DropStmt could have multiple objects to drop, the list of objects will be filtered
1472  * to only keep the distributed functions for deletion on the workers. Non-distributed
1473  * functions will still be dropped locally but not on the workers.
1474  */
1475 List *
PreprocessDropFunctionStmt(Node * node,const char * queryString,ProcessUtilityContext processUtilityContext)1476 PreprocessDropFunctionStmt(Node *node, const char *queryString,
1477 						   ProcessUtilityContext processUtilityContext)
1478 {
1479 	DropStmt *stmt = castNode(DropStmt, node);
1480 	List *deletingObjectWithArgsList = stmt->objects;
1481 	List *distributedObjectWithArgsList = NIL;
1482 	List *distributedFunctionAddresses = NIL;
1483 
1484 	AssertObjectTypeIsFunctional(stmt->removeType);
1485 
1486 	if (creating_extension)
1487 	{
1488 		/*
1489 		 * extensions should be created separately on the workers, types cascading from an
1490 		 * extension should therefor not be propagated here.
1491 		 */
1492 		return NIL;
1493 	}
1494 
1495 	if (!EnableDependencyCreation)
1496 	{
1497 		/*
1498 		 * we are configured to disable object propagation, should not propagate anything
1499 		 */
1500 		return NIL;
1501 	}
1502 
1503 
1504 	/*
1505 	 * Our statements need to be fully qualified so we can drop them from the right schema
1506 	 * on the workers
1507 	 */
1508 	QualifyTreeNode((Node *) stmt);
1509 
1510 	/*
1511 	 * iterate over all functions to be dropped and filter to keep only distributed
1512 	 * functions.
1513 	 */
1514 	ObjectWithArgs *func = NULL;
1515 	foreach_ptr(func, deletingObjectWithArgsList)
1516 	{
1517 		ObjectAddress address = FunctionToObjectAddress(stmt->removeType, func,
1518 														stmt->missing_ok);
1519 
1520 		if (!IsObjectDistributed(&address))
1521 		{
1522 			continue;
1523 		}
1524 
1525 		/* collect information for all distributed functions */
1526 		ObjectAddress *addressp = palloc(sizeof(ObjectAddress));
1527 		*addressp = address;
1528 		distributedFunctionAddresses = lappend(distributedFunctionAddresses, addressp);
1529 		distributedObjectWithArgsList = lappend(distributedObjectWithArgsList, func);
1530 	}
1531 
1532 	if (list_length(distributedObjectWithArgsList) <= 0)
1533 	{
1534 		/* no distributed functions to drop */
1535 		return NIL;
1536 	}
1537 
1538 	/*
1539 	 * managing types can only be done on the coordinator if ddl propagation is on. when
1540 	 * it is off we will never get here. MX workers don't have a notion of distributed
1541 	 * types, so we block the call.
1542 	 */
1543 	EnsureCoordinator();
1544 	EnsureSequentialModeForFunctionDDL();
1545 
1546 	/* remove the entries for the distributed objects on dropping */
1547 	ObjectAddress *address = NULL;
1548 	foreach_ptr(address, distributedFunctionAddresses)
1549 	{
1550 		UnmarkObjectDistributed(address);
1551 	}
1552 
1553 	/*
1554 	 * Swap the list of objects before deparsing and restore the old list after. This
1555 	 * ensures we only have distributed functions in the deparsed drop statement.
1556 	 */
1557 	DropStmt *stmtCopy = copyObject(stmt);
1558 	stmtCopy->objects = distributedObjectWithArgsList;
1559 	const char *dropStmtSql = DeparseTreeNode((Node *) stmtCopy);
1560 
1561 	List *commands = list_make3(DISABLE_DDL_PROPAGATION,
1562 								(void *) dropStmtSql,
1563 								ENABLE_DDL_PROPAGATION);
1564 
1565 	return NodeDDLTaskList(NON_COORDINATOR_NODES, commands);
1566 }
1567 
1568 
1569 /*
1570  * PreprocessAlterFunctionDependsStmt is called during the planning phase of an
1571  * ALTER FUNCION ... DEPENDS ON EXTENSION ... statement. Since functions depending on
1572  * extensions are assumed to be Owned by an extension we assume the extension to keep the
1573  * function in sync.
1574  *
1575  * If we would allow users to create a dependency between a distributed function and an
1576  * extension our pruning logic for which objects to distribute as dependencies of other
1577  * objects will change significantly which could cause issues adding new workers. Hence we
1578  * don't allow this dependency to be created.
1579  */
1580 List *
PreprocessAlterFunctionDependsStmt(Node * node,const char * queryString,ProcessUtilityContext processUtilityContext)1581 PreprocessAlterFunctionDependsStmt(Node *node, const char *queryString,
1582 								   ProcessUtilityContext processUtilityContext)
1583 {
1584 	AlterObjectDependsStmt *stmt = castNode(AlterObjectDependsStmt, node);
1585 	AssertObjectTypeIsFunctional(stmt->objectType);
1586 
1587 	if (creating_extension)
1588 	{
1589 		/*
1590 		 * extensions should be created separately on the workers, types cascading from an
1591 		 * extension should therefor not be propagated here.
1592 		 */
1593 		return NIL;
1594 	}
1595 
1596 	if (!EnableDependencyCreation)
1597 	{
1598 		/*
1599 		 * we are configured to disable object propagation, should not propagate anything
1600 		 */
1601 		return NIL;
1602 	}
1603 
1604 	ObjectAddress address = GetObjectAddressFromParseTree((Node *) stmt, true);
1605 	if (!IsObjectDistributed(&address))
1606 	{
1607 		return NIL;
1608 	}
1609 
1610 	/*
1611 	 * Distributed objects should not start depending on an extension, this will break
1612 	 * the dependency resolving mechanism we use to replicate distributed objects to new
1613 	 * workers
1614 	 */
1615 
1616 	const char *functionName =
1617 		getObjectIdentity_compat(&address, /* missingOk: */ false);
1618 	ereport(ERROR, (errmsg("distrtibuted functions are not allowed to depend on an "
1619 						   "extension"),
1620 					errdetail("Function \"%s\" is already distributed. Functions from "
1621 							  "extensions are expected to be created on the workers by "
1622 							  "the extension they depend on.", functionName)));
1623 }
1624 
1625 
1626 /*
1627  * AlterFunctionDependsStmtObjectAddress resolves the ObjectAddress of the function that
1628  * is the subject of an ALTER FUNCTION ... DEPENS ON EXTENSION ... statement. If
1629  * missing_ok is set to false the lookup will raise an error.
1630  */
1631 ObjectAddress
AlterFunctionDependsStmtObjectAddress(Node * node,bool missing_ok)1632 AlterFunctionDependsStmtObjectAddress(Node *node, bool missing_ok)
1633 {
1634 	AlterObjectDependsStmt *stmt = castNode(AlterObjectDependsStmt, node);
1635 	AssertObjectTypeIsFunctional(stmt->objectType);
1636 
1637 	return FunctionToObjectAddress(stmt->objectType,
1638 								   castNode(ObjectWithArgs, stmt->object), missing_ok);
1639 }
1640 
1641 
1642 /*
1643  * PostprocessAlterFunctionSchemaStmt is executed after the change has been applied locally,
1644  * we can now use the new dependencies of the function to ensure all its dependencies
1645  * exist on the workers before we apply the commands remotely.
1646  */
1647 List *
PostprocessAlterFunctionSchemaStmt(Node * node,const char * queryString)1648 PostprocessAlterFunctionSchemaStmt(Node *node, const char *queryString)
1649 {
1650 	AlterObjectSchemaStmt *stmt = castNode(AlterObjectSchemaStmt, node);
1651 	AssertObjectTypeIsFunctional(stmt->objectType);
1652 
1653 	ObjectAddress address = GetObjectAddressFromParseTree((Node *) stmt, false);
1654 	if (!ShouldPropagateAlterFunction(&address))
1655 	{
1656 		return NIL;
1657 	}
1658 
1659 	/* dependencies have changed (schema) let's ensure they exist */
1660 	EnsureDependenciesExistOnAllNodes(&address);
1661 
1662 	return NIL;
1663 }
1664 
1665 
1666 /*
1667  * AlterFunctionStmtObjectAddress returns the ObjectAddress of the subject in the
1668  * AlterFunctionStmt. If missing_ok is set to false an error will be raised if postgres
1669  * was unable to find the function/procedure that was the target of the statement.
1670  */
1671 ObjectAddress
AlterFunctionStmtObjectAddress(Node * node,bool missing_ok)1672 AlterFunctionStmtObjectAddress(Node *node, bool missing_ok)
1673 {
1674 	AlterFunctionStmt *stmt = castNode(AlterFunctionStmt, node);
1675 	return FunctionToObjectAddress(stmt->objtype, stmt->func, missing_ok);
1676 }
1677 
1678 
1679 /*
1680  * RenameFunctionStmtObjectAddress returns the ObjectAddress of the function that is the
1681  * subject of the RenameStmt. Errors if missing_ok is false.
1682  */
1683 ObjectAddress
RenameFunctionStmtObjectAddress(Node * node,bool missing_ok)1684 RenameFunctionStmtObjectAddress(Node *node, bool missing_ok)
1685 {
1686 	RenameStmt *stmt = castNode(RenameStmt, node);
1687 	return FunctionToObjectAddress(stmt->renameType,
1688 								   castNode(ObjectWithArgs, stmt->object), missing_ok);
1689 }
1690 
1691 
1692 /*
1693  * AlterFunctionOwnerObjectAddress returns the ObjectAddress of the function that is the
1694  * subject of the AlterOwnerStmt. Errors if missing_ok is false.
1695  */
1696 ObjectAddress
AlterFunctionOwnerObjectAddress(Node * node,bool missing_ok)1697 AlterFunctionOwnerObjectAddress(Node *node, bool missing_ok)
1698 {
1699 	AlterOwnerStmt *stmt = castNode(AlterOwnerStmt, node);
1700 	return FunctionToObjectAddress(stmt->objectType,
1701 								   castNode(ObjectWithArgs, stmt->object), missing_ok);
1702 }
1703 
1704 
1705 /*
1706  * AlterFunctionSchemaStmtObjectAddress returns the ObjectAddress of the function that is
1707  * the subject of the AlterObjectSchemaStmt. Errors if missing_ok is false.
1708  *
1709  * This could be called both before or after it has been applied locally. It will look in
1710  * the old schema first, if the function cannot be found in that schema it will look in
1711  * the new schema. Errors if missing_ok is false and the type cannot be found in either of
1712  * the schemas.
1713  */
1714 ObjectAddress
AlterFunctionSchemaStmtObjectAddress(Node * node,bool missing_ok)1715 AlterFunctionSchemaStmtObjectAddress(Node *node, bool missing_ok)
1716 {
1717 	AlterObjectSchemaStmt *stmt = castNode(AlterObjectSchemaStmt, node);
1718 	AssertObjectTypeIsFunctional(stmt->objectType);
1719 
1720 	ObjectWithArgs *objectWithArgs = castNode(ObjectWithArgs, stmt->object);
1721 	Oid funcOid = LookupFuncWithArgs(stmt->objectType, objectWithArgs, true);
1722 	List *names = objectWithArgs->objname;
1723 
1724 	if (funcOid == InvalidOid)
1725 	{
1726 		/*
1727 		 * couldn't find the function, might have already been moved to the new schema, we
1728 		 * construct a new objname that uses the new schema to search in.
1729 		 */
1730 
1731 		/* the name of the function is the last in the list of names */
1732 		Value *funcNameStr = lfirst(list_tail(names));
1733 		List *newNames = list_make2(makeString(stmt->newschema), funcNameStr);
1734 
1735 		/*
1736 		 * we don't error here either, as the error would be not a good user facing
1737 		 * error if the type didn't exist in the first place.
1738 		 */
1739 		objectWithArgs->objname = newNames;
1740 		funcOid = LookupFuncWithArgs(stmt->objectType, objectWithArgs, true);
1741 		objectWithArgs->objname = names; /* restore the original names */
1742 
1743 		/*
1744 		 * if the function is still invalid we couldn't find the function, cause postgres
1745 		 * to error by preforming a lookup once more. Since we know the
1746 		 */
1747 		if (!missing_ok && funcOid == InvalidOid)
1748 		{
1749 			/*
1750 			 * this will most probably throw an error, unless for some reason the function
1751 			 * has just been created (if possible at all). For safety we assign the
1752 			 * funcOid.
1753 			 */
1754 			funcOid = LookupFuncWithArgs(stmt->objectType, objectWithArgs,
1755 										 missing_ok);
1756 		}
1757 	}
1758 
1759 	ObjectAddress address = { 0 };
1760 	ObjectAddressSet(address, ProcedureRelationId, funcOid);
1761 
1762 	return address;
1763 }
1764 
1765 
1766 /*
1767  * GenerateBackupNameForProcCollision generates a new proc name for an existing proc. The
1768  * name is generated in such a way that the new name doesn't overlap with an existing proc
1769  * by adding a suffix with incrementing number after the new name.
1770  */
1771 char *
GenerateBackupNameForProcCollision(const ObjectAddress * address)1772 GenerateBackupNameForProcCollision(const ObjectAddress *address)
1773 {
1774 	char *newName = palloc0(NAMEDATALEN);
1775 	char suffix[NAMEDATALEN] = { 0 };
1776 	int count = 0;
1777 	Value *namespace = makeString(get_namespace_name(get_func_namespace(
1778 														 address->objectId)));
1779 	char *baseName = get_func_name(address->objectId);
1780 	int baseLength = strlen(baseName);
1781 	Oid *argtypes = NULL;
1782 	char **argnames = NULL;
1783 	char *argmodes = NULL;
1784 	HeapTuple proctup = SearchSysCache1(PROCOID, address->objectId);
1785 
1786 	if (!HeapTupleIsValid(proctup))
1787 	{
1788 		elog(ERROR, "citus cache lookup failed.");
1789 	}
1790 
1791 	int numargs = get_func_arg_info(proctup, &argtypes, &argnames, &argmodes);
1792 	ReleaseSysCache(proctup);
1793 
1794 	while (true)
1795 	{
1796 		int suffixLength = SafeSnprintf(suffix, NAMEDATALEN - 1, "(citus_backup_%d)",
1797 										count);
1798 
1799 		/* trim the base name at the end to leave space for the suffix and trailing \0 */
1800 		baseLength = Min(baseLength, NAMEDATALEN - suffixLength - 1);
1801 
1802 		/* clear newName before copying the potentially trimmed baseName and suffix */
1803 		memset(newName, 0, NAMEDATALEN);
1804 		strncpy_s(newName, NAMEDATALEN, baseName, baseLength);
1805 		strncpy_s(newName + baseLength, NAMEDATALEN - baseLength, suffix,
1806 				  suffixLength);
1807 
1808 		List *newProcName = list_make2(namespace, makeString(newName));
1809 
1810 		/* don't need to rename if the input arguments don't match */
1811 		FuncCandidateList clist = FuncnameGetCandidates_compat(newProcName, numargs, NIL,
1812 															   false, false, false, true);
1813 		for (; clist; clist = clist->next)
1814 		{
1815 			if (memcmp(clist->args, argtypes, sizeof(Oid) * numargs) == 0)
1816 			{
1817 				break;
1818 			}
1819 		}
1820 
1821 		if (!clist)
1822 		{
1823 			return newName;
1824 		}
1825 
1826 		count++;
1827 	}
1828 }
1829 
1830 
1831 /*
1832  * ObjectWithArgsFromOid returns the corresponding ObjectWithArgs node for a given pg_proc oid
1833  */
1834 ObjectWithArgs *
ObjectWithArgsFromOid(Oid funcOid)1835 ObjectWithArgsFromOid(Oid funcOid)
1836 {
1837 	ObjectWithArgs *objectWithArgs = makeNode(ObjectWithArgs);
1838 	List *objargs = NIL;
1839 	Oid *argTypes = NULL;
1840 	char **argNames = NULL;
1841 	char *argModes = NULL;
1842 	HeapTuple proctup = SearchSysCache1(PROCOID, funcOid);
1843 
1844 	if (!HeapTupleIsValid(proctup))
1845 	{
1846 		elog(ERROR, "citus cache lookup failed.");
1847 	}
1848 
1849 	int numargs = get_func_arg_info(proctup, &argTypes, &argNames, &argModes);
1850 
1851 	objectWithArgs->objname = list_make2(
1852 		makeString(get_namespace_name(get_func_namespace(funcOid))),
1853 		makeString(get_func_name(funcOid))
1854 		);
1855 
1856 	for (int i = 0; i < numargs; i++)
1857 	{
1858 		if (argModes == NULL ||
1859 			argModes[i] != PROARGMODE_OUT || argModes[i] != PROARGMODE_TABLE)
1860 		{
1861 			objargs = lappend(objargs, makeTypeNameFromOid(argTypes[i], -1));
1862 		}
1863 	}
1864 	objectWithArgs->objargs = objargs;
1865 
1866 	ReleaseSysCache(proctup);
1867 
1868 	return objectWithArgs;
1869 }
1870 
1871 
1872 /*
1873  * FunctionToObjectAddress returns the ObjectAddress of a Function or Procedure based on
1874  * its type and ObjectWithArgs describing the Function/Procedure. If missing_ok is set to
1875  * false an error will be raised by postgres explaining the Function/Procedure could not
1876  * be found.
1877  */
1878 static ObjectAddress
FunctionToObjectAddress(ObjectType objectType,ObjectWithArgs * objectWithArgs,bool missing_ok)1879 FunctionToObjectAddress(ObjectType objectType, ObjectWithArgs *objectWithArgs,
1880 						bool missing_ok)
1881 {
1882 	AssertObjectTypeIsFunctional(objectType);
1883 
1884 	Oid funcOid = LookupFuncWithArgs(objectType, objectWithArgs, missing_ok);
1885 	ObjectAddress address = { 0 };
1886 	ObjectAddressSet(address, ProcedureRelationId, funcOid);
1887 
1888 	return address;
1889 }
1890 
1891 
1892 /*
1893  * ErrorIfUnsupportedAlterFunctionStmt raises an error if the AlterFunctionStmt contains a
1894  * construct that is not supported to be altered on a distributed function. It is assumed
1895  * the statement passed in is already tested to be targeting a distributed function, and
1896  * will only execute the checks to error on unsupported constructs.
1897  *
1898  * Unsupported Constructs:
1899  *  - ALTER FUNCTION ... SET ... FROM CURRENT
1900  */
1901 static void
ErrorIfUnsupportedAlterFunctionStmt(AlterFunctionStmt * stmt)1902 ErrorIfUnsupportedAlterFunctionStmt(AlterFunctionStmt *stmt)
1903 {
1904 	DefElem *action = NULL;
1905 	foreach_ptr(action, stmt->actions)
1906 	{
1907 		if (strcmp(action->defname, "set") == 0)
1908 		{
1909 			VariableSetStmt *setStmt = castNode(VariableSetStmt, action->arg);
1910 			if (setStmt->kind == VAR_SET_CURRENT)
1911 			{
1912 				/* check if the set action is a SET ... FROM CURRENT */
1913 				ereport(ERROR, (errmsg("unsupported ALTER FUNCTION ... SET ... FROM "
1914 									   "CURRENT for a distributed function"),
1915 								errhint("SET FROM CURRENT is not supported for "
1916 										"distributed functions, instead use the SET ... "
1917 										"TO ... syntax with a constant value.")));
1918 			}
1919 		}
1920 	}
1921 }
1922 
1923 
1924 /*
1925  * ErrorIfFunctionDependsOnExtension functions depending on extensions should raise an
1926  * error informing the user why they can't be distributed.
1927  */
1928 static void
ErrorIfFunctionDependsOnExtension(const ObjectAddress * functionAddress)1929 ErrorIfFunctionDependsOnExtension(const ObjectAddress *functionAddress)
1930 {
1931 	/* captures the extension address during lookup */
1932 	ObjectAddress extensionAddress = { 0 };
1933 
1934 	if (IsObjectAddressOwnedByExtension(functionAddress, &extensionAddress))
1935 	{
1936 		char *functionName =
1937 			getObjectIdentity_compat(functionAddress, /* missingOk: */ false);
1938 		char *extensionName =
1939 			getObjectIdentity_compat(&extensionAddress, /* missingOk: */ false);
1940 		ereport(ERROR, (errmsg("unable to create a distributed function from functions "
1941 							   "owned by an extension"),
1942 						errdetail("Function \"%s\" has a dependency on extension \"%s\". "
1943 								  "Functions depending on an extension cannot be "
1944 								  "distributed. Create the function by creating the "
1945 								  "extension on the workers.", functionName,
1946 								  extensionName)));
1947 	}
1948 }
1949 
1950 
1951 /* returns the quoted qualified name of a given function oid */
1952 static char *
quote_qualified_func_name(Oid funcOid)1953 quote_qualified_func_name(Oid funcOid)
1954 {
1955 	return quote_qualified_identifier(
1956 		get_namespace_name(get_func_namespace(funcOid)),
1957 		get_func_name(funcOid));
1958 }
1959