1 /*-------------------------------------------------------------------------
2  * trigger.c
3  *
4  * This file contains functions to create and process trigger objects on
5  * citus tables.
6  *
7  * Copyright (c) Citus Data, Inc.
8  *
9  *-------------------------------------------------------------------------
10  */
11 #include "postgres.h"
12 #include "distributed/pg_version_constants.h"
13 
14 #include "access/genam.h"
15 #include "access/table.h"
16 #include "catalog/indexing.h"
17 #include "catalog/namespace.h"
18 #include "catalog/pg_trigger.h"
19 #include "commands/trigger.h"
20 #include "distributed/citus_ruleutils.h"
21 #include "distributed/commands.h"
22 #include "distributed/commands/utility_hook.h"
23 #include "distributed/coordinator_protocol.h"
24 #include "distributed/deparser.h"
25 #include "distributed/listutils.h"
26 #include "distributed/metadata_cache.h"
27 #include "distributed/namespace_utils.h"
28 #include "distributed/shard_utils.h"
29 #include "distributed/worker_protocol.h"
30 #include "utils/fmgroids.h"
31 #include "utils/lsyscache.h"
32 #include "utils/syscache.h"
33 
34 
35 /* appropriate lock modes for the owner relation according to postgres */
36 #define CREATE_TRIGGER_LOCK_MODE ShareRowExclusiveLock
37 #define ALTER_TRIGGER_LOCK_MODE AccessExclusiveLock
38 #define DROP_TRIGGER_LOCK_MODE AccessExclusiveLock
39 
40 
41 /* local function forward declarations */
42 static bool IsCreateCitusTruncateTriggerStmt(CreateTrigStmt *createTriggerStmt);
43 static Value * GetAlterTriggerDependsTriggerNameValue(AlterObjectDependsStmt *
44 													  alterTriggerDependsStmt);
45 static void ErrorIfUnsupportedDropTriggerCommand(DropStmt *dropTriggerStmt);
46 static RangeVar * GetDropTriggerStmtRelation(DropStmt *dropTriggerStmt);
47 static void ExtractDropStmtTriggerAndRelationName(DropStmt *dropTriggerStmt,
48 												  char **triggerName,
49 												  char **relationName);
50 static void ErrorIfDropStmtDropsMultipleTriggers(DropStmt *dropTriggerStmt);
51 static int16 GetTriggerTypeById(Oid triggerId);
52 
53 
54 /*
55  * GetExplicitTriggerCommandList returns the list of DDL commands to create
56  * triggers that are explicitly created for the table with relationId. See
57  * comment of GetExplicitTriggerIdList function.
58  */
59 List *
GetExplicitTriggerCommandList(Oid relationId)60 GetExplicitTriggerCommandList(Oid relationId)
61 {
62 	List *createTriggerCommandList = NIL;
63 
64 	PushOverrideEmptySearchPath(CurrentMemoryContext);
65 
66 	List *triggerIdList = GetExplicitTriggerIdList(relationId);
67 
68 	Oid triggerId = InvalidOid;
69 	foreach_oid(triggerId, triggerIdList)
70 	{
71 		char *createTriggerCommand = pg_get_triggerdef_command(triggerId);
72 
73 		createTriggerCommandList = lappend(
74 			createTriggerCommandList,
75 			makeTableDDLCommandString(createTriggerCommand));
76 	}
77 
78 	/* revert back to original search_path */
79 	PopOverrideSearchPath();
80 
81 	return createTriggerCommandList;
82 }
83 
84 
85 /*
86  * GetTriggerTupleById returns copy of the heap tuple from pg_trigger for
87  * the trigger with triggerId. If no such trigger exists, this function returns
88  * NULL or errors out depending on missingOk.
89  */
90 HeapTuple
GetTriggerTupleById(Oid triggerId,bool missingOk)91 GetTriggerTupleById(Oid triggerId, bool missingOk)
92 {
93 	Relation pgTrigger = table_open(TriggerRelationId, AccessShareLock);
94 
95 	int scanKeyCount = 1;
96 	ScanKeyData scanKey[1];
97 
98 	AttrNumber attrNumber = Anum_pg_trigger_oid;
99 
100 	ScanKeyInit(&scanKey[0], attrNumber, BTEqualStrategyNumber,
101 				F_OIDEQ, ObjectIdGetDatum(triggerId));
102 
103 	bool useIndex = true;
104 	SysScanDesc scanDescriptor = systable_beginscan(pgTrigger, TriggerOidIndexId,
105 													useIndex, NULL, scanKeyCount,
106 													scanKey);
107 
108 	HeapTuple targetHeapTuple = NULL;
109 
110 	HeapTuple heapTuple = systable_getnext(scanDescriptor);
111 	if (HeapTupleIsValid(heapTuple))
112 	{
113 		targetHeapTuple = heap_copytuple(heapTuple);
114 	}
115 
116 	systable_endscan(scanDescriptor);
117 	table_close(pgTrigger, NoLock);
118 
119 	if (targetHeapTuple == NULL && missingOk == false)
120 	{
121 		ereport(ERROR, (errmsg("could not find heap tuple for trigger with "
122 							   "OID %d", triggerId)));
123 	}
124 
125 	return targetHeapTuple;
126 }
127 
128 
129 /*
130  * GetExplicitTriggerIdList returns a list of OIDs corresponding to the triggers
131  * that are explicitly created on the relation with relationId. That means,
132  * this function discards internal triggers implicitly created by postgres for
133  * foreign key constraint validation and the citus_truncate_trigger.
134  */
135 List *
GetExplicitTriggerIdList(Oid relationId)136 GetExplicitTriggerIdList(Oid relationId)
137 {
138 	List *triggerIdList = NIL;
139 
140 	Relation pgTrigger = table_open(TriggerRelationId, AccessShareLock);
141 
142 	int scanKeyCount = 1;
143 	ScanKeyData scanKey[1];
144 
145 	ScanKeyInit(&scanKey[0], Anum_pg_trigger_tgrelid,
146 				BTEqualStrategyNumber, F_OIDEQ, relationId);
147 
148 	bool useIndex = true;
149 	SysScanDesc scanDescriptor = systable_beginscan(pgTrigger, TriggerRelidNameIndexId,
150 													useIndex, NULL, scanKeyCount,
151 													scanKey);
152 
153 	HeapTuple heapTuple = systable_getnext(scanDescriptor);
154 	while (HeapTupleIsValid(heapTuple))
155 	{
156 		Form_pg_trigger triggerForm = (Form_pg_trigger) GETSTRUCT(heapTuple);
157 
158 		/*
159 		 * Note that we mark truncate trigger that we create on citus tables as
160 		 * internal. Hence, below we discard citus_truncate_trigger as well as
161 		 * the implicit triggers created by postgres for foreign key validation.
162 		 */
163 		if (!triggerForm->tgisinternal)
164 		{
165 			Oid triggerId = get_relation_trigger_oid_compat(heapTuple);
166 			triggerIdList = lappend_oid(triggerIdList, triggerId);
167 		}
168 
169 		heapTuple = systable_getnext(scanDescriptor);
170 	}
171 
172 	systable_endscan(scanDescriptor);
173 	table_close(pgTrigger, NoLock);
174 
175 	return triggerIdList;
176 }
177 
178 
179 /*
180  * get_relation_trigger_oid_compat returns OID of the trigger represented
181  * by the constraintForm, which is passed as an heapTuple. OID of the
182  * trigger is already stored in the triggerForm struct if major PostgreSQL
183  * version is 12. However, in the older versions, we should utilize
184  * HeapTupleGetOid to deduce that OID with no cost.
185  */
186 Oid
get_relation_trigger_oid_compat(HeapTuple heapTuple)187 get_relation_trigger_oid_compat(HeapTuple heapTuple)
188 {
189 	Assert(HeapTupleIsValid(heapTuple));
190 
191 
192 	Form_pg_trigger triggerForm = (Form_pg_trigger) GETSTRUCT(heapTuple);
193 	Oid triggerOid = triggerForm->oid;
194 
195 	return triggerOid;
196 }
197 
198 
199 /*
200  * PostprocessCreateTriggerStmt is called after a CREATE TRIGGER command has
201  * been executed by standard process utility. This function errors out for
202  * unsupported commands or creates ddl job for supported CREATE TRIGGER commands.
203  */
204 List *
PostprocessCreateTriggerStmt(Node * node,const char * queryString)205 PostprocessCreateTriggerStmt(Node *node, const char *queryString)
206 {
207 	CreateTrigStmt *createTriggerStmt = castNode(CreateTrigStmt, node);
208 	if (IsCreateCitusTruncateTriggerStmt(createTriggerStmt))
209 	{
210 		return NIL;
211 	}
212 
213 	RangeVar *relation = createTriggerStmt->relation;
214 	bool missingOk = false;
215 	Oid relationId = RangeVarGetRelid(relation, CREATE_TRIGGER_LOCK_MODE, missingOk);
216 
217 	if (!IsCitusTable(relationId))
218 	{
219 		return NIL;
220 	}
221 
222 	EnsureCoordinator();
223 
224 	ErrorOutForTriggerIfNotCitusLocalTable(relationId);
225 
226 	if (IsCitusTableType(relationId, CITUS_LOCAL_TABLE))
227 	{
228 		ObjectAddress objectAddress = GetObjectAddressFromParseTree(node, missingOk);
229 		EnsureDependenciesExistOnAllNodes(&objectAddress);
230 
231 		char *triggerName = createTriggerStmt->trigname;
232 		return CitusLocalTableTriggerCommandDDLJob(relationId, triggerName,
233 												   queryString);
234 	}
235 
236 	return NIL;
237 }
238 
239 
240 /*
241  * CreateTriggerStmtObjectAddress finds the ObjectAddress for the trigger that
242  * is created by given CreateTriggerStmt. If missingOk is false and if trigger
243  * does not exist, then it errors out.
244  *
245  * Never returns NULL, but the objid in the address can be invalid if missingOk
246  * was set to true.
247  */
248 ObjectAddress
CreateTriggerStmtObjectAddress(Node * node,bool missingOk)249 CreateTriggerStmtObjectAddress(Node *node, bool missingOk)
250 {
251 	CreateTrigStmt *createTriggerStmt = castNode(CreateTrigStmt, node);
252 
253 	RangeVar *relation = createTriggerStmt->relation;
254 	Oid relationId = RangeVarGetRelid(relation, CREATE_TRIGGER_LOCK_MODE, missingOk);
255 
256 	char *triggerName = createTriggerStmt->trigname;
257 	Oid triggerId = get_trigger_oid(relationId, triggerName, missingOk);
258 
259 	if (triggerId == InvalidOid && missingOk == false)
260 	{
261 		char *relationName = get_rel_name(relationId);
262 		ereport(ERROR, (errcode(ERRCODE_UNDEFINED_OBJECT),
263 						errmsg("trigger \"%s\" on relation \"%s\" does not exist",
264 							   triggerName, relationName)));
265 	}
266 
267 	ObjectAddress address = { 0 };
268 	ObjectAddressSet(address, TriggerRelationId, triggerId);
269 	return address;
270 }
271 
272 
273 /*
274  * IsCreateCitusTruncateTriggerStmt returns true if given createTriggerStmt
275  * creates citus_truncate_trigger.
276  */
277 static bool
IsCreateCitusTruncateTriggerStmt(CreateTrigStmt * createTriggerStmt)278 IsCreateCitusTruncateTriggerStmt(CreateTrigStmt *createTriggerStmt)
279 {
280 	List *functionNameList = createTriggerStmt->funcname;
281 	RangeVar *functionRangeVar = makeRangeVarFromNameList(functionNameList);
282 	char *functionName = functionRangeVar->relname;
283 	if (strncmp(functionName, CITUS_TRUNCATE_TRIGGER_NAME, NAMEDATALEN) == 0)
284 	{
285 		return true;
286 	}
287 
288 	return false;
289 }
290 
291 
292 /*
293  * CreateTriggerEventExtendNames extends relation name and trigger name with
294  * shardId, and sets schema name in given CreateTrigStmt.
295  */
296 void
CreateTriggerEventExtendNames(CreateTrigStmt * createTriggerStmt,char * schemaName,uint64 shardId)297 CreateTriggerEventExtendNames(CreateTrigStmt *createTriggerStmt, char *schemaName,
298 							  uint64 shardId)
299 {
300 	RangeVar *relation = createTriggerStmt->relation;
301 
302 	char **relationName = &(relation->relname);
303 	AppendShardIdToName(relationName, shardId);
304 
305 	char **triggerName = &(createTriggerStmt->trigname);
306 	AppendShardIdToName(triggerName, shardId);
307 
308 	char **relationSchemaName = &(relation->schemaname);
309 	SetSchemaNameIfNotExist(relationSchemaName, schemaName);
310 }
311 
312 
313 /*
314  * PostprocessAlterTriggerRenameStmt is called after a ALTER TRIGGER RENAME
315  * command has been executed by standard process utility. This function errors
316  * out for unsupported commands or creates ddl job for supported ALTER TRIGGER
317  * RENAME commands.
318  */
319 List *
PostprocessAlterTriggerRenameStmt(Node * node,const char * queryString)320 PostprocessAlterTriggerRenameStmt(Node *node, const char *queryString)
321 {
322 	RenameStmt *renameTriggerStmt = castNode(RenameStmt, node);
323 	Assert(renameTriggerStmt->renameType == OBJECT_TRIGGER);
324 
325 	RangeVar *relation = renameTriggerStmt->relation;
326 
327 	bool missingOk = false;
328 	Oid relationId = RangeVarGetRelid(relation, ALTER_TRIGGER_LOCK_MODE, missingOk);
329 
330 	if (!IsCitusTable(relationId))
331 	{
332 		return NIL;
333 	}
334 
335 	EnsureCoordinator();
336 	ErrorOutForTriggerIfNotCitusLocalTable(relationId);
337 
338 	if (IsCitusTableType(relationId, CITUS_LOCAL_TABLE))
339 	{
340 		/* use newname as standard process utility already renamed it */
341 		char *triggerName = renameTriggerStmt->newname;
342 		return CitusLocalTableTriggerCommandDDLJob(relationId, triggerName,
343 												   queryString);
344 	}
345 
346 	return NIL;
347 }
348 
349 
350 /*
351  * AlterTriggerRenameEventExtendNames extends relation name, old and new trigger
352  * name with shardId, and sets schema name in given RenameStmt.
353  */
354 void
AlterTriggerRenameEventExtendNames(RenameStmt * renameTriggerStmt,char * schemaName,uint64 shardId)355 AlterTriggerRenameEventExtendNames(RenameStmt *renameTriggerStmt, char *schemaName,
356 								   uint64 shardId)
357 {
358 	Assert(renameTriggerStmt->renameType == OBJECT_TRIGGER);
359 
360 	RangeVar *relation = renameTriggerStmt->relation;
361 
362 	char **relationName = &(relation->relname);
363 	AppendShardIdToName(relationName, shardId);
364 
365 	char **triggerOldName = &(renameTriggerStmt->subname);
366 	AppendShardIdToName(triggerOldName, shardId);
367 
368 	char **triggerNewName = &(renameTriggerStmt->newname);
369 	AppendShardIdToName(triggerNewName, shardId);
370 
371 	char **relationSchemaName = &(relation->schemaname);
372 	SetSchemaNameIfNotExist(relationSchemaName, schemaName);
373 }
374 
375 
376 /*
377  * PostprocessAlterTriggerDependsStmt is called after a ALTER TRIGGER DEPENDS ON
378  * command has been executed by standard process utility. This function errors out
379  * for unsupported commands or creates ddl job for supported ALTER TRIGGER DEPENDS
380  * ON commands.
381  */
382 List *
PostprocessAlterTriggerDependsStmt(Node * node,const char * queryString)383 PostprocessAlterTriggerDependsStmt(Node *node, const char *queryString)
384 {
385 	AlterObjectDependsStmt *alterTriggerDependsStmt =
386 		castNode(AlterObjectDependsStmt, node);
387 	Assert(alterTriggerDependsStmt->objectType == OBJECT_TRIGGER);
388 
389 	RangeVar *relation = alterTriggerDependsStmt->relation;
390 
391 	bool missingOk = false;
392 	Oid relationId = RangeVarGetRelid(relation, ALTER_TRIGGER_LOCK_MODE, missingOk);
393 
394 	if (!IsCitusTable(relationId))
395 	{
396 		return NIL;
397 	}
398 
399 	EnsureCoordinator();
400 	ErrorOutForTriggerIfNotCitusLocalTable(relationId);
401 
402 	if (IsCitusTableType(relationId, CITUS_LOCAL_TABLE))
403 	{
404 		Value *triggerNameValue =
405 			GetAlterTriggerDependsTriggerNameValue(alterTriggerDependsStmt);
406 		return CitusLocalTableTriggerCommandDDLJob(relationId, strVal(triggerNameValue),
407 												   queryString);
408 	}
409 
410 	return NIL;
411 }
412 
413 
414 /*
415  * AlterTriggerDependsEventExtendNames extends relation name and trigger name
416  * with shardId, and sets schema name in given AlterObjectDependsStmt.
417  */
418 void
AlterTriggerDependsEventExtendNames(AlterObjectDependsStmt * alterTriggerDependsStmt,char * schemaName,uint64 shardId)419 AlterTriggerDependsEventExtendNames(AlterObjectDependsStmt *alterTriggerDependsStmt,
420 									char *schemaName, uint64 shardId)
421 {
422 	Assert(alterTriggerDependsStmt->objectType == OBJECT_TRIGGER);
423 
424 	RangeVar *relation = alterTriggerDependsStmt->relation;
425 
426 	char **relationName = &(relation->relname);
427 	AppendShardIdToName(relationName, shardId);
428 
429 	Value *triggerNameValue =
430 		GetAlterTriggerDependsTriggerNameValue(alterTriggerDependsStmt);
431 	AppendShardIdToName(&strVal(triggerNameValue), shardId);
432 
433 	char **relationSchemaName = &(relation->schemaname);
434 	SetSchemaNameIfNotExist(relationSchemaName, schemaName);
435 }
436 
437 
438 /*
439  * GetAlterTriggerDependsTriggerName returns Value object for the trigger
440  * name that given AlterObjectDependsStmt is executed for.
441  */
442 static Value *
GetAlterTriggerDependsTriggerNameValue(AlterObjectDependsStmt * alterTriggerDependsStmt)443 GetAlterTriggerDependsTriggerNameValue(AlterObjectDependsStmt *alterTriggerDependsStmt)
444 {
445 	List *triggerObjectNameList = (List *) alterTriggerDependsStmt->object;
446 
447 	/*
448 	 * Before standard process utility, we only have trigger name in "object"
449 	 * list. However, standard process utility prepends that list with the
450 	 * relationNameList retrieved from AlterObjectDependsStmt->RangeVar and
451 	 * we call this method after standard process utility. So, for the further
452 	 * usages, it is certain that the last element in "object" list will always
453 	 * be the name of the trigger in either before or after standard process
454 	 * utility.
455 	 */
456 	Value *triggerNameValue = llast(triggerObjectNameList);
457 	return triggerNameValue;
458 }
459 
460 
461 /*
462  * PreprocessDropTriggerStmt is called before a DROP TRIGGER command has been
463  * executed by standard process utility. This function errors out for
464  * unsupported commands or creates ddl job for supported DROP TRIGGER commands.
465  * The reason we process drop trigger commands before standard process utility
466  * (unlike the other type of trigger commands) is that we act according to trigger
467  * type in CitusLocalTableTriggerCommandDDLJob but trigger wouldn't exist after
468  * standard process utility.
469  */
470 List *
PreprocessDropTriggerStmt(Node * node,const char * queryString,ProcessUtilityContext processUtilityContext)471 PreprocessDropTriggerStmt(Node *node, const char *queryString,
472 						  ProcessUtilityContext processUtilityContext)
473 {
474 	DropStmt *dropTriggerStmt = castNode(DropStmt, node);
475 	Assert(dropTriggerStmt->removeType == OBJECT_TRIGGER);
476 
477 	RangeVar *relation = GetDropTriggerStmtRelation(dropTriggerStmt);
478 
479 	bool missingOk = true;
480 	Oid relationId = RangeVarGetRelid(relation, DROP_TRIGGER_LOCK_MODE, missingOk);
481 
482 	if (!OidIsValid(relationId))
483 	{
484 		/* let standard process utility to error out */
485 		return NIL;
486 	}
487 
488 	if (!IsCitusTable(relationId))
489 	{
490 		return NIL;
491 	}
492 
493 	ErrorIfUnsupportedDropTriggerCommand(dropTriggerStmt);
494 
495 	if (IsCitusTableType(relationId, CITUS_LOCAL_TABLE))
496 	{
497 		char *triggerName = NULL;
498 		ExtractDropStmtTriggerAndRelationName(dropTriggerStmt, &triggerName, NULL);
499 		return CitusLocalTableTriggerCommandDDLJob(relationId, triggerName,
500 												   queryString);
501 	}
502 
503 	return NIL;
504 }
505 
506 
507 /*
508  * ErrorIfUnsupportedDropTriggerCommand errors out for unsupported
509  * "DROP TRIGGER triggerName ON relationName" commands.
510  */
511 static void
ErrorIfUnsupportedDropTriggerCommand(DropStmt * dropTriggerStmt)512 ErrorIfUnsupportedDropTriggerCommand(DropStmt *dropTriggerStmt)
513 {
514 	RangeVar *relation = GetDropTriggerStmtRelation(dropTriggerStmt);
515 
516 	bool missingOk = false;
517 	Oid relationId = RangeVarGetRelid(relation, DROP_TRIGGER_LOCK_MODE, missingOk);
518 
519 	if (!IsCitusTable(relationId))
520 	{
521 		return;
522 	}
523 
524 	EnsureCoordinator();
525 	ErrorOutForTriggerIfNotCitusLocalTable(relationId);
526 }
527 
528 
529 /*
530  * ErrorOutForTriggerIfNotCitusLocalTable is a helper function to error
531  * out for unsupported trigger commands depending on the citus table type.
532  */
533 void
ErrorOutForTriggerIfNotCitusLocalTable(Oid relationId)534 ErrorOutForTriggerIfNotCitusLocalTable(Oid relationId)
535 {
536 	if (IsCitusTableType(relationId, CITUS_LOCAL_TABLE))
537 	{
538 		return;
539 	}
540 
541 	ereport(ERROR, (errmsg("triggers are only supported for local tables added "
542 						   "to metadata")));
543 }
544 
545 
546 /*
547  * GetDropTriggerStmtRelation takes a DropStmt for a trigger object and returns
548  * RangeVar for the relation that owns the trigger.
549  */
550 static RangeVar *
GetDropTriggerStmtRelation(DropStmt * dropTriggerStmt)551 GetDropTriggerStmtRelation(DropStmt *dropTriggerStmt)
552 {
553 	Assert(dropTriggerStmt->removeType == OBJECT_TRIGGER);
554 
555 	ErrorIfDropStmtDropsMultipleTriggers(dropTriggerStmt);
556 
557 	List *targetObjectList = dropTriggerStmt->objects;
558 	List *triggerObjectNameList = linitial(targetObjectList);
559 
560 	/*
561 	 * The name list that identifies the trigger to be dropped looks like:
562 	 * [catalogName, schemaName, relationName, triggerName], where, the first
563 	 * two elements are optional. We should take all elements except the
564 	 * triggerName to create the range var object that defines the owner
565 	 * relation.
566 	 */
567 	int relationNameListLength = list_length(triggerObjectNameList) - 1;
568 	List *relationNameList = list_truncate(list_copy(triggerObjectNameList),
569 										   relationNameListLength);
570 
571 	return makeRangeVarFromNameList(relationNameList);
572 }
573 
574 
575 /*
576  * DropTriggerEventExtendNames extends relation name and trigger name with
577  * shardId, and sets schema name in given DropStmt by recreating "objects"
578  * list.
579  */
580 void
DropTriggerEventExtendNames(DropStmt * dropTriggerStmt,char * schemaName,uint64 shardId)581 DropTriggerEventExtendNames(DropStmt *dropTriggerStmt, char *schemaName, uint64 shardId)
582 {
583 	Assert(dropTriggerStmt->removeType == OBJECT_TRIGGER);
584 
585 	char *triggerName = NULL;
586 	char *relationName = NULL;
587 	ExtractDropStmtTriggerAndRelationName(dropTriggerStmt, &triggerName, &relationName);
588 
589 	AppendShardIdToName(&triggerName, shardId);
590 	Value *triggerNameValue = makeString(triggerName);
591 
592 	AppendShardIdToName(&relationName, shardId);
593 	Value *relationNameValue = makeString(relationName);
594 
595 	Value *schemaNameValue = makeString(pstrdup(schemaName));
596 
597 	List *shardTriggerNameList =
598 		list_make3(schemaNameValue, relationNameValue, triggerNameValue);
599 	dropTriggerStmt->objects = list_make1(shardTriggerNameList);
600 }
601 
602 
603 /*
604  * ExtractDropStmtTriggerAndRelationName extracts triggerName and relationName
605  * from given dropTriggerStmt if arguments are passed as non-null pointers.
606  */
607 static void
ExtractDropStmtTriggerAndRelationName(DropStmt * dropTriggerStmt,char ** triggerName,char ** relationName)608 ExtractDropStmtTriggerAndRelationName(DropStmt *dropTriggerStmt, char **triggerName,
609 									  char **relationName)
610 {
611 	ErrorIfDropStmtDropsMultipleTriggers(dropTriggerStmt);
612 
613 	List *targetObjectList = dropTriggerStmt->objects;
614 	List *triggerObjectNameList = linitial(targetObjectList);
615 	int objectNameListLength = list_length(triggerObjectNameList);
616 
617 	if (triggerName != NULL)
618 	{
619 		int triggerNameindex = objectNameListLength - 1;
620 		*triggerName = strVal(safe_list_nth(triggerObjectNameList, triggerNameindex));
621 	}
622 
623 	if (relationName != NULL)
624 	{
625 		int relationNameIndex = objectNameListLength - 2;
626 		*relationName = strVal(safe_list_nth(triggerObjectNameList, relationNameIndex));
627 	}
628 }
629 
630 
631 /*
632  * ErrorIfDropStmtDropsMultipleTriggers errors out if given drop trigger
633  * command drops more than one trigger. Actually, this can't be the case
634  * as postgres doesn't support dropping multiple triggers, but we should
635  * be on the safe side.
636  */
637 static void
ErrorIfDropStmtDropsMultipleTriggers(DropStmt * dropTriggerStmt)638 ErrorIfDropStmtDropsMultipleTriggers(DropStmt *dropTriggerStmt)
639 {
640 	List *targetObjectList = dropTriggerStmt->objects;
641 	if (list_length(targetObjectList) > 1)
642 	{
643 		ereport(ERROR, (errcode(ERRCODE_INTERNAL_ERROR),
644 						errmsg("cannot execute DROP TRIGGER command for multiple "
645 							   "triggers")));
646 	}
647 }
648 
649 
650 /*
651  * CitusLocalTableTriggerCommandDDLJob creates a ddl job to execute given
652  * queryString trigger command on shell relation(s) in mx worker(s) and to
653  * execute necessary ddl task on citus local table shard (if needed).
654  */
655 List *
CitusLocalTableTriggerCommandDDLJob(Oid relationId,char * triggerName,const char * queryString)656 CitusLocalTableTriggerCommandDDLJob(Oid relationId, char *triggerName,
657 									const char *queryString)
658 {
659 	DDLJob *ddlJob = palloc0(sizeof(DDLJob));
660 	ddlJob->targetRelationId = relationId;
661 	ddlJob->commandString = queryString;
662 
663 	if (!triggerName)
664 	{
665 		/*
666 		 * ENABLE/DISABLE TRIGGER ALL/USER commands do not specify trigger
667 		 * name.
668 		 */
669 		ddlJob->taskList = DDLTaskList(relationId, queryString);
670 		return list_make1(ddlJob);
671 	}
672 
673 	bool missingOk = true;
674 	Oid triggerId = get_trigger_oid(relationId, triggerName, missingOk);
675 	if (!OidIsValid(triggerId))
676 	{
677 		/*
678 		 * For DROP, ENABLE/DISABLE, ENABLE REPLICA/ALWAYS TRIGGER commands,
679 		 * we create ddl job in preprocess. So trigger may not exist.
680 		 */
681 		return NIL;
682 	}
683 
684 	int16 triggerType = GetTriggerTypeById(triggerId);
685 
686 	/* we don't have truncate triggers on shard relations */
687 	if (!TRIGGER_FOR_TRUNCATE(triggerType))
688 	{
689 		ddlJob->taskList = DDLTaskList(relationId, queryString);
690 	}
691 
692 	return list_make1(ddlJob);
693 }
694 
695 
696 /*
697  * GetTriggerTypeById returns trigger type (tgtype) of the trigger identified
698  * by triggerId if it exists. Otherwise, errors out.
699  */
700 static int16
GetTriggerTypeById(Oid triggerId)701 GetTriggerTypeById(Oid triggerId)
702 {
703 	bool missingOk = false;
704 	HeapTuple triggerTuple = GetTriggerTupleById(triggerId, missingOk);
705 
706 	Form_pg_trigger triggerForm = (Form_pg_trigger) GETSTRUCT(triggerTuple);
707 	int16 triggerType = triggerForm->tgtype;
708 	heap_freetuple(triggerTuple);
709 
710 	return triggerType;
711 }
712 
713 
714 /*
715  * GetTriggerFunctionId returns OID of the function that the trigger with
716  * triggerId executes if the trigger exists. Otherwise, errors out.
717  */
718 Oid
GetTriggerFunctionId(Oid triggerId)719 GetTriggerFunctionId(Oid triggerId)
720 {
721 	bool missingOk = false;
722 	HeapTuple triggerTuple = GetTriggerTupleById(triggerId, missingOk);
723 
724 	Form_pg_trigger triggerForm = (Form_pg_trigger) GETSTRUCT(triggerTuple);
725 	Oid functionId = triggerForm->tgfoid;
726 	heap_freetuple(triggerTuple);
727 
728 	return functionId;
729 }
730