1 /*-------------------------------------------------------------------------
2  *
3  * metadata_cache.c
4  *	  Distributed table metadata cache
5  *
6  * Copyright (c) Citus Data, Inc.
7  *-------------------------------------------------------------------------
8  */
9 
10 #include "distributed/pg_version_constants.h"
11 
12 #include "stdint.h"
13 #include "postgres.h"
14 #include "libpq-fe.h"
15 #include "miscadmin.h"
16 
17 #include "access/genam.h"
18 #include "access/heapam.h"
19 #include "access/htup_details.h"
20 #include "access/nbtree.h"
21 #include "access/xact.h"
22 #include "access/sysattr.h"
23 #include "catalog/indexing.h"
24 #include "catalog/pg_am.h"
25 #include "catalog/pg_collation.h"
26 #include "catalog/pg_enum.h"
27 #include "catalog/pg_extension.h"
28 #include "catalog/pg_namespace.h"
29 #include "catalog/pg_type.h"
30 #include "citus_version.h"
31 #include "commands/dbcommands.h"
32 #include "commands/extension.h"
33 #include "commands/trigger.h"
34 #include "distributed/colocation_utils.h"
35 #include "distributed/connection_management.h"
36 #include "distributed/citus_ruleutils.h"
37 #include "distributed/function_utils.h"
38 #include "distributed/foreign_key_relationship.h"
39 #include "distributed/listutils.h"
40 #include "distributed/metadata_utility.h"
41 #include "distributed/metadata/pg_dist_object.h"
42 #include "distributed/metadata_cache.h"
43 #include "distributed/multi_executor.h"
44 #include "distributed/multi_physical_planner.h"
45 #include "distributed/pg_dist_local_group.h"
46 #include "distributed/pg_dist_node_metadata.h"
47 #include "distributed/pg_dist_node.h"
48 #include "distributed/pg_dist_partition.h"
49 #include "distributed/pg_dist_shard.h"
50 #include "distributed/pg_dist_placement.h"
51 #include "distributed/shared_library_init.h"
52 #include "distributed/shardinterval_utils.h"
53 #include "distributed/version_compat.h"
54 #include "distributed/worker_manager.h"
55 #include "distributed/worker_protocol.h"
56 #include "executor/executor.h"
57 #include "nodes/makefuncs.h"
58 #include "nodes/memnodes.h"
59 #include "nodes/pg_list.h"
60 #include "parser/parse_func.h"
61 #include "parser/parse_type.h"
62 #include "storage/lmgr.h"
63 #include "utils/builtins.h"
64 #include "utils/catcache.h"
65 #include "utils/datum.h"
66 #include "utils/elog.h"
67 #include "utils/hsearch.h"
68 #if PG_VERSION_NUM >= PG_VERSION_13
69 #include "common/hashfn.h"
70 #endif
71 #include "utils/inval.h"
72 #include "utils/fmgroids.h"
73 #include "utils/lsyscache.h"
74 #include "utils/memutils.h"
75 #include "utils/palloc.h"
76 #include "utils/rel.h"
77 #include "utils/relfilenodemap.h"
78 #include "utils/relmapper.h"
79 #include "utils/resowner.h"
80 #include "utils/syscache.h"
81 #include "utils/typcache.h"
82 
83 
84 /* user configuration */
85 int ReadFromSecondaries = USE_SECONDARY_NODES_NEVER;
86 
87 
88 /*
89  * CitusTableCacheEntrySlot is entry type for DistTableCacheHash,
90  * entry data outlives slot on invalidation, so requires indirection.
91  */
92 typedef struct CitusTableCacheEntrySlot
93 {
94 	/* lookup key - must be first. A pg_class.oid oid. */
95 	Oid relationId;
96 
97 	/* Citus table metadata (NULL for local tables) */
98 	CitusTableCacheEntry *citusTableMetadata;
99 
100 	/*
101 	 * If isValid is false, we need to recheck whether the relation ID
102 	 * belongs to a Citus or not.
103 	 */
104 	bool isValid;
105 } CitusTableCacheEntrySlot;
106 
107 
108 /*
109  * ShardIdCacheEntry is the entry type for ShardIdCacheHash.
110  *
111  * This should never be used outside of this file. Use ShardInterval instead.
112  */
113 typedef struct ShardIdCacheEntry
114 {
115 	/* hash key, needs to be first */
116 	uint64 shardId;
117 
118 	/* pointer to the table entry to which this shard currently belongs */
119 	CitusTableCacheEntry *tableEntry;
120 
121 	/* index of the shard interval in the sortedShardIntervalArray of the table entry */
122 	int shardIndex;
123 } ShardIdCacheEntry;
124 
125 
126 /*
127  * State which should be cleared upon DROP EXTENSION. When the configuration
128  * changes, e.g. because extension is dropped, these summarily get set to 0.
129  */
130 typedef struct MetadataCacheData
131 {
132 	bool extensionLoaded;
133 	Oid distShardRelationId;
134 	Oid distPlacementRelationId;
135 	Oid distRebalanceStrategyRelationId;
136 	Oid distNodeRelationId;
137 	Oid distNodeNodeIdIndexId;
138 	Oid distLocalGroupRelationId;
139 	Oid distObjectRelationId;
140 	Oid distObjectPrimaryKeyIndexId;
141 	Oid distColocationRelationId;
142 	Oid distColocationConfigurationIndexId;
143 	Oid distPartitionRelationId;
144 	Oid distPartitionLogicalRelidIndexId;
145 	Oid distPartitionColocationidIndexId;
146 	Oid distShardLogicalRelidIndexId;
147 	Oid distShardShardidIndexId;
148 	Oid distPlacementShardidIndexId;
149 	Oid distPlacementPlacementidIndexId;
150 	Oid distPlacementGroupidIndexId;
151 	Oid distTransactionRelationId;
152 	Oid distTransactionGroupIndexId;
153 	Oid citusCatalogNamespaceId;
154 	Oid copyFormatTypeId;
155 	Oid readIntermediateResultFuncId;
156 	Oid readIntermediateResultArrayFuncId;
157 	Oid extraDataContainerFuncId;
158 	Oid workerHashFunctionId;
159 	Oid anyValueFunctionId;
160 	Oid textSendAsJsonbFunctionId;
161 	Oid extensionOwner;
162 	Oid binaryCopyFormatId;
163 	Oid textCopyFormatId;
164 	Oid primaryNodeRoleId;
165 	Oid secondaryNodeRoleId;
166 	Oid pgTableIsVisibleFuncId;
167 	Oid citusTableIsVisibleFuncId;
168 	Oid jsonbExtractPathFuncId;
169 	bool databaseNameValid;
170 	char databaseName[NAMEDATALEN];
171 } MetadataCacheData;
172 
173 
174 static MetadataCacheData MetadataCache;
175 
176 /* Citus extension version variables */
177 bool EnableVersionChecks = true; /* version checks are enabled */
178 
179 static bool citusVersionKnownCompatible = false;
180 
181 /* Hash table for informations about each partition */
182 static HTAB *DistTableCacheHash = NULL;
183 static List *DistTableCacheExpired = NIL;
184 
185 /* Hash table for informations about each shard */
186 static HTAB *ShardIdCacheHash = NULL;
187 
188 static MemoryContext MetadataCacheMemoryContext = NULL;
189 
190 /* Hash table for information about each object */
191 static HTAB *DistObjectCacheHash = NULL;
192 
193 /* Hash table for informations about worker nodes */
194 static HTAB *WorkerNodeHash = NULL;
195 static WorkerNode **WorkerNodeArray = NULL;
196 static int WorkerNodeCount = 0;
197 static bool workerNodeHashValid = false;
198 
199 /* default value is -1, for coordinator it's 0 and for worker nodes > 0 */
200 static int32 LocalGroupId = -1;
201 
202 /* built first time through in InitializeDistCache */
203 static ScanKeyData DistPartitionScanKey[1];
204 static ScanKeyData DistShardScanKey[1];
205 static ScanKeyData DistObjectScanKey[3];
206 
207 
208 /* local function forward declarations */
209 static bool IsCitusTableViaCatalog(Oid relationId);
210 static HeapTuple PgDistPartitionTupleViaCatalog(Oid relationId);
211 static ShardIdCacheEntry * LookupShardIdCacheEntry(int64 shardId);
212 static CitusTableCacheEntry * BuildCitusTableCacheEntry(Oid relationId);
213 static void BuildCachedShardList(CitusTableCacheEntry *cacheEntry);
214 static void PrepareWorkerNodeCache(void);
215 static bool CheckInstalledVersion(int elevel);
216 static char * AvailableExtensionVersion(void);
217 static char * InstalledExtensionVersion(void);
218 static bool CitusHasBeenLoadedInternal(void);
219 static void InitializeCaches(void);
220 static void InitializeDistCache(void);
221 static void InitializeDistObjectCache(void);
222 static void InitializeWorkerNodeCache(void);
223 static void RegisterForeignKeyGraphCacheCallbacks(void);
224 static void RegisterWorkerNodeCacheCallbacks(void);
225 static void RegisterLocalGroupIdCacheCallbacks(void);
226 static void RegisterCitusTableCacheEntryReleaseCallbacks(void);
227 static uint32 WorkerNodeHashCode(const void *key, Size keySize);
228 static void ResetCitusTableCacheEntry(CitusTableCacheEntry *cacheEntry);
229 static void RemoveStaleShardIdCacheEntries(CitusTableCacheEntry *tableEntry);
230 static void CreateDistTableCache(void);
231 static void CreateShardIdCache(void);
232 static void CreateDistObjectCache(void);
233 static void InvalidateForeignRelationGraphCacheCallback(Datum argument, Oid relationId);
234 static void InvalidateDistRelationCacheCallback(Datum argument, Oid relationId);
235 static void InvalidateNodeRelationCacheCallback(Datum argument, Oid relationId);
236 static void InvalidateLocalGroupIdRelationCacheCallback(Datum argument, Oid relationId);
237 static void CitusTableCacheEntryReleaseCallback(ResourceReleasePhase phase, bool isCommit,
238 												bool isTopLevel, void *arg);
239 static HeapTuple LookupDistPartitionTuple(Relation pgDistPartition, Oid relationId);
240 static void GetPartitionTypeInputInfo(char *partitionKeyString, char partitionMethod,
241 									  Oid *columnTypeId, int32 *columnTypeMod,
242 									  Oid *intervalTypeId, int32 *intervalTypeMod);
243 static void CachedNamespaceLookup(const char *nspname, Oid *cachedOid);
244 static void CachedRelationLookup(const char *relationName, Oid *cachedOid);
245 static void CachedRelationNamespaceLookup(const char *relationName, Oid relnamespace,
246 										  Oid *cachedOid);
247 static ShardPlacement * ResolveGroupShardPlacement(
248 	GroupShardPlacement *groupShardPlacement, CitusTableCacheEntry *tableEntry,
249 	int shardIndex);
250 static Oid LookupEnumValueId(Oid typeId, char *valueName);
251 static void InvalidateCitusTableCacheEntrySlot(CitusTableCacheEntrySlot *cacheSlot);
252 static void InvalidateDistTableCache(void);
253 static void InvalidateDistObjectCache(void);
254 static void InitializeTableCacheEntry(int64 shardId);
255 static bool IsCitusTableTypeInternal(char partitionMethod, char replicationModel,
256 									 CitusTableType tableType);
257 static bool RefreshTableCacheEntryIfInvalid(ShardIdCacheEntry *shardEntry);
258 
259 
260 /* exports for SQL callable functions */
261 PG_FUNCTION_INFO_V1(citus_dist_partition_cache_invalidate);
262 PG_FUNCTION_INFO_V1(master_dist_partition_cache_invalidate);
263 PG_FUNCTION_INFO_V1(citus_dist_shard_cache_invalidate);
264 PG_FUNCTION_INFO_V1(master_dist_shard_cache_invalidate);
265 PG_FUNCTION_INFO_V1(citus_dist_placement_cache_invalidate);
266 PG_FUNCTION_INFO_V1(master_dist_placement_cache_invalidate);
267 PG_FUNCTION_INFO_V1(citus_dist_node_cache_invalidate);
268 PG_FUNCTION_INFO_V1(master_dist_node_cache_invalidate);
269 PG_FUNCTION_INFO_V1(citus_dist_local_group_cache_invalidate);
270 PG_FUNCTION_INFO_V1(master_dist_local_group_cache_invalidate);
271 PG_FUNCTION_INFO_V1(citus_conninfo_cache_invalidate);
272 PG_FUNCTION_INFO_V1(master_dist_authinfo_cache_invalidate);
273 PG_FUNCTION_INFO_V1(citus_dist_object_cache_invalidate);
274 PG_FUNCTION_INFO_V1(master_dist_object_cache_invalidate);
275 PG_FUNCTION_INFO_V1(role_exists);
276 PG_FUNCTION_INFO_V1(authinfo_valid);
277 PG_FUNCTION_INFO_V1(poolinfo_valid);
278 
279 
280 /*
281  * EnsureModificationsCanRun checks if the current node is in recovery mode or
282  * citus.use_secondary_nodes is 'always'. If either is true the function errors out.
283  */
284 void
EnsureModificationsCanRun(void)285 EnsureModificationsCanRun(void)
286 {
287 	if (RecoveryInProgress() && !WritableStandbyCoordinator)
288 	{
289 		ereport(ERROR, (errmsg("writing to worker nodes is not currently allowed"),
290 						errdetail("the database is read-only")));
291 	}
292 
293 	if (ReadFromSecondaries == USE_SECONDARY_NODES_ALWAYS)
294 	{
295 		ereport(ERROR, (errmsg("writing to worker nodes is not currently allowed"),
296 						errdetail("citus.use_secondary_nodes is set to 'always'")));
297 	}
298 }
299 
300 
301 /*
302  * IsCitusTableType returns true if the given table with relationId
303  * belongs to a citus table that matches the given table type. If cache
304  * entry already exists, prefer using IsCitusTableTypeCacheEntry to avoid
305  * an extra lookup.
306  */
307 bool
IsCitusTableType(Oid relationId,CitusTableType tableType)308 IsCitusTableType(Oid relationId, CitusTableType tableType)
309 {
310 	CitusTableCacheEntry *tableEntry = LookupCitusTableCacheEntry(relationId);
311 
312 	/* we are not interested in postgres tables */
313 	if (tableEntry == NULL)
314 	{
315 		return false;
316 	}
317 	return IsCitusTableTypeCacheEntry(tableEntry, tableType);
318 }
319 
320 
321 /*
322  * IsCitusTableTypeCacheEntry returns true if the given table cache entry
323  * belongs to a citus table that matches the given table type.
324  */
325 bool
IsCitusTableTypeCacheEntry(CitusTableCacheEntry * tableEntry,CitusTableType tableType)326 IsCitusTableTypeCacheEntry(CitusTableCacheEntry *tableEntry, CitusTableType tableType)
327 {
328 	return IsCitusTableTypeInternal(tableEntry->partitionMethod,
329 									tableEntry->replicationModel, tableType);
330 }
331 
332 
333 /*
334  * IsCitusTableTypeInternal returns true if the given table entry belongs to
335  * the given table type group. For definition of table types, see CitusTableType.
336  */
337 static bool
IsCitusTableTypeInternal(char partitionMethod,char replicationModel,CitusTableType tableType)338 IsCitusTableTypeInternal(char partitionMethod, char replicationModel,
339 						 CitusTableType tableType)
340 {
341 	switch (tableType)
342 	{
343 		case HASH_DISTRIBUTED:
344 		{
345 			return partitionMethod == DISTRIBUTE_BY_HASH;
346 		}
347 
348 		case APPEND_DISTRIBUTED:
349 		{
350 			return partitionMethod == DISTRIBUTE_BY_APPEND;
351 		}
352 
353 		case RANGE_DISTRIBUTED:
354 		{
355 			return partitionMethod == DISTRIBUTE_BY_RANGE;
356 		}
357 
358 		case DISTRIBUTED_TABLE:
359 		{
360 			return partitionMethod == DISTRIBUTE_BY_HASH ||
361 				   partitionMethod == DISTRIBUTE_BY_RANGE ||
362 				   partitionMethod == DISTRIBUTE_BY_APPEND;
363 		}
364 
365 		case STRICTLY_PARTITIONED_DISTRIBUTED_TABLE:
366 		{
367 			return partitionMethod == DISTRIBUTE_BY_HASH ||
368 				   partitionMethod == DISTRIBUTE_BY_RANGE;
369 		}
370 
371 		case REFERENCE_TABLE:
372 		{
373 			return partitionMethod == DISTRIBUTE_BY_NONE &&
374 				   replicationModel == REPLICATION_MODEL_2PC;
375 		}
376 
377 		case CITUS_LOCAL_TABLE:
378 		{
379 			return partitionMethod == DISTRIBUTE_BY_NONE &&
380 				   replicationModel != REPLICATION_MODEL_2PC;
381 		}
382 
383 		case CITUS_TABLE_WITH_NO_DIST_KEY:
384 		{
385 			return partitionMethod == DISTRIBUTE_BY_NONE;
386 		}
387 
388 		case ANY_CITUS_TABLE_TYPE:
389 		{
390 			return true;
391 		}
392 
393 		default:
394 		{
395 			ereport(ERROR, (errmsg("Unknown table type %d", tableType)));
396 		}
397 	}
398 	return false;
399 }
400 
401 
402 /*
403  * IsCitusTable returns whether relationId is a distributed relation or
404  * not.
405  */
406 bool
IsCitusTable(Oid relationId)407 IsCitusTable(Oid relationId)
408 {
409 	return LookupCitusTableCacheEntry(relationId) != NULL;
410 }
411 
412 
413 /*
414  * IsCitusTableViaCatalog returns whether the given relation is a
415  * distributed table or not.
416  *
417  * It does so by searching pg_dist_partition, explicitly bypassing caches,
418  * because this function is designed to be used in cases where accessing
419  * metadata tables is not safe.
420  *
421  * NB: Currently this still hardcodes pg_dist_partition logicalrelid column
422  * offset and the corresponding index.  If we ever come close to changing
423  * that, we'll have to work a bit harder.
424  */
425 static bool
IsCitusTableViaCatalog(Oid relationId)426 IsCitusTableViaCatalog(Oid relationId)
427 {
428 	HeapTuple partitionTuple = PgDistPartitionTupleViaCatalog(relationId);
429 
430 	bool heapTupleIsValid = HeapTupleIsValid(partitionTuple);
431 
432 	if (heapTupleIsValid)
433 	{
434 		heap_freetuple(partitionTuple);
435 	}
436 	return heapTupleIsValid;
437 }
438 
439 
440 /*
441  * PartitionMethodViaCatalog gets a relationId and returns the partition
442  * method column from pg_dist_partition via reading from catalog.
443  */
444 char
PartitionMethodViaCatalog(Oid relationId)445 PartitionMethodViaCatalog(Oid relationId)
446 {
447 	HeapTuple partitionTuple = PgDistPartitionTupleViaCatalog(relationId);
448 	if (!HeapTupleIsValid(partitionTuple))
449 	{
450 		return DISTRIBUTE_BY_INVALID;
451 	}
452 
453 	Datum datumArray[Natts_pg_dist_partition];
454 	bool isNullArray[Natts_pg_dist_partition];
455 
456 	Relation pgDistPartition = table_open(DistPartitionRelationId(), AccessShareLock);
457 
458 	TupleDesc tupleDescriptor = RelationGetDescr(pgDistPartition);
459 	heap_deform_tuple(partitionTuple, tupleDescriptor, datumArray, isNullArray);
460 
461 	if (isNullArray[Anum_pg_dist_partition_partmethod - 1])
462 	{
463 		/* partition method cannot be NULL, still let's make sure */
464 		heap_freetuple(partitionTuple);
465 		table_close(pgDistPartition, NoLock);
466 		return DISTRIBUTE_BY_INVALID;
467 	}
468 
469 	Datum partitionMethodDatum = datumArray[Anum_pg_dist_partition_partmethod - 1];
470 	char partitionMethodChar = DatumGetChar(partitionMethodDatum);
471 
472 	heap_freetuple(partitionTuple);
473 	table_close(pgDistPartition, NoLock);
474 
475 	return partitionMethodChar;
476 }
477 
478 
479 /*
480  * PgDistPartitionTupleViaCatalog is a helper function that searches
481  * pg_dist_partition for the given relationId. The caller is responsible
482  * for ensuring that the returned heap tuple is valid before accessing
483  * its fields.
484  */
485 static HeapTuple
PgDistPartitionTupleViaCatalog(Oid relationId)486 PgDistPartitionTupleViaCatalog(Oid relationId)
487 {
488 	const int scanKeyCount = 1;
489 	ScanKeyData scanKey[1];
490 	bool indexOK = true;
491 
492 	Relation pgDistPartition = table_open(DistPartitionRelationId(), AccessShareLock);
493 
494 	ScanKeyInit(&scanKey[0], Anum_pg_dist_partition_logicalrelid,
495 				BTEqualStrategyNumber, F_OIDEQ, ObjectIdGetDatum(relationId));
496 
497 	SysScanDesc scanDescriptor = systable_beginscan(pgDistPartition,
498 													DistPartitionLogicalRelidIndexId(),
499 													indexOK, NULL, scanKeyCount, scanKey);
500 
501 	HeapTuple partitionTuple = systable_getnext(scanDescriptor);
502 
503 	if (HeapTupleIsValid(partitionTuple))
504 	{
505 		/* callers should have the tuple in their memory contexts */
506 		partitionTuple = heap_copytuple(partitionTuple);
507 	}
508 
509 	systable_endscan(scanDescriptor);
510 	table_close(pgDistPartition, AccessShareLock);
511 
512 	return partitionTuple;
513 }
514 
515 
516 /*
517  * IsCitusLocalTableByDistParams returns true if given partitionMethod and
518  * replicationModel would identify a citus local table.
519  */
520 bool
IsCitusLocalTableByDistParams(char partitionMethod,char replicationModel)521 IsCitusLocalTableByDistParams(char partitionMethod, char replicationModel)
522 {
523 	return partitionMethod == DISTRIBUTE_BY_NONE &&
524 		   replicationModel != REPLICATION_MODEL_2PC;
525 }
526 
527 
528 /*
529  * CitusTableList returns a list that includes all the valid distributed table
530  * cache entries.
531  */
532 List *
CitusTableList(void)533 CitusTableList(void)
534 {
535 	List *distributedTableList = NIL;
536 
537 	Assert(CitusHasBeenLoaded() && CheckCitusVersion(WARNING));
538 
539 	/* first, we need to iterate over pg_dist_partition */
540 	List *citusTableIdList = CitusTableTypeIdList(ANY_CITUS_TABLE_TYPE);
541 
542 	Oid relationId = InvalidOid;
543 	foreach_oid(relationId, citusTableIdList)
544 	{
545 		CitusTableCacheEntry *cacheEntry = GetCitusTableCacheEntry(relationId);
546 
547 		distributedTableList = lappend(distributedTableList, cacheEntry);
548 	}
549 
550 	return distributedTableList;
551 }
552 
553 
554 /*
555  * LoadShardInterval returns the, cached, metadata about a shard.
556  *
557  * The return value is a copy of the cached ShardInterval struct and may
558  * therefore be modified and/or freed.
559  */
560 ShardInterval *
LoadShardInterval(uint64 shardId)561 LoadShardInterval(uint64 shardId)
562 {
563 	ShardIdCacheEntry *shardIdEntry = LookupShardIdCacheEntry(shardId);
564 	CitusTableCacheEntry *tableEntry = shardIdEntry->tableEntry;
565 	int shardIndex = shardIdEntry->shardIndex;
566 
567 	/* the offset better be in a valid range */
568 	Assert(shardIndex < tableEntry->shardIntervalArrayLength);
569 
570 	ShardInterval *sourceShardInterval =
571 		tableEntry->sortedShardIntervalArray[shardIndex];
572 
573 	/* copy value to return */
574 	ShardInterval *shardInterval = CopyShardInterval(sourceShardInterval);
575 
576 	return shardInterval;
577 }
578 
579 
580 /*
581  * RelationIdOfShard returns the relationId of the given shardId.
582  */
583 Oid
RelationIdForShard(uint64 shardId)584 RelationIdForShard(uint64 shardId)
585 {
586 	ShardIdCacheEntry *shardIdEntry = LookupShardIdCacheEntry(shardId);
587 	CitusTableCacheEntry *tableEntry = shardIdEntry->tableEntry;
588 	return tableEntry->relationId;
589 }
590 
591 
592 /*
593  * ReferenceTableShardId returns true if the given shardId belongs to
594  * a reference table.
595  */
596 bool
ReferenceTableShardId(uint64 shardId)597 ReferenceTableShardId(uint64 shardId)
598 {
599 	ShardIdCacheEntry *shardIdEntry = LookupShardIdCacheEntry(shardId);
600 	CitusTableCacheEntry *tableEntry = shardIdEntry->tableEntry;
601 	return IsCitusTableTypeCacheEntry(tableEntry, REFERENCE_TABLE);
602 }
603 
604 
605 /*
606  * LoadGroupShardPlacement returns the cached shard placement metadata
607  *
608  * The return value is a copy of the cached GroupShardPlacement struct and may
609  * therefore be modified and/or freed.
610  */
611 GroupShardPlacement *
LoadGroupShardPlacement(uint64 shardId,uint64 placementId)612 LoadGroupShardPlacement(uint64 shardId, uint64 placementId)
613 {
614 	ShardIdCacheEntry *shardIdEntry = LookupShardIdCacheEntry(shardId);
615 	CitusTableCacheEntry *tableEntry = shardIdEntry->tableEntry;
616 	int shardIndex = shardIdEntry->shardIndex;
617 
618 	/* the offset better be in a valid range */
619 	Assert(shardIndex < tableEntry->shardIntervalArrayLength);
620 
621 	GroupShardPlacement *placementArray =
622 		tableEntry->arrayOfPlacementArrays[shardIndex];
623 	int numberOfPlacements =
624 		tableEntry->arrayOfPlacementArrayLengths[shardIndex];
625 
626 	for (int i = 0; i < numberOfPlacements; i++)
627 	{
628 		if (placementArray[i].placementId == placementId)
629 		{
630 			GroupShardPlacement *shardPlacement = CitusMakeNode(GroupShardPlacement);
631 
632 			*shardPlacement = placementArray[i];
633 
634 			return shardPlacement;
635 		}
636 	}
637 
638 	ereport(ERROR, (errmsg("could not find valid entry for shard placement "
639 						   UINT64_FORMAT, placementId)));
640 }
641 
642 
643 /*
644  * LoadShardPlacement returns a shard placement for the primary node.
645  */
646 ShardPlacement *
LoadShardPlacement(uint64 shardId,uint64 placementId)647 LoadShardPlacement(uint64 shardId, uint64 placementId)
648 {
649 	ShardIdCacheEntry *shardIdEntry = LookupShardIdCacheEntry(shardId);
650 	CitusTableCacheEntry *tableEntry = shardIdEntry->tableEntry;
651 	int shardIndex = shardIdEntry->shardIndex;
652 	GroupShardPlacement *groupPlacement = LoadGroupShardPlacement(shardId, placementId);
653 	ShardPlacement *nodePlacement = ResolveGroupShardPlacement(groupPlacement,
654 															   tableEntry, shardIndex);
655 
656 	return nodePlacement;
657 }
658 
659 
660 /*
661  * ShardPlacementOnGroupIncludingOrphanedPlacements returns the shard placement
662  * for the given shard on the given group, or returns NULL if no placement for
663  * the shard exists on the group.
664  *
665  * NOTE: This can return inactive or orphaned placements.
666  */
667 ShardPlacement *
ShardPlacementOnGroupIncludingOrphanedPlacements(int32 groupId,uint64 shardId)668 ShardPlacementOnGroupIncludingOrphanedPlacements(int32 groupId, uint64 shardId)
669 {
670 	ShardPlacement *placementOnNode = NULL;
671 
672 	ShardIdCacheEntry *shardIdEntry = LookupShardIdCacheEntry(shardId);
673 	CitusTableCacheEntry *tableEntry = shardIdEntry->tableEntry;
674 	int shardIndex = shardIdEntry->shardIndex;
675 	GroupShardPlacement *placementArray =
676 		tableEntry->arrayOfPlacementArrays[shardIndex];
677 	int numberOfPlacements =
678 		tableEntry->arrayOfPlacementArrayLengths[shardIndex];
679 
680 	for (int placementIndex = 0; placementIndex < numberOfPlacements; placementIndex++)
681 	{
682 		GroupShardPlacement *placement = &placementArray[placementIndex];
683 		if (placement->groupId == groupId)
684 		{
685 			placementOnNode = ResolveGroupShardPlacement(placement, tableEntry,
686 														 shardIndex);
687 			break;
688 		}
689 	}
690 
691 	return placementOnNode;
692 }
693 
694 
695 /*
696  * ActiveShardPlacementOnGroup returns the active shard placement for the
697  * given shard on the given group, or returns NULL if no active placement for
698  * the shard exists on the group.
699  */
700 ShardPlacement *
ActiveShardPlacementOnGroup(int32 groupId,uint64 shardId)701 ActiveShardPlacementOnGroup(int32 groupId, uint64 shardId)
702 {
703 	ShardPlacement *placement =
704 		ShardPlacementOnGroupIncludingOrphanedPlacements(groupId, shardId);
705 	if (placement == NULL)
706 	{
707 		return NULL;
708 	}
709 	if (placement->shardState != SHARD_STATE_ACTIVE)
710 	{
711 		return NULL;
712 	}
713 	return placement;
714 }
715 
716 
717 /*
718  * ResolveGroupShardPlacement takes a GroupShardPlacement and adds additional data to it,
719  * such as the node we should consider it to be on.
720  */
721 static ShardPlacement *
ResolveGroupShardPlacement(GroupShardPlacement * groupShardPlacement,CitusTableCacheEntry * tableEntry,int shardIndex)722 ResolveGroupShardPlacement(GroupShardPlacement *groupShardPlacement,
723 						   CitusTableCacheEntry *tableEntry,
724 						   int shardIndex)
725 {
726 	ShardInterval *shardInterval = tableEntry->sortedShardIntervalArray[shardIndex];
727 
728 	ShardPlacement *shardPlacement = CitusMakeNode(ShardPlacement);
729 	int32 groupId = groupShardPlacement->groupId;
730 	WorkerNode *workerNode = LookupNodeForGroup(groupId);
731 
732 	/* copy everything into shardPlacement but preserve the header */
733 	CitusNode header = shardPlacement->type;
734 	GroupShardPlacement *shardPlacementAsGroupPlacement =
735 		(GroupShardPlacement *) shardPlacement;
736 	*shardPlacementAsGroupPlacement = *groupShardPlacement;
737 	shardPlacement->type = header;
738 
739 	SetPlacementNodeMetadata(shardPlacement, workerNode);
740 
741 	/* fill in remaining fields */
742 	Assert(tableEntry->partitionMethod != 0);
743 	shardPlacement->partitionMethod = tableEntry->partitionMethod;
744 	shardPlacement->colocationGroupId = tableEntry->colocationId;
745 	if (tableEntry->partitionMethod == DISTRIBUTE_BY_HASH)
746 	{
747 		Assert(shardInterval->minValueExists);
748 		Assert(shardInterval->valueTypeId == INT4OID);
749 
750 		/*
751 		 * Use the lower boundary of the interval's range to identify
752 		 * it for colocation purposes. That remains meaningful even if
753 		 * a concurrent session splits a shard.
754 		 */
755 		shardPlacement->representativeValue = DatumGetInt32(shardInterval->minValue);
756 	}
757 	else
758 	{
759 		shardPlacement->representativeValue = 0;
760 	}
761 
762 	return shardPlacement;
763 }
764 
765 
766 /*
767  * HasAnyNodes returns whether there are any nodes in pg_dist_node.
768  */
769 bool
HasAnyNodes(void)770 HasAnyNodes(void)
771 {
772 	PrepareWorkerNodeCache();
773 
774 	return WorkerNodeCount > 0;
775 }
776 
777 
778 /*
779  * LookupNodeByNodeId returns a worker node by nodeId or NULL if the node
780  * cannot be found.
781  */
782 WorkerNode *
LookupNodeByNodeId(uint32 nodeId)783 LookupNodeByNodeId(uint32 nodeId)
784 {
785 	PrepareWorkerNodeCache();
786 
787 	for (int workerNodeIndex = 0; workerNodeIndex < WorkerNodeCount; workerNodeIndex++)
788 	{
789 		WorkerNode *workerNode = WorkerNodeArray[workerNodeIndex];
790 		if (workerNode->nodeId == nodeId)
791 		{
792 			WorkerNode *workerNodeCopy = palloc0(sizeof(WorkerNode));
793 			*workerNodeCopy = *workerNode;
794 
795 			return workerNodeCopy;
796 		}
797 	}
798 
799 	return NULL;
800 }
801 
802 
803 /*
804  * LookupNodeByNodeIdOrError returns a worker node by nodeId or errors out if the
805  * node cannot be found.
806  */
807 WorkerNode *
LookupNodeByNodeIdOrError(uint32 nodeId)808 LookupNodeByNodeIdOrError(uint32 nodeId)
809 {
810 	WorkerNode *node = LookupNodeByNodeId(nodeId);
811 	if (node == NULL)
812 	{
813 		ereport(ERROR, (errmsg("node %d could not be found", nodeId)));
814 	}
815 	return node;
816 }
817 
818 
819 /*
820  * LookupNodeForGroup searches the WorkerNodeHash for a worker which is a member of the
821  * given group and also readable (a primary if we're reading from primaries, a secondary
822  * if we're reading from secondaries). If such a node does not exist it emits an
823  * appropriate error message.
824  */
825 WorkerNode *
LookupNodeForGroup(int32 groupId)826 LookupNodeForGroup(int32 groupId)
827 {
828 	bool foundAnyNodes = false;
829 
830 	PrepareWorkerNodeCache();
831 
832 	for (int workerNodeIndex = 0; workerNodeIndex < WorkerNodeCount; workerNodeIndex++)
833 	{
834 		WorkerNode *workerNode = WorkerNodeArray[workerNodeIndex];
835 		int32 workerNodeGroupId = workerNode->groupId;
836 		if (workerNodeGroupId != groupId)
837 		{
838 			continue;
839 		}
840 
841 		foundAnyNodes = true;
842 
843 		if (NodeIsReadable(workerNode))
844 		{
845 			return workerNode;
846 		}
847 	}
848 
849 	if (!foundAnyNodes)
850 	{
851 		ereport(ERROR, (errmsg("there is a shard placement in node group %d but "
852 							   "there are no nodes in that group", groupId)));
853 	}
854 
855 	switch (ReadFromSecondaries)
856 	{
857 		case USE_SECONDARY_NODES_NEVER:
858 		{
859 			ereport(ERROR, (errmsg("node group %d does not have a primary node",
860 								   groupId)));
861 			break;
862 		}
863 
864 		case USE_SECONDARY_NODES_ALWAYS:
865 		{
866 			ereport(ERROR, (errmsg("node group %d does not have a secondary node",
867 								   groupId)));
868 			break;
869 		}
870 
871 		default:
872 		{
873 			ereport(FATAL, (errmsg("unrecognized value for use_secondary_nodes")));
874 		}
875 	}
876 }
877 
878 
879 /*
880  * ShardPlacementList returns the list of placements for the given shard from
881  * the cache. This list includes placements that are orphaned, because they
882  * their deletion is postponed to a later point (shardstate = 4).
883  *
884  * The returned list is deep copied from the cache and thus can be modified
885  * and pfree()d freely.
886  */
887 List *
ShardPlacementListIncludingOrphanedPlacements(uint64 shardId)888 ShardPlacementListIncludingOrphanedPlacements(uint64 shardId)
889 {
890 	List *placementList = NIL;
891 
892 	ShardIdCacheEntry *shardIdEntry = LookupShardIdCacheEntry(shardId);
893 	CitusTableCacheEntry *tableEntry = shardIdEntry->tableEntry;
894 	int shardIndex = shardIdEntry->shardIndex;
895 
896 	/* the offset better be in a valid range */
897 	Assert(shardIndex < tableEntry->shardIntervalArrayLength);
898 
899 	GroupShardPlacement *placementArray =
900 		tableEntry->arrayOfPlacementArrays[shardIndex];
901 	int numberOfPlacements =
902 		tableEntry->arrayOfPlacementArrayLengths[shardIndex];
903 
904 	for (int i = 0; i < numberOfPlacements; i++)
905 	{
906 		GroupShardPlacement *groupShardPlacement = &placementArray[i];
907 		ShardPlacement *shardPlacement = ResolveGroupShardPlacement(groupShardPlacement,
908 																	tableEntry,
909 																	shardIndex);
910 
911 		placementList = lappend(placementList, shardPlacement);
912 	}
913 
914 	/* if no shard placements are found, warn the user */
915 	if (numberOfPlacements == 0)
916 	{
917 		ereport(WARNING, (errmsg("could not find any shard placements for shardId "
918 								 UINT64_FORMAT, shardId)));
919 	}
920 
921 	return placementList;
922 }
923 
924 
925 /*
926  * InitializeTableCacheEntry initializes a shard in cache.  A possible reason
927  * for not finding an entry in the cache is that the distributed table's cache
928  * entry hasn't been accessed yet. Thus look up the distributed table, and
929  * build the cache entry. Afterwards we know that the shard has to be in the
930  * cache if it exists. If the shard does *not* exist, this function errors
931  * (because LookupShardRelationFromCatalog errors out).
932  */
933 static void
InitializeTableCacheEntry(int64 shardId)934 InitializeTableCacheEntry(int64 shardId)
935 {
936 	bool missingOk = false;
937 	Oid relationId = LookupShardRelationFromCatalog(shardId, missingOk);
938 
939 	/* trigger building the cache for the shard id */
940 	GetCitusTableCacheEntry(relationId); /* lgtm[cpp/return-value-ignored] */
941 }
942 
943 
944 /*
945  * RefreshInvalidTableCacheEntry checks if the cache entry is still valid and
946  * refreshes it in cache when it's not. It returns true if it refreshed the
947  * entry in the cache and false if it didn't.
948  */
949 static bool
RefreshTableCacheEntryIfInvalid(ShardIdCacheEntry * shardEntry)950 RefreshTableCacheEntryIfInvalid(ShardIdCacheEntry *shardEntry)
951 {
952 	/*
953 	 * We might have some concurrent metadata changes. In order to get the changes,
954 	 * we first need to accept the cache invalidation messages.
955 	 */
956 	AcceptInvalidationMessages();
957 	if (shardEntry->tableEntry->isValid)
958 	{
959 		return false;
960 	}
961 	Oid oldRelationId = shardEntry->tableEntry->relationId;
962 	Oid currentRelationId = LookupShardRelationFromCatalog(shardEntry->shardId, false);
963 
964 	/*
965 	 * The relation OID to which the shard belongs could have changed,
966 	 * most notably when the extension is dropped and a shard ID is
967 	 * reused. Reload the cache entries for both old and new relation
968 	 * ID and then look up the shard entry again.
969 	 */
970 	LookupCitusTableCacheEntry(oldRelationId);
971 	LookupCitusTableCacheEntry(currentRelationId);
972 	return true;
973 }
974 
975 
976 /*
977  * LookupShardCacheEntry returns the cache entry belonging to a shard, or
978  * errors out if that shard is unknown.
979  */
980 static ShardIdCacheEntry *
LookupShardIdCacheEntry(int64 shardId)981 LookupShardIdCacheEntry(int64 shardId)
982 {
983 	bool foundInCache = false;
984 	bool recheck = false;
985 
986 	Assert(CitusHasBeenLoaded() && CheckCitusVersion(WARNING));
987 
988 	InitializeCaches();
989 
990 	ShardIdCacheEntry *shardEntry =
991 		hash_search(ShardIdCacheHash, &shardId, HASH_FIND, &foundInCache);
992 
993 	if (!foundInCache)
994 	{
995 		InitializeTableCacheEntry(shardId);
996 		recheck = true;
997 	}
998 	else
999 	{
1000 		recheck = RefreshTableCacheEntryIfInvalid(shardEntry);
1001 	}
1002 
1003 	/*
1004 	 * If we (re-)loaded the table cache, re-search the shard cache - the
1005 	 * shard index might have changed.  If we still can't find the entry, it
1006 	 * can't exist.
1007 	 */
1008 	if (recheck)
1009 	{
1010 		shardEntry = hash_search(ShardIdCacheHash, &shardId, HASH_FIND, &foundInCache);
1011 
1012 		if (!foundInCache)
1013 		{
1014 			ereport(ERROR, (errmsg("could not find valid entry for shard "
1015 								   UINT64_FORMAT, shardId)));
1016 		}
1017 	}
1018 
1019 	return shardEntry;
1020 }
1021 
1022 
1023 /*
1024  * GetCitusTableCacheEntry looks up a pg_dist_partition entry for a
1025  * relation.
1026  *
1027  * Errors out if no relation matching the criteria could be found.
1028  */
1029 CitusTableCacheEntry *
GetCitusTableCacheEntry(Oid distributedRelationId)1030 GetCitusTableCacheEntry(Oid distributedRelationId)
1031 {
1032 	CitusTableCacheEntry *cacheEntry =
1033 		LookupCitusTableCacheEntry(distributedRelationId);
1034 
1035 	if (cacheEntry)
1036 	{
1037 		return cacheEntry;
1038 	}
1039 	else
1040 	{
1041 		char *relationName = get_rel_name(distributedRelationId);
1042 
1043 		if (relationName == NULL)
1044 		{
1045 			ereport(ERROR, (errmsg("relation with OID %u does not exist",
1046 								   distributedRelationId)));
1047 		}
1048 		else
1049 		{
1050 			ereport(ERROR, (errmsg("relation %s is not distributed", relationName)));
1051 		}
1052 	}
1053 }
1054 
1055 
1056 /*
1057  * GetCitusTableCacheEntry returns the distributed table metadata for the
1058  * passed relationId. For efficiency it caches lookups. This function returns
1059  * NULL if the relation isn't a distributed table.
1060  */
1061 CitusTableCacheEntry *
LookupCitusTableCacheEntry(Oid relationId)1062 LookupCitusTableCacheEntry(Oid relationId)
1063 {
1064 	bool foundInCache = false;
1065 	void *hashKey = (void *) &relationId;
1066 
1067 	/*
1068 	 * Can't be a distributed relation if the extension hasn't been loaded
1069 	 * yet. As we can't do lookups in nonexistent tables, directly return NULL
1070 	 * here.
1071 	 */
1072 	if (!CitusHasBeenLoaded())
1073 	{
1074 		return NULL;
1075 	}
1076 
1077 	InitializeCaches();
1078 
1079 	/*
1080 	 * If the version is not known to be compatible, perform thorough check,
1081 	 * unless such checks are disabled.
1082 	 */
1083 	if (!citusVersionKnownCompatible && EnableVersionChecks)
1084 	{
1085 		bool isCitusTable = IsCitusTableViaCatalog(relationId);
1086 		int reportLevel = DEBUG1;
1087 
1088 		/*
1089 		 * If there's a version-mismatch, and we're dealing with a distributed
1090 		 * table, we have to error out as we can't return a valid entry.  We
1091 		 * want to check compatibility in the non-distributed case as well, so
1092 		 * future lookups can use the cache if compatible.
1093 		 */
1094 		if (isCitusTable)
1095 		{
1096 			reportLevel = ERROR;
1097 		}
1098 
1099 		if (!CheckCitusVersion(reportLevel))
1100 		{
1101 			/* incompatible, can't access cache, so return before doing so */
1102 			return NULL;
1103 		}
1104 	}
1105 
1106 	/*
1107 	 * We might have some concurrent metadata changes. In order to get the changes,
1108 	 * we first need to accept the cache invalidation messages.
1109 	 */
1110 	AcceptInvalidationMessages();
1111 	CitusTableCacheEntrySlot *cacheSlot =
1112 		hash_search(DistTableCacheHash, hashKey, HASH_ENTER, &foundInCache);
1113 
1114 	/* return valid matches */
1115 	if (foundInCache)
1116 	{
1117 		if (cacheSlot->isValid)
1118 		{
1119 			return cacheSlot->citusTableMetadata;
1120 		}
1121 		else
1122 		{
1123 			/*
1124 			 * An invalidation was received or we encountered an OOM while building
1125 			 * the cache entry. We need to rebuild it.
1126 			 */
1127 
1128 			if (cacheSlot->citusTableMetadata)
1129 			{
1130 				/*
1131 				 * The CitusTableCacheEntry might still be in use. We therefore do
1132 				 * not reset it until the end of the transaction.
1133 				 */
1134 				MemoryContext oldContext =
1135 					MemoryContextSwitchTo(MetadataCacheMemoryContext);
1136 
1137 				DistTableCacheExpired = lappend(DistTableCacheExpired,
1138 												cacheSlot->citusTableMetadata);
1139 
1140 				MemoryContextSwitchTo(oldContext);
1141 			}
1142 		}
1143 	}
1144 
1145 	/* zero out entry, but not the key part */
1146 	memset(((char *) cacheSlot) + sizeof(Oid), 0,
1147 		   sizeof(CitusTableCacheEntrySlot) - sizeof(Oid));
1148 
1149 	/*
1150 	 * We disable interrupts while creating the cache entry because loading
1151 	 * shard metadata can take a while, and if statement_timeout is too low,
1152 	 * this will get canceled on each call and we won't be able to run any
1153 	 * queries on the table.
1154 	 */
1155 	HOLD_INTERRUPTS();
1156 
1157 	cacheSlot->citusTableMetadata = BuildCitusTableCacheEntry(relationId);
1158 
1159 	/*
1160 	 * Mark it as valid only after building the full entry, such that any
1161 	 * error that happened during the build would trigger a rebuild.
1162 	 */
1163 	cacheSlot->isValid = true;
1164 
1165 	RESUME_INTERRUPTS();
1166 
1167 	return cacheSlot->citusTableMetadata;
1168 }
1169 
1170 
1171 /*
1172  * LookupDistObjectCacheEntry returns the distributed table metadata for the
1173  * passed relationId. For efficiency it caches lookups.
1174  */
1175 DistObjectCacheEntry *
LookupDistObjectCacheEntry(Oid classid,Oid objid,int32 objsubid)1176 LookupDistObjectCacheEntry(Oid classid, Oid objid, int32 objsubid)
1177 {
1178 	bool foundInCache = false;
1179 	DistObjectCacheEntryKey hashKey;
1180 	ScanKeyData pgDistObjectKey[3];
1181 
1182 	memset(&hashKey, 0, sizeof(DistObjectCacheEntryKey));
1183 	hashKey.classid = classid;
1184 	hashKey.objid = objid;
1185 	hashKey.objsubid = objsubid;
1186 
1187 	/*
1188 	 * Can't be a distributed relation if the extension hasn't been loaded
1189 	 * yet. As we can't do lookups in nonexistent tables, directly return NULL
1190 	 * here.
1191 	 */
1192 	if (!CitusHasBeenLoaded())
1193 	{
1194 		return NULL;
1195 	}
1196 
1197 	InitializeCaches();
1198 
1199 	DistObjectCacheEntry *cacheEntry = hash_search(DistObjectCacheHash, &hashKey,
1200 												   HASH_ENTER, &foundInCache);
1201 
1202 	/* return valid matches */
1203 	if (foundInCache)
1204 	{
1205 		/*
1206 		 * We might have some concurrent metadata changes. In order to get the changes,
1207 		 * we first need to accept the cache invalidation messages.
1208 		 */
1209 		AcceptInvalidationMessages();
1210 
1211 		if (cacheEntry->isValid)
1212 		{
1213 			return cacheEntry;
1214 		}
1215 
1216 		/*
1217 		 * This is where we'd free the old entry's out of band data if it had any.
1218 		 * Right now we don't have anything to free.
1219 		 */
1220 	}
1221 
1222 	/* zero out entry, but not the key part */
1223 	memset(((char *) cacheEntry), 0, sizeof(DistObjectCacheEntry));
1224 	cacheEntry->key.classid = classid;
1225 	cacheEntry->key.objid = objid;
1226 	cacheEntry->key.objsubid = objsubid;
1227 
1228 	Relation pgDistObjectRel = table_open(DistObjectRelationId(), AccessShareLock);
1229 	TupleDesc pgDistObjectTupleDesc = RelationGetDescr(pgDistObjectRel);
1230 
1231 	ScanKeyInit(&pgDistObjectKey[0], Anum_pg_dist_object_classid,
1232 				BTEqualStrategyNumber, F_OIDEQ, ObjectIdGetDatum(classid));
1233 	ScanKeyInit(&pgDistObjectKey[1], Anum_pg_dist_object_objid,
1234 				BTEqualStrategyNumber, F_OIDEQ, ObjectIdGetDatum(objid));
1235 	ScanKeyInit(&pgDistObjectKey[2], Anum_pg_dist_object_objsubid,
1236 				BTEqualStrategyNumber, F_INT4EQ, Int32GetDatum(objsubid));
1237 
1238 	SysScanDesc pgDistObjectScan = systable_beginscan(pgDistObjectRel,
1239 													  DistObjectPrimaryKeyIndexId(),
1240 													  true, NULL, 3, pgDistObjectKey);
1241 	HeapTuple pgDistObjectTup = systable_getnext(pgDistObjectScan);
1242 
1243 	if (HeapTupleIsValid(pgDistObjectTup))
1244 	{
1245 		Datum datumArray[Natts_pg_dist_object];
1246 		bool isNullArray[Natts_pg_dist_object];
1247 
1248 		heap_deform_tuple(pgDistObjectTup, pgDistObjectTupleDesc, datumArray,
1249 						  isNullArray);
1250 
1251 		cacheEntry->isValid = true;
1252 		cacheEntry->isDistributed = true;
1253 
1254 		cacheEntry->distributionArgIndex =
1255 			DatumGetInt32(datumArray[Anum_pg_dist_object_distribution_argument_index -
1256 									 1]);
1257 		cacheEntry->colocationId =
1258 			DatumGetInt32(datumArray[Anum_pg_dist_object_colocationid - 1]);
1259 	}
1260 	else
1261 	{
1262 		cacheEntry->isValid = true;
1263 		cacheEntry->isDistributed = false;
1264 	}
1265 
1266 	systable_endscan(pgDistObjectScan);
1267 	relation_close(pgDistObjectRel, AccessShareLock);
1268 
1269 	return cacheEntry;
1270 }
1271 
1272 
1273 /*
1274  * BuildCitusTableCacheEntry is a helper routine for
1275  * LookupCitusTableCacheEntry() for building the cache contents.
1276  * This function returns NULL if the relation isn't a distributed table.
1277  */
1278 static CitusTableCacheEntry *
BuildCitusTableCacheEntry(Oid relationId)1279 BuildCitusTableCacheEntry(Oid relationId)
1280 {
1281 	Relation pgDistPartition = table_open(DistPartitionRelationId(), AccessShareLock);
1282 	HeapTuple distPartitionTuple =
1283 		LookupDistPartitionTuple(pgDistPartition, relationId);
1284 
1285 	if (distPartitionTuple == NULL)
1286 	{
1287 		/* not a distributed table, done */
1288 		table_close(pgDistPartition, NoLock);
1289 		return NULL;
1290 	}
1291 
1292 	MemoryContext oldContext = NULL;
1293 	Datum datumArray[Natts_pg_dist_partition];
1294 	bool isNullArray[Natts_pg_dist_partition];
1295 
1296 	TupleDesc tupleDescriptor = RelationGetDescr(pgDistPartition);
1297 	heap_deform_tuple(distPartitionTuple, tupleDescriptor, datumArray, isNullArray);
1298 
1299 	CitusTableCacheEntry *cacheEntry =
1300 		MemoryContextAllocZero(MetadataCacheMemoryContext, sizeof(CitusTableCacheEntry));
1301 
1302 	cacheEntry->relationId = relationId;
1303 
1304 	cacheEntry->partitionMethod = datumArray[Anum_pg_dist_partition_partmethod - 1];
1305 	Datum partitionKeyDatum = datumArray[Anum_pg_dist_partition_partkey - 1];
1306 	bool partitionKeyIsNull = isNullArray[Anum_pg_dist_partition_partkey - 1];
1307 
1308 	/* note that for reference tables partitionKeyisNull is true */
1309 	if (!partitionKeyIsNull)
1310 	{
1311 		oldContext = MemoryContextSwitchTo(MetadataCacheMemoryContext);
1312 
1313 		/* get the string representation of the partition column Var */
1314 		cacheEntry->partitionKeyString = TextDatumGetCString(partitionKeyDatum);
1315 
1316 		/* convert the string to a Node and ensure it is a Var */
1317 		Node *partitionNode = stringToNode(cacheEntry->partitionKeyString);
1318 		Assert(IsA(partitionNode, Var));
1319 
1320 		cacheEntry->partitionColumn = (Var *) partitionNode;
1321 
1322 		MemoryContextSwitchTo(oldContext);
1323 	}
1324 	else
1325 	{
1326 		cacheEntry->partitionKeyString = NULL;
1327 	}
1328 
1329 	cacheEntry->colocationId = datumArray[Anum_pg_dist_partition_colocationid - 1];
1330 	if (isNullArray[Anum_pg_dist_partition_colocationid - 1])
1331 	{
1332 		cacheEntry->colocationId = INVALID_COLOCATION_ID;
1333 	}
1334 
1335 	Datum replicationModelDatum = datumArray[Anum_pg_dist_partition_repmodel - 1];
1336 	if (isNullArray[Anum_pg_dist_partition_repmodel - 1])
1337 	{
1338 		/*
1339 		 * repmodel is NOT NULL but before ALTER EXTENSION citus UPGRADE the column
1340 		 * doesn't exist
1341 		 */
1342 		cacheEntry->replicationModel = 'c';
1343 	}
1344 	else
1345 	{
1346 		cacheEntry->replicationModel = DatumGetChar(replicationModelDatum);
1347 	}
1348 
1349 	heap_freetuple(distPartitionTuple);
1350 
1351 	BuildCachedShardList(cacheEntry);
1352 
1353 	/* we only need hash functions for hash distributed tables */
1354 	if (cacheEntry->partitionMethod == DISTRIBUTE_BY_HASH)
1355 	{
1356 		Var *partitionColumn = cacheEntry->partitionColumn;
1357 
1358 		TypeCacheEntry *typeEntry = lookup_type_cache(partitionColumn->vartype,
1359 													  TYPECACHE_HASH_PROC_FINFO);
1360 
1361 		FmgrInfo *hashFunction = MemoryContextAllocZero(MetadataCacheMemoryContext,
1362 														sizeof(FmgrInfo));
1363 
1364 		fmgr_info_copy(hashFunction, &(typeEntry->hash_proc_finfo),
1365 					   MetadataCacheMemoryContext);
1366 
1367 		cacheEntry->hashFunction = hashFunction;
1368 
1369 		/* check the shard distribution for hash partitioned tables */
1370 		cacheEntry->hasUniformHashDistribution =
1371 			HasUniformHashDistribution(cacheEntry->sortedShardIntervalArray,
1372 									   cacheEntry->shardIntervalArrayLength);
1373 	}
1374 	else
1375 	{
1376 		cacheEntry->hashFunction = NULL;
1377 	}
1378 
1379 	oldContext = MemoryContextSwitchTo(MetadataCacheMemoryContext);
1380 
1381 	cacheEntry->referencedRelationsViaForeignKey = ReferencedRelationIdList(
1382 		cacheEntry->relationId);
1383 	cacheEntry->referencingRelationsViaForeignKey = ReferencingRelationIdList(
1384 		cacheEntry->relationId);
1385 
1386 	MemoryContextSwitchTo(oldContext);
1387 
1388 	table_close(pgDistPartition, NoLock);
1389 
1390 	cacheEntry->isValid = true;
1391 
1392 	return cacheEntry;
1393 }
1394 
1395 
1396 /*
1397  * BuildCachedShardList() is a helper routine for BuildCitusTableCacheEntry()
1398  * building up the list of shards in a distributed relation.
1399  */
1400 static void
BuildCachedShardList(CitusTableCacheEntry * cacheEntry)1401 BuildCachedShardList(CitusTableCacheEntry *cacheEntry)
1402 {
1403 	ShardInterval **shardIntervalArray = NULL;
1404 	ShardInterval **sortedShardIntervalArray = NULL;
1405 	FmgrInfo *shardIntervalCompareFunction = NULL;
1406 	FmgrInfo *shardColumnCompareFunction = NULL;
1407 	Oid columnTypeId = InvalidOid;
1408 	int32 columnTypeMod = -1;
1409 	Oid intervalTypeId = InvalidOid;
1410 	int32 intervalTypeMod = -1;
1411 
1412 	GetPartitionTypeInputInfo(cacheEntry->partitionKeyString,
1413 							  cacheEntry->partitionMethod,
1414 							  &columnTypeId,
1415 							  &columnTypeMod,
1416 							  &intervalTypeId,
1417 							  &intervalTypeMod);
1418 
1419 	List *distShardTupleList = LookupDistShardTuples(cacheEntry->relationId);
1420 	int shardIntervalArrayLength = list_length(distShardTupleList);
1421 	if (shardIntervalArrayLength > 0)
1422 	{
1423 		Relation distShardRelation = table_open(DistShardRelationId(), AccessShareLock);
1424 		TupleDesc distShardTupleDesc = RelationGetDescr(distShardRelation);
1425 		int arrayIndex = 0;
1426 
1427 		shardIntervalArray = MemoryContextAllocZero(MetadataCacheMemoryContext,
1428 													shardIntervalArrayLength *
1429 													sizeof(ShardInterval *));
1430 
1431 		cacheEntry->arrayOfPlacementArrays =
1432 			MemoryContextAllocZero(MetadataCacheMemoryContext,
1433 								   shardIntervalArrayLength *
1434 								   sizeof(GroupShardPlacement *));
1435 		cacheEntry->arrayOfPlacementArrayLengths =
1436 			MemoryContextAllocZero(MetadataCacheMemoryContext,
1437 								   shardIntervalArrayLength *
1438 								   sizeof(int));
1439 
1440 		HeapTuple shardTuple = NULL;
1441 		foreach_ptr(shardTuple, distShardTupleList)
1442 		{
1443 			ShardInterval *shardInterval = TupleToShardInterval(shardTuple,
1444 																distShardTupleDesc,
1445 																intervalTypeId,
1446 																intervalTypeMod);
1447 			MemoryContext oldContext = MemoryContextSwitchTo(MetadataCacheMemoryContext);
1448 
1449 			shardIntervalArray[arrayIndex] = CopyShardInterval(shardInterval);
1450 
1451 			MemoryContextSwitchTo(oldContext);
1452 
1453 			heap_freetuple(shardTuple);
1454 
1455 			arrayIndex++;
1456 		}
1457 
1458 		table_close(distShardRelation, AccessShareLock);
1459 	}
1460 
1461 	/* look up value comparison function */
1462 	if (columnTypeId != InvalidOid)
1463 	{
1464 		/* allocate the comparison function in the cache context */
1465 		MemoryContext oldContext = MemoryContextSwitchTo(MetadataCacheMemoryContext);
1466 
1467 		shardColumnCompareFunction = GetFunctionInfo(columnTypeId, BTREE_AM_OID,
1468 													 BTORDER_PROC);
1469 		MemoryContextSwitchTo(oldContext);
1470 	}
1471 	else
1472 	{
1473 		shardColumnCompareFunction = NULL;
1474 	}
1475 
1476 	/* look up interval comparison function */
1477 	if (intervalTypeId != InvalidOid)
1478 	{
1479 		/* allocate the comparison function in the cache context */
1480 		MemoryContext oldContext = MemoryContextSwitchTo(MetadataCacheMemoryContext);
1481 
1482 		shardIntervalCompareFunction = GetFunctionInfo(intervalTypeId, BTREE_AM_OID,
1483 													   BTORDER_PROC);
1484 		MemoryContextSwitchTo(oldContext);
1485 	}
1486 	else
1487 	{
1488 		shardIntervalCompareFunction = NULL;
1489 	}
1490 
1491 	/* reference tables has a single shard which is not initialized */
1492 	if (cacheEntry->partitionMethod == DISTRIBUTE_BY_NONE)
1493 	{
1494 		cacheEntry->hasUninitializedShardInterval = true;
1495 		cacheEntry->hasOverlappingShardInterval = true;
1496 
1497 		/*
1498 		 * Note that during create_reference_table() call,
1499 		 * the reference table do not have any shards.
1500 		 */
1501 		if (shardIntervalArrayLength > 1)
1502 		{
1503 			char *relationName = get_rel_name(cacheEntry->relationId);
1504 
1505 			ereport(ERROR, (errcode(ERRCODE_OBJECT_NOT_IN_PREREQUISITE_STATE),
1506 							errmsg("reference table \"%s\" has more than 1 shard",
1507 								   relationName)));
1508 		}
1509 
1510 		/* since there is a zero or one shard, it is already sorted */
1511 		sortedShardIntervalArray = shardIntervalArray;
1512 	}
1513 	else
1514 	{
1515 		/* sort the interval array */
1516 		sortedShardIntervalArray = SortShardIntervalArray(shardIntervalArray,
1517 														  shardIntervalArrayLength,
1518 														  cacheEntry->partitionColumn->
1519 														  varcollid,
1520 														  shardIntervalCompareFunction);
1521 
1522 		/* check if there exists any shard intervals with no min/max values */
1523 		cacheEntry->hasUninitializedShardInterval =
1524 			HasUninitializedShardInterval(sortedShardIntervalArray,
1525 										  shardIntervalArrayLength);
1526 
1527 		if (!cacheEntry->hasUninitializedShardInterval)
1528 		{
1529 			cacheEntry->hasOverlappingShardInterval =
1530 				HasOverlappingShardInterval(sortedShardIntervalArray,
1531 											shardIntervalArrayLength,
1532 											cacheEntry->partitionColumn->varcollid,
1533 											shardIntervalCompareFunction);
1534 		}
1535 		else
1536 		{
1537 			cacheEntry->hasOverlappingShardInterval = true;
1538 		}
1539 
1540 		ErrorIfInconsistentShardIntervals(cacheEntry);
1541 	}
1542 
1543 	cacheEntry->sortedShardIntervalArray = sortedShardIntervalArray;
1544 	cacheEntry->shardIntervalArrayLength = 0;
1545 
1546 	/* maintain shardId->(table,ShardInterval) cache */
1547 	for (int shardIndex = 0; shardIndex < shardIntervalArrayLength; shardIndex++)
1548 	{
1549 		ShardInterval *shardInterval = sortedShardIntervalArray[shardIndex];
1550 		int64 shardId = shardInterval->shardId;
1551 		int placementOffset = 0;
1552 
1553 		/*
1554 		 * Enable quick lookups of this shard ID by adding it to ShardIdCacheHash
1555 		 * or overwriting the previous values.
1556 		 */
1557 		ShardIdCacheEntry *shardIdCacheEntry =
1558 			hash_search(ShardIdCacheHash, &shardId, HASH_ENTER, NULL);
1559 
1560 		shardIdCacheEntry->tableEntry = cacheEntry;
1561 		shardIdCacheEntry->shardIndex = shardIndex;
1562 
1563 		/*
1564 		 * We should increment this only after we are sure this hasn't already
1565 		 * been assigned to any other relations. ResetCitusTableCacheEntry()
1566 		 * depends on this.
1567 		 */
1568 		cacheEntry->shardIntervalArrayLength++;
1569 
1570 		/* build list of shard placements */
1571 		List *placementList = BuildShardPlacementList(shardId);
1572 		int numberOfPlacements = list_length(placementList);
1573 
1574 		/* and copy that list into the cache entry */
1575 		MemoryContext oldContext = MemoryContextSwitchTo(MetadataCacheMemoryContext);
1576 		GroupShardPlacement *placementArray = palloc0(numberOfPlacements *
1577 													  sizeof(GroupShardPlacement));
1578 		GroupShardPlacement *srcPlacement = NULL;
1579 		foreach_ptr(srcPlacement, placementList)
1580 		{
1581 			placementArray[placementOffset] = *srcPlacement;
1582 			placementOffset++;
1583 		}
1584 		MemoryContextSwitchTo(oldContext);
1585 
1586 		cacheEntry->arrayOfPlacementArrays[shardIndex] = placementArray;
1587 		cacheEntry->arrayOfPlacementArrayLengths[shardIndex] = numberOfPlacements;
1588 
1589 		/* store the shard index in the ShardInterval */
1590 		shardInterval->shardIndex = shardIndex;
1591 	}
1592 
1593 	cacheEntry->shardColumnCompareFunction = shardColumnCompareFunction;
1594 	cacheEntry->shardIntervalCompareFunction = shardIntervalCompareFunction;
1595 }
1596 
1597 
1598 /*
1599  * ErrorIfInconsistentShardIntervals checks if shard intervals are consistent with
1600  * our expectations.
1601  */
1602 void
ErrorIfInconsistentShardIntervals(CitusTableCacheEntry * cacheEntry)1603 ErrorIfInconsistentShardIntervals(CitusTableCacheEntry *cacheEntry)
1604 {
1605 	/*
1606 	 * If table is hash-partitioned and has shards, there never should be any
1607 	 * uninitalized shards.  Historically we've not prevented that for range
1608 	 * partitioned tables, but it might be a good idea to start doing so.
1609 	 */
1610 	if (cacheEntry->partitionMethod == DISTRIBUTE_BY_HASH &&
1611 		cacheEntry->hasUninitializedShardInterval)
1612 	{
1613 		ereport(ERROR, (errmsg("hash partitioned table has uninitialized shards")));
1614 	}
1615 	if (cacheEntry->partitionMethod == DISTRIBUTE_BY_HASH &&
1616 		cacheEntry->hasOverlappingShardInterval)
1617 	{
1618 		ereport(ERROR, (errmsg("hash partitioned table has overlapping shards")));
1619 	}
1620 }
1621 
1622 
1623 /*
1624  * HasUniformHashDistribution determines whether the given list of sorted shards
1625  * has a uniform hash distribution, as produced by master_create_worker_shards for
1626  * hash partitioned tables.
1627  */
1628 bool
HasUniformHashDistribution(ShardInterval ** shardIntervalArray,int shardIntervalArrayLength)1629 HasUniformHashDistribution(ShardInterval **shardIntervalArray,
1630 						   int shardIntervalArrayLength)
1631 {
1632 	/* if there are no shards, there is no uniform distribution */
1633 	if (shardIntervalArrayLength == 0)
1634 	{
1635 		return false;
1636 	}
1637 
1638 	/* calculate the hash token increment */
1639 	uint64 hashTokenIncrement = HASH_TOKEN_COUNT / shardIntervalArrayLength;
1640 
1641 	for (int shardIndex = 0; shardIndex < shardIntervalArrayLength; shardIndex++)
1642 	{
1643 		ShardInterval *shardInterval = shardIntervalArray[shardIndex];
1644 		int32 shardMinHashToken = PG_INT32_MIN + (shardIndex * hashTokenIncrement);
1645 		int32 shardMaxHashToken = shardMinHashToken + (hashTokenIncrement - 1);
1646 
1647 		if (shardIndex == (shardIntervalArrayLength - 1))
1648 		{
1649 			shardMaxHashToken = PG_INT32_MAX;
1650 		}
1651 
1652 		if (DatumGetInt32(shardInterval->minValue) != shardMinHashToken ||
1653 			DatumGetInt32(shardInterval->maxValue) != shardMaxHashToken)
1654 		{
1655 			return false;
1656 		}
1657 	}
1658 
1659 	return true;
1660 }
1661 
1662 
1663 /*
1664  * HasUninitializedShardInterval returns true if all the elements of the
1665  * sortedShardIntervalArray has min/max values. Callers of the function must
1666  * ensure that input shard interval array is sorted on shardminvalue and uninitialized
1667  * shard intervals are at the end of the array.
1668  */
1669 bool
HasUninitializedShardInterval(ShardInterval ** sortedShardIntervalArray,int shardCount)1670 HasUninitializedShardInterval(ShardInterval **sortedShardIntervalArray, int shardCount)
1671 {
1672 	bool hasUninitializedShardInterval = false;
1673 
1674 	if (shardCount == 0)
1675 	{
1676 		return hasUninitializedShardInterval;
1677 	}
1678 
1679 	Assert(sortedShardIntervalArray != NULL);
1680 
1681 	/*
1682 	 * Since the shard interval array is sorted, and uninitialized ones stored
1683 	 * in the end of the array, checking the last element is enough.
1684 	 */
1685 	ShardInterval *lastShardInterval = sortedShardIntervalArray[shardCount - 1];
1686 	if (!lastShardInterval->minValueExists || !lastShardInterval->maxValueExists)
1687 	{
1688 		hasUninitializedShardInterval = true;
1689 	}
1690 
1691 	return hasUninitializedShardInterval;
1692 }
1693 
1694 
1695 /*
1696  * HasOverlappingShardInterval determines whether the given list of sorted
1697  * shards has overlapping ranges.
1698  */
1699 bool
HasOverlappingShardInterval(ShardInterval ** shardIntervalArray,int shardIntervalArrayLength,Oid shardIntervalCollation,FmgrInfo * shardIntervalSortCompareFunction)1700 HasOverlappingShardInterval(ShardInterval **shardIntervalArray,
1701 							int shardIntervalArrayLength,
1702 							Oid shardIntervalCollation,
1703 							FmgrInfo *shardIntervalSortCompareFunction)
1704 {
1705 	Datum comparisonDatum = 0;
1706 	int comparisonResult = 0;
1707 
1708 	/* zero/a single shard can't overlap */
1709 	if (shardIntervalArrayLength < 2)
1710 	{
1711 		return false;
1712 	}
1713 
1714 	ShardInterval *lastShardInterval = shardIntervalArray[0];
1715 	for (int shardIndex = 1; shardIndex < shardIntervalArrayLength; shardIndex++)
1716 	{
1717 		ShardInterval *curShardInterval = shardIntervalArray[shardIndex];
1718 
1719 		/* only called if !hasUninitializedShardInterval */
1720 		Assert(lastShardInterval->minValueExists && lastShardInterval->maxValueExists);
1721 		Assert(curShardInterval->minValueExists && curShardInterval->maxValueExists);
1722 
1723 		comparisonDatum = FunctionCall2Coll(shardIntervalSortCompareFunction,
1724 											shardIntervalCollation,
1725 											lastShardInterval->maxValue,
1726 											curShardInterval->minValue);
1727 		comparisonResult = DatumGetInt32(comparisonDatum);
1728 
1729 		if (comparisonResult >= 0)
1730 		{
1731 			return true;
1732 		}
1733 
1734 		lastShardInterval = curShardInterval;
1735 	}
1736 
1737 	return false;
1738 }
1739 
1740 
1741 /*
1742  * CitusHasBeenLoaded returns true if the citus extension has been created
1743  * in the current database and the extension script has been executed. Otherwise,
1744  * it returns false. The result is cached as this is called very frequently.
1745  */
1746 bool
CitusHasBeenLoaded(void)1747 CitusHasBeenLoaded(void)
1748 {
1749 	if (!MetadataCache.extensionLoaded || creating_extension)
1750 	{
1751 		/*
1752 		 * Refresh if we have not determined whether the extension has been
1753 		 * loaded yet, or in case of ALTER EXTENSION since we want to treat
1754 		 * Citus as "not loaded" during ALTER EXTENSION citus.
1755 		 */
1756 		bool extensionLoaded = CitusHasBeenLoadedInternal();
1757 
1758 		if (extensionLoaded && !MetadataCache.extensionLoaded)
1759 		{
1760 			/*
1761 			 * Loaded Citus for the first time in this session, or first time after
1762 			 * CREATE/ALTER EXTENSION citus. Do some initialisation.
1763 			 */
1764 
1765 			/*
1766 			 * Make sure the maintenance daemon is running if it was not already.
1767 			 */
1768 			StartupCitusBackend();
1769 
1770 			/*
1771 			 * InvalidateDistRelationCacheCallback resets state such as extensionLoaded
1772 			 * when it notices changes to pg_dist_partition (which usually indicate
1773 			 * `DROP EXTENSION citus;` has been run)
1774 			 *
1775 			 * Ensure InvalidateDistRelationCacheCallback will notice those changes
1776 			 * by caching pg_dist_partition's oid.
1777 			 *
1778 			 * We skip these checks during upgrade since pg_dist_partition is not
1779 			 * present during early stages of upgrade operation.
1780 			 */
1781 			DistPartitionRelationId();
1782 
1783 			/*
1784 			 * This needs to be initialized so we can receive foreign relation graph
1785 			 * invalidation messages in InvalidateForeignRelationGraphCacheCallback().
1786 			 * See the comments of InvalidateForeignKeyGraph for more context.
1787 			 */
1788 			DistColocationRelationId();
1789 		}
1790 
1791 		MetadataCache.extensionLoaded = extensionLoaded;
1792 	}
1793 
1794 	return MetadataCache.extensionLoaded;
1795 }
1796 
1797 
1798 /*
1799  * CitusHasBeenLoadedInternal returns true if the citus extension has been created
1800  * in the current database and the extension script has been executed. Otherwise,
1801  * it returns false.
1802  */
1803 static bool
CitusHasBeenLoadedInternal(void)1804 CitusHasBeenLoadedInternal(void)
1805 {
1806 	if (IsBinaryUpgrade)
1807 	{
1808 		/* never use Citus logic during pg_upgrade */
1809 		return false;
1810 	}
1811 
1812 	Oid citusExtensionOid = get_extension_oid("citus", true);
1813 	if (citusExtensionOid == InvalidOid)
1814 	{
1815 		/* Citus extension does not exist yet */
1816 		return false;
1817 	}
1818 
1819 	if (creating_extension && CurrentExtensionObject == citusExtensionOid)
1820 	{
1821 		/*
1822 		 * We do not use Citus hooks during CREATE/ALTER EXTENSION citus
1823 		 * since the objects used by the C code might be not be there yet.
1824 		 */
1825 		return false;
1826 	}
1827 
1828 	/* citus extension exists and has been created */
1829 	return true;
1830 }
1831 
1832 
1833 /*
1834  * CheckCitusVersion checks whether there is a version mismatch between the
1835  * available version and the loaded version or between the installed version
1836  * and the loaded version. Returns true if compatible, false otherwise.
1837  *
1838  * As a side effect, this function also sets citusVersionKnownCompatible global
1839  * variable to true which reduces version check cost of next calls.
1840  */
1841 bool
CheckCitusVersion(int elevel)1842 CheckCitusVersion(int elevel)
1843 {
1844 	if (citusVersionKnownCompatible ||
1845 		!CitusHasBeenLoaded() ||
1846 		!EnableVersionChecks)
1847 	{
1848 		return true;
1849 	}
1850 
1851 	if (CheckAvailableVersion(elevel) && CheckInstalledVersion(elevel))
1852 	{
1853 		citusVersionKnownCompatible = true;
1854 		return true;
1855 	}
1856 	else
1857 	{
1858 		return false;
1859 	}
1860 }
1861 
1862 
1863 /*
1864  * CheckAvailableVersion compares CITUS_EXTENSIONVERSION and the currently
1865  * available version from the citus.control file. If they are not compatible,
1866  * this function logs an error with the specified elevel and returns false,
1867  * otherwise it returns true.
1868  */
1869 bool
CheckAvailableVersion(int elevel)1870 CheckAvailableVersion(int elevel)
1871 {
1872 	if (!EnableVersionChecks)
1873 	{
1874 		return true;
1875 	}
1876 
1877 	char *availableVersion = AvailableExtensionVersion();
1878 
1879 	if (!MajorVersionsCompatible(availableVersion, CITUS_EXTENSIONVERSION))
1880 	{
1881 		ereport(elevel, (errmsg("loaded Citus library version differs from latest "
1882 								"available extension version"),
1883 						 errdetail("Loaded library requires %s, but the latest control "
1884 								   "file specifies %s.", CITUS_MAJORVERSION,
1885 								   availableVersion),
1886 						 errhint("Restart the database to load the latest Citus "
1887 								 "library.")));
1888 		return false;
1889 	}
1890 
1891 	return true;
1892 }
1893 
1894 
1895 /*
1896  * CheckInstalledVersion compares CITUS_EXTENSIONVERSION and the
1897  * extension's current version from the pg_extension catalog table. If they
1898  * are not compatible, this function logs an error with the specified elevel,
1899  * otherwise it returns true.
1900  */
1901 static bool
CheckInstalledVersion(int elevel)1902 CheckInstalledVersion(int elevel)
1903 {
1904 	Assert(CitusHasBeenLoaded());
1905 	Assert(EnableVersionChecks);
1906 
1907 	char *installedVersion = InstalledExtensionVersion();
1908 
1909 	if (!MajorVersionsCompatible(installedVersion, CITUS_EXTENSIONVERSION))
1910 	{
1911 		ereport(elevel, (errmsg("loaded Citus library version differs from installed "
1912 								"extension version"),
1913 						 errdetail("Loaded library requires %s, but the installed "
1914 								   "extension version is %s.", CITUS_MAJORVERSION,
1915 								   installedVersion),
1916 						 errhint("Run ALTER EXTENSION citus UPDATE and try again.")));
1917 		return false;
1918 	}
1919 
1920 	return true;
1921 }
1922 
1923 
1924 /*
1925  * InstalledAndAvailableVersionsSame compares extension's available version and
1926  * its current version from the pg_extension catalog table. If they are not same
1927  * returns false, otherwise returns true.
1928  */
1929 bool
InstalledAndAvailableVersionsSame()1930 InstalledAndAvailableVersionsSame()
1931 {
1932 	char *installedVersion = InstalledExtensionVersion();
1933 	char *availableVersion = AvailableExtensionVersion();
1934 
1935 	if (strncmp(installedVersion, availableVersion, NAMEDATALEN) == 0)
1936 	{
1937 		return true;
1938 	}
1939 
1940 	return false;
1941 }
1942 
1943 
1944 /*
1945  * MajorVersionsCompatible checks whether both versions are compatible. They
1946  * are if major and minor version numbers match, the schema version is
1947  * ignored.  Returns true if compatible, false otherwise.
1948  */
1949 bool
MajorVersionsCompatible(char * leftVersion,char * rightVersion)1950 MajorVersionsCompatible(char *leftVersion, char *rightVersion)
1951 {
1952 	const char schemaVersionSeparator = '-';
1953 
1954 	char *leftSeperatorPosition = strchr(leftVersion, schemaVersionSeparator);
1955 	char *rightSeperatorPosition = strchr(rightVersion, schemaVersionSeparator);
1956 	int leftComparisionLimit = 0;
1957 	int rightComparisionLimit = 0;
1958 
1959 	if (leftSeperatorPosition != NULL)
1960 	{
1961 		leftComparisionLimit = leftSeperatorPosition - leftVersion;
1962 	}
1963 	else
1964 	{
1965 		leftComparisionLimit = strlen(leftVersion);
1966 	}
1967 
1968 	if (rightSeperatorPosition != NULL)
1969 	{
1970 		rightComparisionLimit = rightSeperatorPosition - rightVersion;
1971 	}
1972 	else
1973 	{
1974 		rightComparisionLimit = strlen(leftVersion);
1975 	}
1976 
1977 	/* we can error out early if hypens are not in the same position */
1978 	if (leftComparisionLimit != rightComparisionLimit)
1979 	{
1980 		return false;
1981 	}
1982 
1983 	return strncmp(leftVersion, rightVersion, leftComparisionLimit) == 0;
1984 }
1985 
1986 
1987 /*
1988  * AvailableExtensionVersion returns the Citus version from citus.control file. It also
1989  * saves the result, thus consecutive calls to CitusExtensionAvailableVersion will
1990  * not read the citus.control file again.
1991  */
1992 static char *
AvailableExtensionVersion(void)1993 AvailableExtensionVersion(void)
1994 {
1995 	LOCAL_FCINFO(fcinfo, 0);
1996 	FmgrInfo flinfo;
1997 
1998 	bool goForward = true;
1999 	bool doCopy = false;
2000 	char *availableExtensionVersion;
2001 
2002 	InitializeCaches();
2003 
2004 	EState *estate = CreateExecutorState();
2005 	ReturnSetInfo *extensionsResultSet = makeNode(ReturnSetInfo);
2006 	extensionsResultSet->econtext = GetPerTupleExprContext(estate);
2007 	extensionsResultSet->allowedModes = SFRM_Materialize;
2008 
2009 	fmgr_info(F_PG_AVAILABLE_EXTENSIONS, &flinfo);
2010 	InitFunctionCallInfoData(*fcinfo, &flinfo, 0, InvalidOid, NULL,
2011 							 (Node *) extensionsResultSet);
2012 
2013 	/* pg_available_extensions returns result set containing all available extensions */
2014 	(*pg_available_extensions)(fcinfo);
2015 
2016 	TupleTableSlot *tupleTableSlot = MakeSingleTupleTableSlotCompat(
2017 		extensionsResultSet->setDesc,
2018 		&TTSOpsMinimalTuple);
2019 	bool hasTuple = tuplestore_gettupleslot(extensionsResultSet->setResult, goForward,
2020 											doCopy,
2021 											tupleTableSlot);
2022 	while (hasTuple)
2023 	{
2024 		bool isNull = false;
2025 
2026 		Datum extensionNameDatum = slot_getattr(tupleTableSlot, 1, &isNull);
2027 		char *extensionName = NameStr(*DatumGetName(extensionNameDatum));
2028 		if (strcmp(extensionName, "citus") == 0)
2029 		{
2030 			Datum availableVersion = slot_getattr(tupleTableSlot, 2, &isNull);
2031 
2032 			/* we will cache the result of citus version to prevent catalog access */
2033 			MemoryContext oldMemoryContext = MemoryContextSwitchTo(
2034 				MetadataCacheMemoryContext);
2035 
2036 			availableExtensionVersion = text_to_cstring(DatumGetTextPP(availableVersion));
2037 
2038 			MemoryContextSwitchTo(oldMemoryContext);
2039 
2040 			ExecClearTuple(tupleTableSlot);
2041 			ExecDropSingleTupleTableSlot(tupleTableSlot);
2042 
2043 			return availableExtensionVersion;
2044 		}
2045 
2046 		ExecClearTuple(tupleTableSlot);
2047 		hasTuple = tuplestore_gettupleslot(extensionsResultSet->setResult, goForward,
2048 										   doCopy, tupleTableSlot);
2049 	}
2050 
2051 	ExecDropSingleTupleTableSlot(tupleTableSlot);
2052 
2053 	ereport(ERROR, (errcode(ERRCODE_OBJECT_NOT_IN_PREREQUISITE_STATE),
2054 					errmsg("citus extension is not found")));
2055 }
2056 
2057 
2058 /*
2059  * InstalledExtensionVersion returns the Citus version in PostgreSQL pg_extension table.
2060  */
2061 static char *
InstalledExtensionVersion(void)2062 InstalledExtensionVersion(void)
2063 {
2064 	ScanKeyData entry[1];
2065 	char *installedExtensionVersion = NULL;
2066 
2067 	InitializeCaches();
2068 
2069 	Relation relation = table_open(ExtensionRelationId, AccessShareLock);
2070 
2071 	ScanKeyInit(&entry[0], Anum_pg_extension_extname, BTEqualStrategyNumber, F_NAMEEQ,
2072 				CStringGetDatum("citus"));
2073 
2074 	SysScanDesc scandesc = systable_beginscan(relation, ExtensionNameIndexId, true,
2075 											  NULL, 1, entry);
2076 
2077 	HeapTuple extensionTuple = systable_getnext(scandesc);
2078 
2079 	/* We assume that there can be at most one matching tuple */
2080 	if (HeapTupleIsValid(extensionTuple))
2081 	{
2082 		int extensionIndex = Anum_pg_extension_extversion;
2083 		TupleDesc tupleDescriptor = RelationGetDescr(relation);
2084 		bool isNull = false;
2085 
2086 		Datum installedVersion = heap_getattr(extensionTuple, extensionIndex,
2087 											  tupleDescriptor, &isNull);
2088 
2089 		if (isNull)
2090 		{
2091 			ereport(ERROR, (errcode(ERRCODE_OBJECT_NOT_IN_PREREQUISITE_STATE),
2092 							errmsg("citus extension version is null")));
2093 		}
2094 
2095 		/* we will cache the result of citus version to prevent catalog access */
2096 		MemoryContext oldMemoryContext = MemoryContextSwitchTo(
2097 			MetadataCacheMemoryContext);
2098 
2099 		installedExtensionVersion = text_to_cstring(DatumGetTextPP(installedVersion));
2100 
2101 		MemoryContextSwitchTo(oldMemoryContext);
2102 	}
2103 	else
2104 	{
2105 		ereport(ERROR, (errcode(ERRCODE_OBJECT_NOT_IN_PREREQUISITE_STATE),
2106 						errmsg("citus extension is not loaded")));
2107 	}
2108 
2109 	systable_endscan(scandesc);
2110 
2111 	table_close(relation, AccessShareLock);
2112 
2113 	return installedExtensionVersion;
2114 }
2115 
2116 
2117 /* return oid of pg_dist_shard relation */
2118 Oid
DistShardRelationId(void)2119 DistShardRelationId(void)
2120 {
2121 	CachedRelationLookup("pg_dist_shard",
2122 						 &MetadataCache.distShardRelationId);
2123 
2124 	return MetadataCache.distShardRelationId;
2125 }
2126 
2127 
2128 /* return oid of pg_dist_placement relation */
2129 Oid
DistPlacementRelationId(void)2130 DistPlacementRelationId(void)
2131 {
2132 	CachedRelationLookup("pg_dist_placement",
2133 						 &MetadataCache.distPlacementRelationId);
2134 
2135 	return MetadataCache.distPlacementRelationId;
2136 }
2137 
2138 
2139 /* return oid of pg_dist_node relation */
2140 Oid
DistNodeRelationId(void)2141 DistNodeRelationId(void)
2142 {
2143 	CachedRelationLookup("pg_dist_node",
2144 						 &MetadataCache.distNodeRelationId);
2145 
2146 	return MetadataCache.distNodeRelationId;
2147 }
2148 
2149 
2150 /* return oid of pg_dist_node's primary key index */
2151 Oid
DistNodeNodeIdIndexId(void)2152 DistNodeNodeIdIndexId(void)
2153 {
2154 	CachedRelationLookup("pg_dist_node_pkey",
2155 						 &MetadataCache.distNodeNodeIdIndexId);
2156 
2157 	return MetadataCache.distNodeNodeIdIndexId;
2158 }
2159 
2160 
2161 /* return oid of pg_dist_local_group relation */
2162 Oid
DistLocalGroupIdRelationId(void)2163 DistLocalGroupIdRelationId(void)
2164 {
2165 	CachedRelationLookup("pg_dist_local_group",
2166 						 &MetadataCache.distLocalGroupRelationId);
2167 
2168 	return MetadataCache.distLocalGroupRelationId;
2169 }
2170 
2171 
2172 /* return oid of pg_dist_rebalance_strategy relation */
2173 Oid
DistRebalanceStrategyRelationId(void)2174 DistRebalanceStrategyRelationId(void)
2175 {
2176 	CachedRelationLookup("pg_dist_rebalance_strategy",
2177 						 &MetadataCache.distRebalanceStrategyRelationId);
2178 
2179 	return MetadataCache.distRebalanceStrategyRelationId;
2180 }
2181 
2182 
2183 /* return the oid of citus namespace */
2184 Oid
CitusCatalogNamespaceId(void)2185 CitusCatalogNamespaceId(void)
2186 {
2187 	CachedNamespaceLookup("citus", &MetadataCache.citusCatalogNamespaceId);
2188 	return MetadataCache.citusCatalogNamespaceId;
2189 }
2190 
2191 
2192 /* return oid of pg_dist_object relation */
2193 Oid
DistObjectRelationId(void)2194 DistObjectRelationId(void)
2195 {
2196 	CachedRelationNamespaceLookup("pg_dist_object", CitusCatalogNamespaceId(),
2197 								  &MetadataCache.distObjectRelationId);
2198 
2199 	return MetadataCache.distObjectRelationId;
2200 }
2201 
2202 
2203 /* return oid of pg_dist_object_pkey */
2204 Oid
DistObjectPrimaryKeyIndexId(void)2205 DistObjectPrimaryKeyIndexId(void)
2206 {
2207 	CachedRelationNamespaceLookup("pg_dist_object_pkey",
2208 								  CitusCatalogNamespaceId(),
2209 								  &MetadataCache.distObjectPrimaryKeyIndexId);
2210 
2211 	return MetadataCache.distObjectPrimaryKeyIndexId;
2212 }
2213 
2214 
2215 /* return oid of pg_dist_colocation relation */
2216 Oid
DistColocationRelationId(void)2217 DistColocationRelationId(void)
2218 {
2219 	CachedRelationLookup("pg_dist_colocation",
2220 						 &MetadataCache.distColocationRelationId);
2221 
2222 	return MetadataCache.distColocationRelationId;
2223 }
2224 
2225 
2226 /* return oid of pg_dist_colocation_configuration_index index */
2227 Oid
DistColocationConfigurationIndexId(void)2228 DistColocationConfigurationIndexId(void)
2229 {
2230 	CachedRelationLookup("pg_dist_colocation_configuration_index",
2231 						 &MetadataCache.distColocationConfigurationIndexId);
2232 
2233 	return MetadataCache.distColocationConfigurationIndexId;
2234 }
2235 
2236 
2237 /* return oid of pg_dist_partition relation */
2238 Oid
DistPartitionRelationId(void)2239 DistPartitionRelationId(void)
2240 {
2241 	CachedRelationLookup("pg_dist_partition",
2242 						 &MetadataCache.distPartitionRelationId);
2243 
2244 	return MetadataCache.distPartitionRelationId;
2245 }
2246 
2247 
2248 /* return oid of pg_dist_partition_logical_relid_index index */
2249 Oid
DistPartitionLogicalRelidIndexId(void)2250 DistPartitionLogicalRelidIndexId(void)
2251 {
2252 	CachedRelationLookup("pg_dist_partition_logical_relid_index",
2253 						 &MetadataCache.distPartitionLogicalRelidIndexId);
2254 
2255 	return MetadataCache.distPartitionLogicalRelidIndexId;
2256 }
2257 
2258 
2259 /* return oid of pg_dist_partition_colocationid_index index */
2260 Oid
DistPartitionColocationidIndexId(void)2261 DistPartitionColocationidIndexId(void)
2262 {
2263 	CachedRelationLookup("pg_dist_partition_colocationid_index",
2264 						 &MetadataCache.distPartitionColocationidIndexId);
2265 
2266 	return MetadataCache.distPartitionColocationidIndexId;
2267 }
2268 
2269 
2270 /* return oid of pg_dist_shard_logical_relid_index index */
2271 Oid
DistShardLogicalRelidIndexId(void)2272 DistShardLogicalRelidIndexId(void)
2273 {
2274 	CachedRelationLookup("pg_dist_shard_logical_relid_index",
2275 						 &MetadataCache.distShardLogicalRelidIndexId);
2276 
2277 	return MetadataCache.distShardLogicalRelidIndexId;
2278 }
2279 
2280 
2281 /* return oid of pg_dist_shard_shardid_index index */
2282 Oid
DistShardShardidIndexId(void)2283 DistShardShardidIndexId(void)
2284 {
2285 	CachedRelationLookup("pg_dist_shard_shardid_index",
2286 						 &MetadataCache.distShardShardidIndexId);
2287 
2288 	return MetadataCache.distShardShardidIndexId;
2289 }
2290 
2291 
2292 /* return oid of pg_dist_placement_shardid_index */
2293 Oid
DistPlacementShardidIndexId(void)2294 DistPlacementShardidIndexId(void)
2295 {
2296 	CachedRelationLookup("pg_dist_placement_shardid_index",
2297 						 &MetadataCache.distPlacementShardidIndexId);
2298 
2299 	return MetadataCache.distPlacementShardidIndexId;
2300 }
2301 
2302 
2303 /* return oid of pg_dist_placement_placementid_index */
2304 Oid
DistPlacementPlacementidIndexId(void)2305 DistPlacementPlacementidIndexId(void)
2306 {
2307 	CachedRelationLookup("pg_dist_placement_placementid_index",
2308 						 &MetadataCache.distPlacementPlacementidIndexId);
2309 
2310 	return MetadataCache.distPlacementPlacementidIndexId;
2311 }
2312 
2313 
2314 /* return oid of pg_dist_transaction relation */
2315 Oid
DistTransactionRelationId(void)2316 DistTransactionRelationId(void)
2317 {
2318 	CachedRelationLookup("pg_dist_transaction",
2319 						 &MetadataCache.distTransactionRelationId);
2320 
2321 	return MetadataCache.distTransactionRelationId;
2322 }
2323 
2324 
2325 /* return oid of pg_dist_transaction_group_index */
2326 Oid
DistTransactionGroupIndexId(void)2327 DistTransactionGroupIndexId(void)
2328 {
2329 	CachedRelationLookup("pg_dist_transaction_group_index",
2330 						 &MetadataCache.distTransactionGroupIndexId);
2331 
2332 	return MetadataCache.distTransactionGroupIndexId;
2333 }
2334 
2335 
2336 /* return oid of pg_dist_placement_groupid_index */
2337 Oid
DistPlacementGroupidIndexId(void)2338 DistPlacementGroupidIndexId(void)
2339 {
2340 	CachedRelationLookup("pg_dist_placement_groupid_index",
2341 						 &MetadataCache.distPlacementGroupidIndexId);
2342 
2343 	return MetadataCache.distPlacementGroupidIndexId;
2344 }
2345 
2346 
2347 /* return oid of the read_intermediate_result(text,citus_copy_format) function */
2348 Oid
CitusReadIntermediateResultFuncId(void)2349 CitusReadIntermediateResultFuncId(void)
2350 {
2351 	if (MetadataCache.readIntermediateResultFuncId == InvalidOid)
2352 	{
2353 		List *functionNameList = list_make2(makeString("pg_catalog"),
2354 											makeString("read_intermediate_result"));
2355 		Oid copyFormatTypeOid = CitusCopyFormatTypeId();
2356 		Oid paramOids[2] = { TEXTOID, copyFormatTypeOid };
2357 		bool missingOK = false;
2358 
2359 		MetadataCache.readIntermediateResultFuncId =
2360 			LookupFuncName(functionNameList, 2, paramOids, missingOK);
2361 	}
2362 
2363 	return MetadataCache.readIntermediateResultFuncId;
2364 }
2365 
2366 
2367 /* return oid of the read_intermediate_results(text[],citus_copy_format) function */
2368 Oid
CitusReadIntermediateResultArrayFuncId(void)2369 CitusReadIntermediateResultArrayFuncId(void)
2370 {
2371 	if (MetadataCache.readIntermediateResultArrayFuncId == InvalidOid)
2372 	{
2373 		List *functionNameList = list_make2(makeString("pg_catalog"),
2374 											makeString("read_intermediate_results"));
2375 		Oid copyFormatTypeOid = CitusCopyFormatTypeId();
2376 		Oid paramOids[2] = { TEXTARRAYOID, copyFormatTypeOid };
2377 		bool missingOK = false;
2378 
2379 		MetadataCache.readIntermediateResultArrayFuncId =
2380 			LookupFuncName(functionNameList, 2, paramOids, missingOK);
2381 	}
2382 
2383 	return MetadataCache.readIntermediateResultArrayFuncId;
2384 }
2385 
2386 
2387 /* return oid of the citus.copy_format enum type */
2388 Oid
CitusCopyFormatTypeId(void)2389 CitusCopyFormatTypeId(void)
2390 {
2391 	if (MetadataCache.copyFormatTypeId == InvalidOid)
2392 	{
2393 		char *typeName = "citus_copy_format";
2394 		MetadataCache.copyFormatTypeId = GetSysCacheOid2Compat(TYPENAMENSP,
2395 															   Anum_pg_enum_oid,
2396 															   PointerGetDatum(typeName),
2397 															   PG_CATALOG_NAMESPACE);
2398 	}
2399 
2400 	return MetadataCache.copyFormatTypeId;
2401 }
2402 
2403 
2404 /* return oid of the 'binary' citus_copy_format enum value */
2405 Oid
BinaryCopyFormatId(void)2406 BinaryCopyFormatId(void)
2407 {
2408 	if (MetadataCache.binaryCopyFormatId == InvalidOid)
2409 	{
2410 		Oid copyFormatTypeId = CitusCopyFormatTypeId();
2411 		MetadataCache.binaryCopyFormatId = LookupEnumValueId(copyFormatTypeId, "binary");
2412 	}
2413 
2414 	return MetadataCache.binaryCopyFormatId;
2415 }
2416 
2417 
2418 /* return oid of the 'text' citus_copy_format enum value */
2419 Oid
TextCopyFormatId(void)2420 TextCopyFormatId(void)
2421 {
2422 	if (MetadataCache.textCopyFormatId == InvalidOid)
2423 	{
2424 		Oid copyFormatTypeId = CitusCopyFormatTypeId();
2425 		MetadataCache.textCopyFormatId = LookupEnumValueId(copyFormatTypeId, "text");
2426 	}
2427 
2428 	return MetadataCache.textCopyFormatId;
2429 }
2430 
2431 
2432 /* return oid of the citus_extradata_container(internal) function */
2433 Oid
CitusExtraDataContainerFuncId(void)2434 CitusExtraDataContainerFuncId(void)
2435 {
2436 	List *nameList = NIL;
2437 	Oid paramOids[1] = { INTERNALOID };
2438 
2439 	if (MetadataCache.extraDataContainerFuncId == InvalidOid)
2440 	{
2441 		nameList = list_make2(makeString("pg_catalog"),
2442 							  makeString("citus_extradata_container"));
2443 		MetadataCache.extraDataContainerFuncId =
2444 			LookupFuncName(nameList, 1, paramOids, false);
2445 	}
2446 
2447 	return MetadataCache.extraDataContainerFuncId;
2448 }
2449 
2450 
2451 /* return oid of the any_value aggregate function */
2452 Oid
CitusAnyValueFunctionId(void)2453 CitusAnyValueFunctionId(void)
2454 {
2455 	if (MetadataCache.anyValueFunctionId == InvalidOid)
2456 	{
2457 		const int argCount = 1;
2458 		MetadataCache.anyValueFunctionId =
2459 			FunctionOid("pg_catalog", "any_value", argCount);
2460 	}
2461 
2462 	return MetadataCache.anyValueFunctionId;
2463 }
2464 
2465 
2466 /*
2467  * PgTableVisibleFuncId returns oid of the pg_table_is_visible function.
2468  */
2469 Oid
PgTableVisibleFuncId(void)2470 PgTableVisibleFuncId(void)
2471 {
2472 	if (MetadataCache.pgTableIsVisibleFuncId == InvalidOid)
2473 	{
2474 		const int argCount = 1;
2475 
2476 		MetadataCache.pgTableIsVisibleFuncId =
2477 			FunctionOid("pg_catalog", "pg_table_is_visible", argCount);
2478 	}
2479 
2480 	return MetadataCache.pgTableIsVisibleFuncId;
2481 }
2482 
2483 
2484 /*
2485  * CitusTableVisibleFuncId returns oid of the citus_table_is_visible function.
2486  */
2487 Oid
CitusTableVisibleFuncId(void)2488 CitusTableVisibleFuncId(void)
2489 {
2490 	if (MetadataCache.citusTableIsVisibleFuncId == InvalidOid)
2491 	{
2492 		const int argCount = 1;
2493 
2494 		MetadataCache.citusTableIsVisibleFuncId =
2495 			FunctionOid("pg_catalog", "citus_table_is_visible", argCount);
2496 	}
2497 
2498 	return MetadataCache.citusTableIsVisibleFuncId;
2499 }
2500 
2501 
2502 /*
2503  * JsonbExtractPathFuncId returns oid of the jsonb_extract_path function.
2504  */
2505 Oid
JsonbExtractPathFuncId(void)2506 JsonbExtractPathFuncId(void)
2507 {
2508 	if (MetadataCache.jsonbExtractPathFuncId == InvalidOid)
2509 	{
2510 		const int argCount = 2;
2511 
2512 		MetadataCache.jsonbExtractPathFuncId =
2513 			FunctionOid("pg_catalog", "jsonb_extract_path", argCount);
2514 	}
2515 
2516 	return MetadataCache.jsonbExtractPathFuncId;
2517 }
2518 
2519 
2520 /*
2521  * CurrentDatabaseName gets the name of the current database and caches
2522  * the result.
2523  *
2524  * Given that the database name cannot be changed when there is at least
2525  * one session connected to it, we do not need to implement any invalidation
2526  * mechanism.
2527  */
2528 const char *
CurrentDatabaseName(void)2529 CurrentDatabaseName(void)
2530 {
2531 	if (!MetadataCache.databaseNameValid)
2532 	{
2533 		char *databaseName = get_database_name(MyDatabaseId);
2534 		if (databaseName == NULL)
2535 		{
2536 			ereport(ERROR, (errmsg("database that is connected to does not exist")));
2537 		}
2538 
2539 		strlcpy(MetadataCache.databaseName, databaseName, NAMEDATALEN);
2540 		MetadataCache.databaseNameValid = true;
2541 	}
2542 
2543 	return MetadataCache.databaseName;
2544 }
2545 
2546 
2547 /*
2548  * CitusExtensionOwner() returns the owner of the 'citus' extension. That user
2549  * is, amongst others, used to perform actions a normal user might not be
2550  * allowed to perform.
2551  */
2552 extern Oid
CitusExtensionOwner(void)2553 CitusExtensionOwner(void)
2554 {
2555 	ScanKeyData entry[1];
2556 	Form_pg_extension extensionForm = NULL;
2557 
2558 	if (MetadataCache.extensionOwner != InvalidOid)
2559 	{
2560 		return MetadataCache.extensionOwner;
2561 	}
2562 
2563 	Relation relation = table_open(ExtensionRelationId, AccessShareLock);
2564 
2565 	ScanKeyInit(&entry[0],
2566 				Anum_pg_extension_extname,
2567 				BTEqualStrategyNumber, F_NAMEEQ,
2568 				CStringGetDatum("citus"));
2569 
2570 	SysScanDesc scandesc = systable_beginscan(relation, ExtensionNameIndexId, true,
2571 											  NULL, 1, entry);
2572 
2573 	HeapTuple extensionTuple = systable_getnext(scandesc);
2574 
2575 	/* We assume that there can be at most one matching tuple */
2576 	if (HeapTupleIsValid(extensionTuple))
2577 	{
2578 		extensionForm = (Form_pg_extension) GETSTRUCT(extensionTuple);
2579 
2580 		/*
2581 		 * For some operations Citus requires superuser permissions; we use
2582 		 * the extension owner for that. The extension owner is guaranteed to
2583 		 * be a superuser (otherwise C functions can't be created), but it'd
2584 		 * be possible to change the owner. So check that this still a
2585 		 * superuser.
2586 		 */
2587 		if (!superuser_arg(extensionForm->extowner))
2588 		{
2589 			ereport(ERROR, (errcode(ERRCODE_OBJECT_NOT_IN_PREREQUISITE_STATE),
2590 							errmsg("citus extension needs to be owned by superuser")));
2591 		}
2592 		MetadataCache.extensionOwner = extensionForm->extowner;
2593 		Assert(OidIsValid(MetadataCache.extensionOwner));
2594 	}
2595 	else
2596 	{
2597 		ereport(ERROR, (errcode(ERRCODE_OBJECT_NOT_IN_PREREQUISITE_STATE),
2598 						errmsg("citus extension not loaded")));
2599 	}
2600 
2601 	systable_endscan(scandesc);
2602 
2603 	table_close(relation, AccessShareLock);
2604 
2605 	return MetadataCache.extensionOwner;
2606 }
2607 
2608 
2609 /*
2610  * CitusExtensionOwnerName returns the name of the owner of the extension.
2611  */
2612 char *
CitusExtensionOwnerName(void)2613 CitusExtensionOwnerName(void)
2614 {
2615 	Oid superUserId = CitusExtensionOwner();
2616 
2617 	return GetUserNameFromId(superUserId, false);
2618 }
2619 
2620 
2621 /* return the username of the currently active role */
2622 char *
CurrentUserName(void)2623 CurrentUserName(void)
2624 {
2625 	Oid userId = GetUserId();
2626 
2627 	return GetUserNameFromId(userId, false);
2628 }
2629 
2630 
2631 /*
2632  * LookupTypeOid returns the Oid of the "{schemaNameSting}.{typeNameString}" type, or
2633  * InvalidOid if it does not exist.
2634  */
2635 Oid
LookupTypeOid(char * schemaNameSting,char * typeNameString)2636 LookupTypeOid(char *schemaNameSting, char *typeNameString)
2637 {
2638 	Value *schemaName = makeString(schemaNameSting);
2639 	Value *typeName = makeString(typeNameString);
2640 	List *qualifiedName = list_make2(schemaName, typeName);
2641 	TypeName *enumTypeName = makeTypeNameFromNameList(qualifiedName);
2642 
2643 
2644 	/* typenameTypeId but instead of raising an error return InvalidOid */
2645 	Type tup = LookupTypeName(NULL, enumTypeName, NULL, false);
2646 	if (tup == NULL)
2647 	{
2648 		return InvalidOid;
2649 	}
2650 
2651 	Oid nodeRoleTypId = ((Form_pg_type) GETSTRUCT(tup))->oid;
2652 	ReleaseSysCache(tup);
2653 
2654 	return nodeRoleTypId;
2655 }
2656 
2657 
2658 /*
2659  * LookupStringEnumValueId returns the Oid of the value in "pg_catalog.{enumName}"
2660  * which matches the provided valueName, or InvalidOid if the enum doesn't exist yet.
2661  */
2662 static Oid
LookupStringEnumValueId(char * enumName,char * valueName)2663 LookupStringEnumValueId(char *enumName, char *valueName)
2664 {
2665 	Oid enumTypeId = LookupTypeOid("pg_catalog", enumName);
2666 
2667 	if (enumTypeId == InvalidOid)
2668 	{
2669 		return InvalidOid;
2670 	}
2671 	else
2672 	{
2673 		Oid valueId = LookupEnumValueId(enumTypeId, valueName);
2674 		return valueId;
2675 	}
2676 }
2677 
2678 
2679 /*
2680  * LookupEnumValueId looks up the OID of an enum value.
2681  */
2682 static Oid
LookupEnumValueId(Oid typeId,char * valueName)2683 LookupEnumValueId(Oid typeId, char *valueName)
2684 {
2685 	Datum typeIdDatum = ObjectIdGetDatum(typeId);
2686 	Datum valueDatum = CStringGetDatum(valueName);
2687 	Datum valueIdDatum = DirectFunctionCall2(enum_in, valueDatum, typeIdDatum);
2688 	Oid valueId = DatumGetObjectId(valueIdDatum);
2689 
2690 	return valueId;
2691 }
2692 
2693 
2694 /* return the Oid of the 'primary' nodeRole enum value */
2695 Oid
PrimaryNodeRoleId(void)2696 PrimaryNodeRoleId(void)
2697 {
2698 	if (!MetadataCache.primaryNodeRoleId)
2699 	{
2700 		MetadataCache.primaryNodeRoleId = LookupStringEnumValueId("noderole", "primary");
2701 	}
2702 
2703 	return MetadataCache.primaryNodeRoleId;
2704 }
2705 
2706 
2707 /* return the Oid of the 'secodary' nodeRole enum value */
2708 Oid
SecondaryNodeRoleId(void)2709 SecondaryNodeRoleId(void)
2710 {
2711 	if (!MetadataCache.secondaryNodeRoleId)
2712 	{
2713 		MetadataCache.secondaryNodeRoleId = LookupStringEnumValueId("noderole",
2714 																	"secondary");
2715 	}
2716 
2717 	return MetadataCache.secondaryNodeRoleId;
2718 }
2719 
2720 
2721 /*
2722  * citus_dist_partition_cache_invalidate is a trigger function that performs
2723  * relcache invalidations when the contents of pg_dist_partition are changed
2724  * on the SQL level.
2725  *
2726  * NB: We decided there is little point in checking permissions here, there
2727  * are much easier ways to waste CPU than causing cache invalidations.
2728  */
2729 Datum
citus_dist_partition_cache_invalidate(PG_FUNCTION_ARGS)2730 citus_dist_partition_cache_invalidate(PG_FUNCTION_ARGS)
2731 {
2732 	CheckCitusVersion(ERROR);
2733 
2734 	TriggerData *triggerData = (TriggerData *) fcinfo->context;
2735 	Oid oldLogicalRelationId = InvalidOid;
2736 	Oid newLogicalRelationId = InvalidOid;
2737 
2738 	if (!CALLED_AS_TRIGGER(fcinfo))
2739 	{
2740 		ereport(ERROR, (errcode(ERRCODE_E_R_I_E_TRIGGER_PROTOCOL_VIOLATED),
2741 						errmsg("must be called as trigger")));
2742 	}
2743 
2744 	if (RelationGetRelid(triggerData->tg_relation) != DistPartitionRelationId())
2745 	{
2746 		ereport(ERROR, (errcode(ERRCODE_E_R_I_E_TRIGGER_PROTOCOL_VIOLATED),
2747 						errmsg("triggered on incorrect relation")));
2748 	}
2749 
2750 	HeapTuple newTuple = triggerData->tg_newtuple;
2751 	HeapTuple oldTuple = triggerData->tg_trigtuple;
2752 
2753 	/* collect logicalrelid for OLD and NEW tuple */
2754 	if (oldTuple != NULL)
2755 	{
2756 		Form_pg_dist_partition distPart = (Form_pg_dist_partition) GETSTRUCT(oldTuple);
2757 
2758 		oldLogicalRelationId = distPart->logicalrelid;
2759 	}
2760 
2761 	if (newTuple != NULL)
2762 	{
2763 		Form_pg_dist_partition distPart = (Form_pg_dist_partition) GETSTRUCT(newTuple);
2764 
2765 		newLogicalRelationId = distPart->logicalrelid;
2766 	}
2767 
2768 	/*
2769 	 * Invalidate relcache for the relevant relation(s). In theory
2770 	 * logicalrelid should never change, but it doesn't hurt to be
2771 	 * paranoid.
2772 	 */
2773 	if (oldLogicalRelationId != InvalidOid &&
2774 		oldLogicalRelationId != newLogicalRelationId)
2775 	{
2776 		CitusInvalidateRelcacheByRelid(oldLogicalRelationId);
2777 	}
2778 
2779 	if (newLogicalRelationId != InvalidOid)
2780 	{
2781 		CitusInvalidateRelcacheByRelid(newLogicalRelationId);
2782 	}
2783 
2784 	PG_RETURN_DATUM(PointerGetDatum(NULL));
2785 }
2786 
2787 
2788 /*
2789  * master_dist_partition_cache_invalidate is a wrapper function for old UDF name.
2790  */
2791 Datum
master_dist_partition_cache_invalidate(PG_FUNCTION_ARGS)2792 master_dist_partition_cache_invalidate(PG_FUNCTION_ARGS)
2793 {
2794 	return citus_dist_partition_cache_invalidate(fcinfo);
2795 }
2796 
2797 
2798 /*
2799  * citus_dist_shard_cache_invalidate is a trigger function that performs
2800  * relcache invalidations when the contents of pg_dist_shard are changed
2801  * on the SQL level.
2802  *
2803  * NB: We decided there is little point in checking permissions here, there
2804  * are much easier ways to waste CPU than causing cache invalidations.
2805  */
2806 Datum
citus_dist_shard_cache_invalidate(PG_FUNCTION_ARGS)2807 citus_dist_shard_cache_invalidate(PG_FUNCTION_ARGS)
2808 {
2809 	CheckCitusVersion(ERROR);
2810 
2811 	TriggerData *triggerData = (TriggerData *) fcinfo->context;
2812 	Oid oldLogicalRelationId = InvalidOid;
2813 	Oid newLogicalRelationId = InvalidOid;
2814 
2815 	if (!CALLED_AS_TRIGGER(fcinfo))
2816 	{
2817 		ereport(ERROR, (errcode(ERRCODE_E_R_I_E_TRIGGER_PROTOCOL_VIOLATED),
2818 						errmsg("must be called as trigger")));
2819 	}
2820 
2821 	if (RelationGetRelid(triggerData->tg_relation) != DistShardRelationId())
2822 	{
2823 		ereport(ERROR, (errcode(ERRCODE_E_R_I_E_TRIGGER_PROTOCOL_VIOLATED),
2824 						errmsg("triggered on incorrect relation")));
2825 	}
2826 
2827 	HeapTuple newTuple = triggerData->tg_newtuple;
2828 	HeapTuple oldTuple = triggerData->tg_trigtuple;
2829 
2830 	/* collect logicalrelid for OLD and NEW tuple */
2831 	if (oldTuple != NULL)
2832 	{
2833 		Form_pg_dist_shard distShard = (Form_pg_dist_shard) GETSTRUCT(oldTuple);
2834 
2835 		oldLogicalRelationId = distShard->logicalrelid;
2836 	}
2837 
2838 	if (newTuple != NULL)
2839 	{
2840 		Form_pg_dist_shard distShard = (Form_pg_dist_shard) GETSTRUCT(newTuple);
2841 
2842 		newLogicalRelationId = distShard->logicalrelid;
2843 	}
2844 
2845 	/*
2846 	 * Invalidate relcache for the relevant relation(s). In theory
2847 	 * logicalrelid should never change, but it doesn't hurt to be
2848 	 * paranoid.
2849 	 */
2850 	if (oldLogicalRelationId != InvalidOid &&
2851 		oldLogicalRelationId != newLogicalRelationId)
2852 	{
2853 		CitusInvalidateRelcacheByRelid(oldLogicalRelationId);
2854 	}
2855 
2856 	if (newLogicalRelationId != InvalidOid)
2857 	{
2858 		CitusInvalidateRelcacheByRelid(newLogicalRelationId);
2859 	}
2860 
2861 	PG_RETURN_DATUM(PointerGetDatum(NULL));
2862 }
2863 
2864 
2865 /*
2866  * master_dist_shard_cache_invalidate is a wrapper function for old UDF name.
2867  */
2868 Datum
master_dist_shard_cache_invalidate(PG_FUNCTION_ARGS)2869 master_dist_shard_cache_invalidate(PG_FUNCTION_ARGS)
2870 {
2871 	return citus_dist_shard_cache_invalidate(fcinfo);
2872 }
2873 
2874 
2875 /*
2876  * citus_dist_placement_cache_invalidate is a trigger function that performs
2877  * relcache invalidations when the contents of pg_dist_placement are
2878  * changed on the SQL level.
2879  *
2880  * NB: We decided there is little point in checking permissions here, there
2881  * are much easier ways to waste CPU than causing cache invalidations.
2882  */
2883 Datum
citus_dist_placement_cache_invalidate(PG_FUNCTION_ARGS)2884 citus_dist_placement_cache_invalidate(PG_FUNCTION_ARGS)
2885 {
2886 	CheckCitusVersion(ERROR);
2887 
2888 	TriggerData *triggerData = (TriggerData *) fcinfo->context;
2889 	Oid oldShardId = InvalidOid;
2890 	Oid newShardId = InvalidOid;
2891 
2892 	if (!CALLED_AS_TRIGGER(fcinfo))
2893 	{
2894 		ereport(ERROR, (errcode(ERRCODE_E_R_I_E_TRIGGER_PROTOCOL_VIOLATED),
2895 						errmsg("must be called as trigger")));
2896 	}
2897 
2898 	/*
2899 	 * Before 7.0-2 this trigger is on pg_dist_shard_placement,
2900 	 * ignore trigger in this scenario.
2901 	 */
2902 	Oid pgDistShardPlacementId = get_relname_relid("pg_dist_shard_placement",
2903 												   PG_CATALOG_NAMESPACE);
2904 	if (RelationGetRelid(triggerData->tg_relation) == pgDistShardPlacementId)
2905 	{
2906 		PG_RETURN_DATUM(PointerGetDatum(NULL));
2907 	}
2908 
2909 	if (RelationGetRelid(triggerData->tg_relation) != DistPlacementRelationId())
2910 	{
2911 		ereport(ERROR, (errcode(ERRCODE_E_R_I_E_TRIGGER_PROTOCOL_VIOLATED),
2912 						errmsg("triggered on incorrect relation")));
2913 	}
2914 
2915 	HeapTuple newTuple = triggerData->tg_newtuple;
2916 	HeapTuple oldTuple = triggerData->tg_trigtuple;
2917 
2918 	/* collect shardid for OLD and NEW tuple */
2919 	if (oldTuple != NULL)
2920 	{
2921 		Form_pg_dist_placement distPlacement =
2922 			(Form_pg_dist_placement) GETSTRUCT(oldTuple);
2923 
2924 		oldShardId = distPlacement->shardid;
2925 	}
2926 
2927 	if (newTuple != NULL)
2928 	{
2929 		Form_pg_dist_placement distPlacement =
2930 			(Form_pg_dist_placement) GETSTRUCT(newTuple);
2931 
2932 		newShardId = distPlacement->shardid;
2933 	}
2934 
2935 	/*
2936 	 * Invalidate relcache for the relevant relation(s). In theory shardId
2937 	 * should never change, but it doesn't hurt to be paranoid.
2938 	 */
2939 	if (oldShardId != InvalidOid &&
2940 		oldShardId != newShardId)
2941 	{
2942 		CitusInvalidateRelcacheByShardId(oldShardId);
2943 	}
2944 
2945 	if (newShardId != InvalidOid)
2946 	{
2947 		CitusInvalidateRelcacheByShardId(newShardId);
2948 	}
2949 
2950 	PG_RETURN_DATUM(PointerGetDatum(NULL));
2951 }
2952 
2953 
2954 /*
2955  * master_dist_placement_cache_invalidate is a wrapper function for old UDF name.
2956  */
2957 Datum
master_dist_placement_cache_invalidate(PG_FUNCTION_ARGS)2958 master_dist_placement_cache_invalidate(PG_FUNCTION_ARGS)
2959 {
2960 	return citus_dist_placement_cache_invalidate(fcinfo);
2961 }
2962 
2963 
2964 /*
2965  * citus_dist_node_cache_invalidate is a trigger function that performs
2966  * relcache invalidations when the contents of pg_dist_node are changed
2967  * on the SQL level.
2968  *
2969  * NB: We decided there is little point in checking permissions here, there
2970  * are much easier ways to waste CPU than causing cache invalidations.
2971  */
2972 Datum
citus_dist_node_cache_invalidate(PG_FUNCTION_ARGS)2973 citus_dist_node_cache_invalidate(PG_FUNCTION_ARGS)
2974 {
2975 	CheckCitusVersion(ERROR);
2976 
2977 	if (!CALLED_AS_TRIGGER(fcinfo))
2978 	{
2979 		ereport(ERROR, (errcode(ERRCODE_E_R_I_E_TRIGGER_PROTOCOL_VIOLATED),
2980 						errmsg("must be called as trigger")));
2981 	}
2982 
2983 	CitusInvalidateRelcacheByRelid(DistNodeRelationId());
2984 
2985 	PG_RETURN_DATUM(PointerGetDatum(NULL));
2986 }
2987 
2988 
2989 /*
2990  * master_dist_node_cache_invalidate is a wrapper function for old UDF name.
2991  */
2992 Datum
master_dist_node_cache_invalidate(PG_FUNCTION_ARGS)2993 master_dist_node_cache_invalidate(PG_FUNCTION_ARGS)
2994 {
2995 	return citus_dist_node_cache_invalidate(fcinfo);
2996 }
2997 
2998 
2999 /*
3000  * citus_conninfo_cache_invalidate is a trigger function that performs
3001  * relcache invalidations when the contents of pg_dist_authinfo are changed
3002  * on the SQL level.
3003  *
3004  * NB: We decided there is little point in checking permissions here, there
3005  * are much easier ways to waste CPU than causing cache invalidations.
3006  */
3007 Datum
citus_conninfo_cache_invalidate(PG_FUNCTION_ARGS)3008 citus_conninfo_cache_invalidate(PG_FUNCTION_ARGS)
3009 {
3010 	CheckCitusVersion(ERROR);
3011 
3012 	if (!CALLED_AS_TRIGGER(fcinfo))
3013 	{
3014 		ereport(ERROR, (errcode(ERRCODE_E_R_I_E_TRIGGER_PROTOCOL_VIOLATED),
3015 						errmsg("must be called as trigger")));
3016 	}
3017 
3018 	/* no-op in community edition */
3019 
3020 	PG_RETURN_DATUM(PointerGetDatum(NULL));
3021 }
3022 
3023 
3024 /*
3025  * master_dist_authinfo_cache_invalidate is a wrapper function for old UDF name.
3026  */
3027 Datum
master_dist_authinfo_cache_invalidate(PG_FUNCTION_ARGS)3028 master_dist_authinfo_cache_invalidate(PG_FUNCTION_ARGS)
3029 {
3030 	return citus_conninfo_cache_invalidate(fcinfo);
3031 }
3032 
3033 
3034 /*
3035  * citus_dist_local_group_cache_invalidate is a trigger function that performs
3036  * relcache invalidations when the contents of pg_dist_local_group are changed
3037  * on the SQL level.
3038  *
3039  * NB: We decided there is little point in checking permissions here, there
3040  * are much easier ways to waste CPU than causing cache invalidations.
3041  */
3042 Datum
citus_dist_local_group_cache_invalidate(PG_FUNCTION_ARGS)3043 citus_dist_local_group_cache_invalidate(PG_FUNCTION_ARGS)
3044 {
3045 	CheckCitusVersion(ERROR);
3046 
3047 	if (!CALLED_AS_TRIGGER(fcinfo))
3048 	{
3049 		ereport(ERROR, (errcode(ERRCODE_E_R_I_E_TRIGGER_PROTOCOL_VIOLATED),
3050 						errmsg("must be called as trigger")));
3051 	}
3052 
3053 	CitusInvalidateRelcacheByRelid(DistLocalGroupIdRelationId());
3054 
3055 	PG_RETURN_DATUM(PointerGetDatum(NULL));
3056 }
3057 
3058 
3059 /*
3060  * master_dist_local_group_cache_invalidate is a wrapper function for old UDF name.
3061  */
3062 Datum
master_dist_local_group_cache_invalidate(PG_FUNCTION_ARGS)3063 master_dist_local_group_cache_invalidate(PG_FUNCTION_ARGS)
3064 {
3065 	return citus_dist_local_group_cache_invalidate(fcinfo);
3066 }
3067 
3068 
3069 /*
3070  * citus_dist_object_cache_invalidate is a trigger function that performs relcache
3071  * invalidation when the contents of pg_dist_object are changed on the SQL
3072  * level.
3073  *
3074  * NB: We decided there is little point in checking permissions here, there
3075  * are much easier ways to waste CPU than causing cache invalidations.
3076  */
3077 Datum
citus_dist_object_cache_invalidate(PG_FUNCTION_ARGS)3078 citus_dist_object_cache_invalidate(PG_FUNCTION_ARGS)
3079 {
3080 	CheckCitusVersion(ERROR);
3081 
3082 	if (!CALLED_AS_TRIGGER(fcinfo))
3083 	{
3084 		ereport(ERROR, (errcode(ERRCODE_E_R_I_E_TRIGGER_PROTOCOL_VIOLATED),
3085 						errmsg("must be called as trigger")));
3086 	}
3087 
3088 	CitusInvalidateRelcacheByRelid(DistObjectRelationId());
3089 
3090 	PG_RETURN_DATUM(PointerGetDatum(NULL));
3091 }
3092 
3093 
3094 /*
3095  * master_dist_object_cache_invalidate is a wrapper function for old UDF name.
3096  */
3097 Datum
master_dist_object_cache_invalidate(PG_FUNCTION_ARGS)3098 master_dist_object_cache_invalidate(PG_FUNCTION_ARGS)
3099 {
3100 	return citus_dist_object_cache_invalidate(fcinfo);
3101 }
3102 
3103 
3104 /*
3105  * InitializeCaches() registers invalidation handlers for metadata_cache.c's
3106  * caches.
3107  */
3108 static void
InitializeCaches(void)3109 InitializeCaches(void)
3110 {
3111 	static bool performedInitialization = false;
3112 
3113 	if (!performedInitialization)
3114 	{
3115 		MetadataCacheMemoryContext = NULL;
3116 
3117 		/*
3118 		 * If either of dist table cache or shard cache
3119 		 * allocation and initializations fail due to an exception
3120 		 * that is caused by OOM or any other reason,
3121 		 * we reset the flag, and delete the shard cache memory
3122 		 * context to reclaim partially allocated memory.
3123 		 *
3124 		 * Command will continue to fail since we re-throw the exception.
3125 		 */
3126 		PG_TRY();
3127 		{
3128 			/* set first, to avoid recursion dangers */
3129 			performedInitialization = true;
3130 
3131 			/* make sure we've initialized CacheMemoryContext */
3132 			if (CacheMemoryContext == NULL)
3133 			{
3134 				CreateCacheMemoryContext();
3135 			}
3136 
3137 			MetadataCacheMemoryContext = AllocSetContextCreate(
3138 				CacheMemoryContext,
3139 				"MetadataCacheMemoryContext",
3140 				ALLOCSET_DEFAULT_SIZES);
3141 
3142 			InitializeDistCache();
3143 			RegisterForeignKeyGraphCacheCallbacks();
3144 			RegisterWorkerNodeCacheCallbacks();
3145 			RegisterLocalGroupIdCacheCallbacks();
3146 			RegisterCitusTableCacheEntryReleaseCallbacks();
3147 		}
3148 		PG_CATCH();
3149 		{
3150 			performedInitialization = false;
3151 
3152 			if (MetadataCacheMemoryContext != NULL)
3153 			{
3154 				MemoryContextDelete(MetadataCacheMemoryContext);
3155 			}
3156 
3157 			MetadataCacheMemoryContext = NULL;
3158 			DistTableCacheHash = NULL;
3159 			DistTableCacheExpired = NIL;
3160 			ShardIdCacheHash = NULL;
3161 
3162 			PG_RE_THROW();
3163 		}
3164 		PG_END_TRY();
3165 	}
3166 }
3167 
3168 
3169 /* initialize the infrastructure for the metadata cache */
3170 static void
InitializeDistCache(void)3171 InitializeDistCache(void)
3172 {
3173 	/* build initial scan keys, copied for every relation scan */
3174 	memset(&DistPartitionScanKey, 0, sizeof(DistPartitionScanKey));
3175 
3176 	fmgr_info_cxt(F_OIDEQ,
3177 				  &DistPartitionScanKey[0].sk_func,
3178 				  MetadataCacheMemoryContext);
3179 	DistPartitionScanKey[0].sk_strategy = BTEqualStrategyNumber;
3180 	DistPartitionScanKey[0].sk_subtype = InvalidOid;
3181 	DistPartitionScanKey[0].sk_collation = InvalidOid;
3182 	DistPartitionScanKey[0].sk_attno = Anum_pg_dist_partition_logicalrelid;
3183 
3184 	memset(&DistShardScanKey, 0, sizeof(DistShardScanKey));
3185 
3186 	fmgr_info_cxt(F_OIDEQ,
3187 				  &DistShardScanKey[0].sk_func,
3188 				  MetadataCacheMemoryContext);
3189 	DistShardScanKey[0].sk_strategy = BTEqualStrategyNumber;
3190 	DistShardScanKey[0].sk_subtype = InvalidOid;
3191 	DistShardScanKey[0].sk_collation = InvalidOid;
3192 	DistShardScanKey[0].sk_attno = Anum_pg_dist_shard_logicalrelid;
3193 
3194 	CreateDistTableCache();
3195 	CreateShardIdCache();
3196 
3197 	InitializeDistObjectCache();
3198 
3199 	/* Watch for invalidation events. */
3200 	CacheRegisterRelcacheCallback(InvalidateDistRelationCacheCallback,
3201 								  (Datum) 0);
3202 }
3203 
3204 
3205 static void
InitializeDistObjectCache(void)3206 InitializeDistObjectCache(void)
3207 {
3208 	/* build initial scan keys, copied for every relation scan */
3209 	memset(&DistObjectScanKey, 0, sizeof(DistObjectScanKey));
3210 
3211 	fmgr_info_cxt(F_OIDEQ,
3212 				  &DistObjectScanKey[0].sk_func,
3213 				  MetadataCacheMemoryContext);
3214 	DistObjectScanKey[0].sk_strategy = BTEqualStrategyNumber;
3215 	DistObjectScanKey[0].sk_subtype = InvalidOid;
3216 	DistObjectScanKey[0].sk_collation = InvalidOid;
3217 	DistObjectScanKey[0].sk_attno = Anum_pg_dist_object_classid;
3218 
3219 	fmgr_info_cxt(F_OIDEQ,
3220 				  &DistObjectScanKey[1].sk_func,
3221 				  MetadataCacheMemoryContext);
3222 	DistObjectScanKey[1].sk_strategy = BTEqualStrategyNumber;
3223 	DistObjectScanKey[1].sk_subtype = InvalidOid;
3224 	DistObjectScanKey[1].sk_collation = InvalidOid;
3225 	DistObjectScanKey[1].sk_attno = Anum_pg_dist_object_objid;
3226 
3227 	fmgr_info_cxt(F_INT4EQ,
3228 				  &DistObjectScanKey[2].sk_func,
3229 				  MetadataCacheMemoryContext);
3230 	DistObjectScanKey[2].sk_strategy = BTEqualStrategyNumber;
3231 	DistObjectScanKey[2].sk_subtype = InvalidOid;
3232 	DistObjectScanKey[2].sk_collation = InvalidOid;
3233 	DistObjectScanKey[2].sk_attno = Anum_pg_dist_object_objsubid;
3234 
3235 	CreateDistObjectCache();
3236 }
3237 
3238 
3239 /*
3240  * GetWorkerNodeHash returns the worker node data as a hash with the nodename and
3241  * nodeport as a key.
3242  *
3243  * The hash is returned from the cache, if the cache is not (yet) valid, it is first
3244  * rebuilt.
3245  */
3246 HTAB *
GetWorkerNodeHash(void)3247 GetWorkerNodeHash(void)
3248 {
3249 	PrepareWorkerNodeCache();
3250 
3251 	return WorkerNodeHash;
3252 }
3253 
3254 
3255 /*
3256  * PrepareWorkerNodeCache makes sure the worker node data from pg_dist_node is cached,
3257  * if it is not already cached.
3258  */
3259 static void
PrepareWorkerNodeCache(void)3260 PrepareWorkerNodeCache(void)
3261 {
3262 	InitializeCaches(); /* ensure relevant callbacks are registered */
3263 
3264 	/*
3265 	 * Simulate a SELECT from pg_dist_node, ensure pg_dist_node doesn't change while our
3266 	 * caller is using WorkerNodeHash.
3267 	 */
3268 	LockRelationOid(DistNodeRelationId(), AccessShareLock);
3269 
3270 	/*
3271 	 * We might have some concurrent metadata changes. In order to get the changes,
3272 	 * we first need to accept the cache invalidation messages.
3273 	 */
3274 	AcceptInvalidationMessages();
3275 
3276 	if (!workerNodeHashValid)
3277 	{
3278 		InitializeWorkerNodeCache();
3279 
3280 		workerNodeHashValid = true;
3281 	}
3282 }
3283 
3284 
3285 /*
3286  * InitializeWorkerNodeCache initialize the infrastructure for the worker node cache.
3287  * The function reads the worker nodes from the metadata table, adds them to the hash and
3288  * finally registers an invalidation callback.
3289  */
3290 static void
InitializeWorkerNodeCache(void)3291 InitializeWorkerNodeCache(void)
3292 {
3293 	HASHCTL info;
3294 	long maxTableSize = (long) MaxWorkerNodesTracked;
3295 	bool includeNodesFromOtherClusters = false;
3296 	int workerNodeIndex = 0;
3297 
3298 	InitializeCaches();
3299 
3300 	/*
3301 	 * Create the hash that holds the worker nodes. The key is the combination of
3302 	 * nodename and nodeport, instead of the unique nodeid because worker nodes are
3303 	 * searched by the nodename and nodeport in every physical plan creation.
3304 	 */
3305 	memset(&info, 0, sizeof(info));
3306 	info.keysize = sizeof(uint32) + WORKER_LENGTH + sizeof(uint32);
3307 	info.entrysize = sizeof(WorkerNode);
3308 	info.hcxt = MetadataCacheMemoryContext;
3309 	info.hash = WorkerNodeHashCode;
3310 	info.match = WorkerNodeCompare;
3311 	int hashFlags = HASH_ELEM | HASH_FUNCTION | HASH_CONTEXT | HASH_COMPARE;
3312 
3313 	HTAB *newWorkerNodeHash = hash_create("Worker Node Hash", maxTableSize, &info,
3314 										  hashFlags);
3315 
3316 	/* read the list from pg_dist_node */
3317 	List *workerNodeList = ReadDistNode(includeNodesFromOtherClusters);
3318 
3319 	int newWorkerNodeCount = list_length(workerNodeList);
3320 	WorkerNode **newWorkerNodeArray = MemoryContextAlloc(MetadataCacheMemoryContext,
3321 														 sizeof(WorkerNode *) *
3322 														 newWorkerNodeCount);
3323 
3324 	/* iterate over the worker node list */
3325 	WorkerNode *currentNode = NULL;
3326 	foreach_ptr(currentNode, workerNodeList)
3327 	{
3328 		bool handleFound = false;
3329 
3330 		/* search for the worker node in the hash, and then insert the values */
3331 		void *hashKey = (void *) currentNode;
3332 		WorkerNode *workerNode = (WorkerNode *) hash_search(newWorkerNodeHash, hashKey,
3333 															HASH_ENTER, &handleFound);
3334 
3335 		/* fill the newly allocated workerNode in the cache */
3336 		strlcpy(workerNode->workerName, currentNode->workerName, WORKER_LENGTH);
3337 		workerNode->workerPort = currentNode->workerPort;
3338 		workerNode->groupId = currentNode->groupId;
3339 		workerNode->nodeId = currentNode->nodeId;
3340 		strlcpy(workerNode->workerRack, currentNode->workerRack, WORKER_LENGTH);
3341 		workerNode->hasMetadata = currentNode->hasMetadata;
3342 		workerNode->metadataSynced = currentNode->metadataSynced;
3343 		workerNode->isActive = currentNode->isActive;
3344 		workerNode->nodeRole = currentNode->nodeRole;
3345 		workerNode->shouldHaveShards = currentNode->shouldHaveShards;
3346 		strlcpy(workerNode->nodeCluster, currentNode->nodeCluster, NAMEDATALEN);
3347 
3348 		newWorkerNodeArray[workerNodeIndex++] = workerNode;
3349 
3350 		if (handleFound)
3351 		{
3352 			ereport(WARNING, (errmsg("multiple lines for worker node: \"%s:%u\"",
3353 									 workerNode->workerName,
3354 									 workerNode->workerPort)));
3355 		}
3356 
3357 		/* we do not need the currentNode anymore */
3358 		pfree(currentNode);
3359 	}
3360 
3361 	/* now, safe to destroy the old hash */
3362 	hash_destroy(WorkerNodeHash);
3363 
3364 	if (WorkerNodeArray != NULL)
3365 	{
3366 		pfree(WorkerNodeArray);
3367 	}
3368 
3369 	WorkerNodeCount = newWorkerNodeCount;
3370 	WorkerNodeArray = newWorkerNodeArray;
3371 	WorkerNodeHash = newWorkerNodeHash;
3372 }
3373 
3374 
3375 /*
3376  * RegisterForeignKeyGraphCacheCallbacks registers callbacks required for
3377  * the foreign key graph cache.
3378  */
3379 static void
RegisterForeignKeyGraphCacheCallbacks(void)3380 RegisterForeignKeyGraphCacheCallbacks(void)
3381 {
3382 	/* Watch for invalidation events. */
3383 	CacheRegisterRelcacheCallback(InvalidateForeignRelationGraphCacheCallback,
3384 								  (Datum) 0);
3385 }
3386 
3387 
3388 /*
3389  * RegisterWorkerNodeCacheCallbacks registers the callbacks required for the
3390  * worker node cache.  It's separate from InitializeWorkerNodeCache so the
3391  * callback can be registered early, before the metadata tables exist.
3392  */
3393 static void
RegisterWorkerNodeCacheCallbacks(void)3394 RegisterWorkerNodeCacheCallbacks(void)
3395 {
3396 	/* Watch for invalidation events. */
3397 	CacheRegisterRelcacheCallback(InvalidateNodeRelationCacheCallback,
3398 								  (Datum) 0);
3399 }
3400 
3401 
3402 /*
3403  * RegisterCitusTableCacheEntryReleaseCallbacks registers callbacks to release
3404  * cache entries. Data should be locked by callers to avoid staleness.
3405  */
3406 static void
RegisterCitusTableCacheEntryReleaseCallbacks(void)3407 RegisterCitusTableCacheEntryReleaseCallbacks(void)
3408 {
3409 	RegisterResourceReleaseCallback(CitusTableCacheEntryReleaseCallback, NULL);
3410 }
3411 
3412 
3413 /*
3414  * GetLocalGroupId returns the group identifier of the local node. The function assumes
3415  * that pg_dist_local_node_group has exactly one row and has at least one column.
3416  * Otherwise, the function errors out.
3417  */
3418 int32
GetLocalGroupId(void)3419 GetLocalGroupId(void)
3420 {
3421 	ScanKeyData scanKey[1];
3422 	int scanKeyCount = 0;
3423 	int32 groupId = 0;
3424 
3425 	InitializeCaches();
3426 
3427 	/*
3428 	 * Already set the group id, no need to read the heap again.
3429 	 */
3430 	if (LocalGroupId != -1)
3431 	{
3432 		return LocalGroupId;
3433 	}
3434 
3435 	Oid localGroupTableOid = DistLocalGroupIdRelationId();
3436 	if (localGroupTableOid == InvalidOid)
3437 	{
3438 		return 0;
3439 	}
3440 
3441 	Relation pgDistLocalGroupId = table_open(localGroupTableOid, AccessShareLock);
3442 
3443 	SysScanDesc scanDescriptor = systable_beginscan(pgDistLocalGroupId,
3444 													InvalidOid, false,
3445 													NULL, scanKeyCount, scanKey);
3446 
3447 	TupleDesc tupleDescriptor = RelationGetDescr(pgDistLocalGroupId);
3448 
3449 	HeapTuple heapTuple = systable_getnext(scanDescriptor);
3450 
3451 	if (HeapTupleIsValid(heapTuple))
3452 	{
3453 		bool isNull = false;
3454 		Datum groupIdDatum = heap_getattr(heapTuple,
3455 										  Anum_pg_dist_local_groupid,
3456 										  tupleDescriptor, &isNull);
3457 
3458 		groupId = DatumGetInt32(groupIdDatum);
3459 
3460 		/* set the local cache variable */
3461 		LocalGroupId = groupId;
3462 	}
3463 	else
3464 	{
3465 		/*
3466 		 * Upgrade is happening. When upgrading postgres, pg_dist_local_group is
3467 		 * temporarily empty before citus_finish_pg_upgrade() finishes execution.
3468 		 */
3469 		groupId = GROUP_ID_UPGRADING;
3470 	}
3471 
3472 	systable_endscan(scanDescriptor);
3473 	table_close(pgDistLocalGroupId, AccessShareLock);
3474 
3475 	return groupId;
3476 }
3477 
3478 
3479 /*
3480  * RegisterLocalGroupIdCacheCallbacks registers the callbacks required to
3481  * maintain LocalGroupId at a consistent value. It's separate from
3482  * GetLocalGroupId so the callback can be registered early, before metadata
3483  * tables exist.
3484  */
3485 static void
RegisterLocalGroupIdCacheCallbacks(void)3486 RegisterLocalGroupIdCacheCallbacks(void)
3487 {
3488 	/* Watch for invalidation events. */
3489 	CacheRegisterRelcacheCallback(InvalidateLocalGroupIdRelationCacheCallback,
3490 								  (Datum) 0);
3491 }
3492 
3493 
3494 /*
3495  * WorkerNodeHashCode computes the hash code for a worker node from the node's
3496  * host name and port number. Nodes that only differ by their rack locations
3497  * hash to the same value.
3498  */
3499 static uint32
WorkerNodeHashCode(const void * key,Size keySize)3500 WorkerNodeHashCode(const void *key, Size keySize)
3501 {
3502 	const WorkerNode *worker = (const WorkerNode *) key;
3503 	const char *workerName = worker->workerName;
3504 	const uint32 *workerPort = &(worker->workerPort);
3505 
3506 	/* standard hash function outlined in Effective Java, Item 8 */
3507 	uint32 result = 17;
3508 	result = 37 * result + string_hash(workerName, WORKER_LENGTH);
3509 	result = 37 * result + tag_hash(workerPort, sizeof(uint32));
3510 	return result;
3511 }
3512 
3513 
3514 /*
3515  * ResetCitusTableCacheEntry frees any out-of-band memory used by a cache entry,
3516  * but does not free the entry itself.
3517  */
3518 static void
ResetCitusTableCacheEntry(CitusTableCacheEntry * cacheEntry)3519 ResetCitusTableCacheEntry(CitusTableCacheEntry *cacheEntry)
3520 {
3521 	if (cacheEntry->partitionKeyString != NULL)
3522 	{
3523 		pfree(cacheEntry->partitionKeyString);
3524 		cacheEntry->partitionKeyString = NULL;
3525 	}
3526 
3527 	if (cacheEntry->shardIntervalCompareFunction != NULL)
3528 	{
3529 		pfree(cacheEntry->shardIntervalCompareFunction);
3530 		cacheEntry->shardIntervalCompareFunction = NULL;
3531 	}
3532 
3533 	if (cacheEntry->hashFunction)
3534 	{
3535 		pfree(cacheEntry->hashFunction);
3536 		cacheEntry->hashFunction = NULL;
3537 	}
3538 
3539 	if (cacheEntry->partitionColumn != NULL)
3540 	{
3541 		pfree(cacheEntry->partitionColumn);
3542 		cacheEntry->partitionColumn = NULL;
3543 	}
3544 
3545 	if (cacheEntry->shardIntervalArrayLength == 0)
3546 	{
3547 		return;
3548 	}
3549 
3550 	/* clean up ShardIdCacheHash */
3551 	RemoveStaleShardIdCacheEntries(cacheEntry);
3552 
3553 	for (int shardIndex = 0; shardIndex < cacheEntry->shardIntervalArrayLength;
3554 		 shardIndex++)
3555 	{
3556 		ShardInterval *shardInterval = cacheEntry->sortedShardIntervalArray[shardIndex];
3557 		GroupShardPlacement *placementArray =
3558 			cacheEntry->arrayOfPlacementArrays[shardIndex];
3559 		bool valueByVal = shardInterval->valueByVal;
3560 
3561 		/* delete the shard's placements */
3562 		if (placementArray != NULL)
3563 		{
3564 			pfree(placementArray);
3565 		}
3566 
3567 		/* delete data pointed to by ShardInterval */
3568 		if (!valueByVal)
3569 		{
3570 			if (shardInterval->minValueExists)
3571 			{
3572 				pfree(DatumGetPointer(shardInterval->minValue));
3573 			}
3574 
3575 			if (shardInterval->maxValueExists)
3576 			{
3577 				pfree(DatumGetPointer(shardInterval->maxValue));
3578 			}
3579 		}
3580 
3581 		/* and finally the ShardInterval itself */
3582 		pfree(shardInterval);
3583 	}
3584 
3585 	if (cacheEntry->sortedShardIntervalArray)
3586 	{
3587 		pfree(cacheEntry->sortedShardIntervalArray);
3588 		cacheEntry->sortedShardIntervalArray = NULL;
3589 	}
3590 	if (cacheEntry->arrayOfPlacementArrayLengths)
3591 	{
3592 		pfree(cacheEntry->arrayOfPlacementArrayLengths);
3593 		cacheEntry->arrayOfPlacementArrayLengths = NULL;
3594 	}
3595 	if (cacheEntry->arrayOfPlacementArrays)
3596 	{
3597 		pfree(cacheEntry->arrayOfPlacementArrays);
3598 		cacheEntry->arrayOfPlacementArrays = NULL;
3599 	}
3600 	if (cacheEntry->referencedRelationsViaForeignKey)
3601 	{
3602 		list_free(cacheEntry->referencedRelationsViaForeignKey);
3603 		cacheEntry->referencedRelationsViaForeignKey = NIL;
3604 	}
3605 	if (cacheEntry->referencingRelationsViaForeignKey)
3606 	{
3607 		list_free(cacheEntry->referencingRelationsViaForeignKey);
3608 		cacheEntry->referencingRelationsViaForeignKey = NIL;
3609 	}
3610 
3611 	cacheEntry->shardIntervalArrayLength = 0;
3612 	cacheEntry->hasUninitializedShardInterval = false;
3613 	cacheEntry->hasUniformHashDistribution = false;
3614 	cacheEntry->hasOverlappingShardInterval = false;
3615 
3616 	pfree(cacheEntry);
3617 }
3618 
3619 
3620 /*
3621  * RemoveShardIdCacheEntries removes all shard ID cache entries belonging to the
3622  * given table entry. If the shard ID belongs to a different (newer) table entry,
3623  * we leave it in place.
3624  */
3625 static void
RemoveStaleShardIdCacheEntries(CitusTableCacheEntry * invalidatedTableEntry)3626 RemoveStaleShardIdCacheEntries(CitusTableCacheEntry *invalidatedTableEntry)
3627 {
3628 	int shardIndex = 0;
3629 	int shardCount = invalidatedTableEntry->shardIntervalArrayLength;
3630 
3631 	for (shardIndex = 0; shardIndex < shardCount; shardIndex++)
3632 	{
3633 		ShardInterval *shardInterval =
3634 			invalidatedTableEntry->sortedShardIntervalArray[shardIndex];
3635 		int64 shardId = shardInterval->shardId;
3636 		bool foundInCache = false;
3637 
3638 		ShardIdCacheEntry *shardIdCacheEntry =
3639 			hash_search(ShardIdCacheHash, &shardId, HASH_FIND, &foundInCache);
3640 
3641 		if (foundInCache && shardIdCacheEntry->tableEntry == invalidatedTableEntry)
3642 		{
3643 			hash_search(ShardIdCacheHash, &shardId, HASH_REMOVE, &foundInCache);
3644 		}
3645 	}
3646 }
3647 
3648 
3649 /*
3650  * InvalidateForeignRelationGraphCacheCallback invalidates the foreign key relation
3651  * graph and entire distributed cache entries.
3652  */
3653 static void
InvalidateForeignRelationGraphCacheCallback(Datum argument,Oid relationId)3654 InvalidateForeignRelationGraphCacheCallback(Datum argument, Oid relationId)
3655 {
3656 	if (relationId == MetadataCache.distColocationRelationId)
3657 	{
3658 		SetForeignConstraintRelationshipGraphInvalid();
3659 		InvalidateDistTableCache();
3660 	}
3661 }
3662 
3663 
3664 /*
3665  * InvalidateForeignKeyGraph is used to invalidate the cached foreign key
3666  * graph (see ForeignKeyRelationGraph @ utils/foreign_key_relationship.c).
3667  *
3668  * To invalidate the foreign key graph, we hack around relcache invalidation
3669  * callbacks. Given that there is no metadata table associated with the foreign
3670  * key graph cache, we use pg_dist_colocation, which is never invalidated for
3671  * other purposes.
3672  *
3673  * We acknowledge that it is not a very intuitive way of implementing this cache
3674  * invalidation, but, seems acceptable for now. If this becomes problematic, we
3675  * could try using a magic oid where we're sure that no relation would ever use
3676  * that oid.
3677  */
3678 void
InvalidateForeignKeyGraph(void)3679 InvalidateForeignKeyGraph(void)
3680 {
3681 	if (!CitusHasBeenLoaded())
3682 	{
3683 		/*
3684 		 * We should not try to invalidate foreign key graph
3685 		 * if citus is not loaded.
3686 		 */
3687 		return;
3688 	}
3689 
3690 	CitusInvalidateRelcacheByRelid(DistColocationRelationId());
3691 
3692 	/* bump command counter to force invalidation to take effect */
3693 	CommandCounterIncrement();
3694 }
3695 
3696 
3697 /*
3698  * InvalidateDistRelationCacheCallback flushes cache entries when a relation
3699  * is updated (or flushes the entire cache).
3700  */
3701 static void
InvalidateDistRelationCacheCallback(Datum argument,Oid relationId)3702 InvalidateDistRelationCacheCallback(Datum argument, Oid relationId)
3703 {
3704 	/* invalidate either entire cache or a specific entry */
3705 	if (relationId == InvalidOid)
3706 	{
3707 		InvalidateDistTableCache();
3708 		InvalidateDistObjectCache();
3709 	}
3710 	else
3711 	{
3712 		void *hashKey = (void *) &relationId;
3713 		bool foundInCache = false;
3714 
3715 		CitusTableCacheEntrySlot *cacheSlot =
3716 			hash_search(DistTableCacheHash, hashKey, HASH_FIND, &foundInCache);
3717 		if (foundInCache)
3718 		{
3719 			InvalidateCitusTableCacheEntrySlot(cacheSlot);
3720 		}
3721 
3722 		/*
3723 		 * If pg_dist_partition is being invalidated drop all state
3724 		 * This happens pretty rarely, but most importantly happens during
3725 		 * DROP EXTENSION citus;
3726 		 */
3727 		if (relationId == MetadataCache.distPartitionRelationId)
3728 		{
3729 			InvalidateMetadataSystemCache();
3730 		}
3731 
3732 		if (relationId == MetadataCache.distObjectRelationId)
3733 		{
3734 			InvalidateDistObjectCache();
3735 		}
3736 	}
3737 }
3738 
3739 
3740 /*
3741  * InvalidateCitusTableCacheEntrySlot marks a CitusTableCacheEntrySlot as invalid,
3742  * meaning it needs to be rebuilt and the citusTableMetadata (if any) should be
3743  * released.
3744  */
3745 static void
InvalidateCitusTableCacheEntrySlot(CitusTableCacheEntrySlot * cacheSlot)3746 InvalidateCitusTableCacheEntrySlot(CitusTableCacheEntrySlot *cacheSlot)
3747 {
3748 	/* recheck whether this is a distributed table */
3749 	cacheSlot->isValid = false;
3750 
3751 	if (cacheSlot->citusTableMetadata != NULL)
3752 	{
3753 		/* reload the metadata */
3754 		cacheSlot->citusTableMetadata->isValid = false;
3755 	}
3756 }
3757 
3758 
3759 /*
3760  * InvalidateDistTableCache marks all DistTableCacheHash entries invalid.
3761  */
3762 static void
InvalidateDistTableCache(void)3763 InvalidateDistTableCache(void)
3764 {
3765 	CitusTableCacheEntrySlot *cacheSlot = NULL;
3766 	HASH_SEQ_STATUS status;
3767 
3768 	hash_seq_init(&status, DistTableCacheHash);
3769 
3770 	while ((cacheSlot = (CitusTableCacheEntrySlot *) hash_seq_search(&status)) != NULL)
3771 	{
3772 		InvalidateCitusTableCacheEntrySlot(cacheSlot);
3773 	}
3774 }
3775 
3776 
3777 /*
3778  * InvalidateDistObjectCache marks all DistObjectCacheHash entries invalid.
3779  */
3780 static void
InvalidateDistObjectCache(void)3781 InvalidateDistObjectCache(void)
3782 {
3783 	DistObjectCacheEntry *cacheEntry = NULL;
3784 	HASH_SEQ_STATUS status;
3785 
3786 	hash_seq_init(&status, DistObjectCacheHash);
3787 
3788 	while ((cacheEntry = (DistObjectCacheEntry *) hash_seq_search(&status)) != NULL)
3789 	{
3790 		cacheEntry->isValid = false;
3791 	}
3792 }
3793 
3794 
3795 /*
3796  * FlushDistTableCache flushes the entire distributed relation cache, frees
3797  * all entries, and recreates the cache.
3798  */
3799 void
FlushDistTableCache(void)3800 FlushDistTableCache(void)
3801 {
3802 	CitusTableCacheEntrySlot *cacheSlot = NULL;
3803 	HASH_SEQ_STATUS status;
3804 
3805 	hash_seq_init(&status, DistTableCacheHash);
3806 
3807 	while ((cacheSlot = (CitusTableCacheEntrySlot *) hash_seq_search(&status)) != NULL)
3808 	{
3809 		ResetCitusTableCacheEntry(cacheSlot->citusTableMetadata);
3810 	}
3811 
3812 	hash_destroy(DistTableCacheHash);
3813 	hash_destroy(ShardIdCacheHash);
3814 	CreateDistTableCache();
3815 	CreateShardIdCache();
3816 }
3817 
3818 
3819 /* CreateDistTableCache initializes the per-table hash table */
3820 static void
CreateDistTableCache(void)3821 CreateDistTableCache(void)
3822 {
3823 	HASHCTL info;
3824 	MemSet(&info, 0, sizeof(info));
3825 	info.keysize = sizeof(Oid);
3826 	info.entrysize = sizeof(CitusTableCacheEntrySlot);
3827 	info.hash = tag_hash;
3828 	info.hcxt = MetadataCacheMemoryContext;
3829 	DistTableCacheHash =
3830 		hash_create("Distributed Relation Cache", 32, &info,
3831 					HASH_ELEM | HASH_FUNCTION | HASH_CONTEXT);
3832 }
3833 
3834 
3835 /* CreateShardIdCache initializes the shard ID mapping */
3836 static void
CreateShardIdCache(void)3837 CreateShardIdCache(void)
3838 {
3839 	HASHCTL info;
3840 	MemSet(&info, 0, sizeof(info));
3841 	info.keysize = sizeof(int64);
3842 	info.entrysize = sizeof(ShardIdCacheEntry);
3843 	info.hash = tag_hash;
3844 	info.hcxt = MetadataCacheMemoryContext;
3845 	ShardIdCacheHash =
3846 		hash_create("Shard Id Cache", 128, &info,
3847 					HASH_ELEM | HASH_FUNCTION | HASH_CONTEXT);
3848 }
3849 
3850 
3851 /* CreateDistObjectCache initializes the per-object hash table */
3852 static void
CreateDistObjectCache(void)3853 CreateDistObjectCache(void)
3854 {
3855 	HASHCTL info;
3856 	MemSet(&info, 0, sizeof(info));
3857 	info.keysize = sizeof(DistObjectCacheEntryKey);
3858 	info.entrysize = sizeof(DistObjectCacheEntry);
3859 	info.hash = tag_hash;
3860 	info.hcxt = MetadataCacheMemoryContext;
3861 	DistObjectCacheHash =
3862 		hash_create("Distributed Object Cache", 32, &info,
3863 					HASH_ELEM | HASH_FUNCTION | HASH_CONTEXT);
3864 }
3865 
3866 
3867 /*
3868  * InvalidateMetadataSystemCache resets all the cached OIDs and the extensionLoaded flag,
3869  * and invalidates the worker node, ConnParams, and local group ID caches.
3870  */
3871 void
InvalidateMetadataSystemCache(void)3872 InvalidateMetadataSystemCache(void)
3873 {
3874 	InvalidateConnParamsHashEntries();
3875 
3876 	memset(&MetadataCache, 0, sizeof(MetadataCache));
3877 	workerNodeHashValid = false;
3878 	LocalGroupId = -1;
3879 }
3880 
3881 
3882 /*
3883  * AllCitusTableIds returns all citus table ids.
3884  */
3885 List *
AllCitusTableIds(void)3886 AllCitusTableIds(void)
3887 {
3888 	return CitusTableTypeIdList(ANY_CITUS_TABLE_TYPE);
3889 }
3890 
3891 
3892 /*
3893  * CitusTableTypeIdList function scans pg_dist_partition and returns a
3894  * list of OID's for the tables matching given citusTableType.
3895  * To create the list, it performs sequential scan. Since it is not expected
3896  * that this function will be called frequently, it is OK not to use index
3897  * scan. If this function becomes performance bottleneck, it is possible to
3898  * modify this function to perform index scan.
3899  */
3900 List *
CitusTableTypeIdList(CitusTableType citusTableType)3901 CitusTableTypeIdList(CitusTableType citusTableType)
3902 {
3903 	ScanKeyData scanKey[1];
3904 	int scanKeyCount = 0;
3905 	List *relationIdList = NIL;
3906 
3907 	Relation pgDistPartition = table_open(DistPartitionRelationId(), AccessShareLock);
3908 
3909 	SysScanDesc scanDescriptor = systable_beginscan(pgDistPartition,
3910 													InvalidOid, false,
3911 													NULL, scanKeyCount, scanKey);
3912 
3913 	TupleDesc tupleDescriptor = RelationGetDescr(pgDistPartition);
3914 
3915 	HeapTuple heapTuple = systable_getnext(scanDescriptor);
3916 	while (HeapTupleIsValid(heapTuple))
3917 	{
3918 		bool isNull = false;
3919 
3920 		Datum partMethodDatum =
3921 			heap_getattr(heapTuple, Anum_pg_dist_partition_partmethod,
3922 						 tupleDescriptor, &isNull);
3923 		Datum replicationModelDatum =
3924 			heap_getattr(heapTuple, Anum_pg_dist_partition_repmodel,
3925 						 tupleDescriptor, &isNull);
3926 
3927 		Oid partitionMethod = DatumGetChar(partMethodDatum);
3928 		Oid replicationModel = DatumGetChar(replicationModelDatum);
3929 
3930 		if (IsCitusTableTypeInternal(partitionMethod, replicationModel, citusTableType))
3931 		{
3932 			Datum relationIdDatum = heap_getattr(heapTuple,
3933 												 Anum_pg_dist_partition_logicalrelid,
3934 												 tupleDescriptor, &isNull);
3935 
3936 			Oid relationId = DatumGetObjectId(relationIdDatum);
3937 
3938 			relationIdList = lappend_oid(relationIdList, relationId);
3939 		}
3940 
3941 		heapTuple = systable_getnext(scanDescriptor);
3942 	}
3943 
3944 	systable_endscan(scanDescriptor);
3945 	table_close(pgDistPartition, AccessShareLock);
3946 
3947 	return relationIdList;
3948 }
3949 
3950 
3951 /*
3952  * ClusterHasReferenceTable returns true if the cluster has
3953  * any reference table.
3954  */
3955 bool
ClusterHasReferenceTable(void)3956 ClusterHasReferenceTable(void)
3957 {
3958 	return list_length(CitusTableTypeIdList(REFERENCE_TABLE)) > 0;
3959 }
3960 
3961 
3962 /*
3963  * InvalidateNodeRelationCacheCallback destroys the WorkerNodeHash when
3964  * any change happens on pg_dist_node table. It also set WorkerNodeHash to
3965  * NULL, which allows consequent accesses to the hash read from the
3966  * pg_dist_node from scratch.
3967  */
3968 static void
InvalidateNodeRelationCacheCallback(Datum argument,Oid relationId)3969 InvalidateNodeRelationCacheCallback(Datum argument, Oid relationId)
3970 {
3971 	if (relationId == InvalidOid || relationId == MetadataCache.distNodeRelationId)
3972 	{
3973 		workerNodeHashValid = false;
3974 	}
3975 }
3976 
3977 
3978 /*
3979  * InvalidateLocalGroupIdRelationCacheCallback sets the LocalGroupId to
3980  * the default value.
3981  */
3982 static void
InvalidateLocalGroupIdRelationCacheCallback(Datum argument,Oid relationId)3983 InvalidateLocalGroupIdRelationCacheCallback(Datum argument, Oid relationId)
3984 {
3985 	/* when invalidation happens simply set the LocalGroupId to the default value */
3986 	if (relationId == InvalidOid || relationId == MetadataCache.distLocalGroupRelationId)
3987 	{
3988 		LocalGroupId = -1;
3989 	}
3990 }
3991 
3992 
3993 /*
3994  * CitusTableCacheFlushInvalidatedEntries frees invalidated cache entries.
3995  * Invalidated entries aren't freed immediately as callers expect their lifetime
3996  * to extend beyond that scope.
3997  */
3998 void
CitusTableCacheFlushInvalidatedEntries()3999 CitusTableCacheFlushInvalidatedEntries()
4000 {
4001 	if (DistTableCacheHash != NULL && DistTableCacheExpired != NIL)
4002 	{
4003 		CitusTableCacheEntry *cacheEntry = NULL;
4004 		foreach_ptr(cacheEntry, DistTableCacheExpired)
4005 		{
4006 			ResetCitusTableCacheEntry(cacheEntry);
4007 		}
4008 		list_free(DistTableCacheExpired);
4009 		DistTableCacheExpired = NIL;
4010 	}
4011 }
4012 
4013 
4014 /*
4015  * CitusTableCacheEntryReleaseCallback frees invalidated cache entries.
4016  */
4017 static void
CitusTableCacheEntryReleaseCallback(ResourceReleasePhase phase,bool isCommit,bool isTopLevel,void * arg)4018 CitusTableCacheEntryReleaseCallback(ResourceReleasePhase phase, bool isCommit,
4019 									bool isTopLevel, void *arg)
4020 {
4021 	if (isTopLevel && phase == RESOURCE_RELEASE_LOCKS)
4022 	{
4023 		CitusTableCacheFlushInvalidatedEntries();
4024 	}
4025 }
4026 
4027 
4028 /*
4029  * LookupDistPartitionTuple searches pg_dist_partition for relationId's entry
4030  * and returns that or, if no matching entry was found, NULL.
4031  */
4032 static HeapTuple
LookupDistPartitionTuple(Relation pgDistPartition,Oid relationId)4033 LookupDistPartitionTuple(Relation pgDistPartition, Oid relationId)
4034 {
4035 	HeapTuple distPartitionTuple = NULL;
4036 	ScanKeyData scanKey[1];
4037 
4038 	/* copy scankey to local copy, it will be modified during the scan */
4039 	scanKey[0] = DistPartitionScanKey[0];
4040 
4041 	/* set scan arguments */
4042 	scanKey[0].sk_argument = ObjectIdGetDatum(relationId);
4043 
4044 	SysScanDesc scanDescriptor = systable_beginscan(pgDistPartition,
4045 													DistPartitionLogicalRelidIndexId(),
4046 													true, NULL, 1, scanKey);
4047 
4048 	HeapTuple currentPartitionTuple = systable_getnext(scanDescriptor);
4049 	if (HeapTupleIsValid(currentPartitionTuple))
4050 	{
4051 		distPartitionTuple = heap_copytuple(currentPartitionTuple);
4052 	}
4053 
4054 	systable_endscan(scanDescriptor);
4055 
4056 	return distPartitionTuple;
4057 }
4058 
4059 
4060 /*
4061  * LookupDistShardTuples returns a list of all dist_shard tuples for the
4062  * specified relation.
4063  */
4064 List *
LookupDistShardTuples(Oid relationId)4065 LookupDistShardTuples(Oid relationId)
4066 {
4067 	List *distShardTupleList = NIL;
4068 	ScanKeyData scanKey[1];
4069 
4070 	Relation pgDistShard = table_open(DistShardRelationId(), AccessShareLock);
4071 
4072 	/* copy scankey to local copy, it will be modified during the scan */
4073 	scanKey[0] = DistShardScanKey[0];
4074 
4075 	/* set scan arguments */
4076 	scanKey[0].sk_argument = ObjectIdGetDatum(relationId);
4077 
4078 	SysScanDesc scanDescriptor = systable_beginscan(pgDistShard,
4079 													DistShardLogicalRelidIndexId(), true,
4080 													NULL, 1, scanKey);
4081 
4082 	HeapTuple currentShardTuple = systable_getnext(scanDescriptor);
4083 	while (HeapTupleIsValid(currentShardTuple))
4084 	{
4085 		HeapTuple shardTupleCopy = heap_copytuple(currentShardTuple);
4086 		distShardTupleList = lappend(distShardTupleList, shardTupleCopy);
4087 
4088 		currentShardTuple = systable_getnext(scanDescriptor);
4089 	}
4090 
4091 	systable_endscan(scanDescriptor);
4092 	table_close(pgDistShard, AccessShareLock);
4093 
4094 	return distShardTupleList;
4095 }
4096 
4097 
4098 /*
4099  * LookupShardRelationFromCatalog returns the logical relation oid a shard belongs to.
4100  *
4101  * Errors out if the shardId does not exist and missingOk is false.
4102  * Returns InvalidOid if the shardId does not exist and missingOk is true.
4103  */
4104 Oid
LookupShardRelationFromCatalog(int64 shardId,bool missingOk)4105 LookupShardRelationFromCatalog(int64 shardId, bool missingOk)
4106 {
4107 	ScanKeyData scanKey[1];
4108 	int scanKeyCount = 1;
4109 	Form_pg_dist_shard shardForm = NULL;
4110 	Relation pgDistShard = table_open(DistShardRelationId(), AccessShareLock);
4111 	Oid relationId = InvalidOid;
4112 
4113 	ScanKeyInit(&scanKey[0], Anum_pg_dist_shard_shardid,
4114 				BTEqualStrategyNumber, F_INT8EQ, Int64GetDatum(shardId));
4115 
4116 	SysScanDesc scanDescriptor = systable_beginscan(pgDistShard,
4117 													DistShardShardidIndexId(), true,
4118 													NULL, scanKeyCount, scanKey);
4119 
4120 	HeapTuple heapTuple = systable_getnext(scanDescriptor);
4121 	if (!HeapTupleIsValid(heapTuple) && !missingOk)
4122 	{
4123 		ereport(ERROR, (errmsg("could not find valid entry for shard "
4124 							   UINT64_FORMAT, shardId)));
4125 	}
4126 
4127 	if (!HeapTupleIsValid(heapTuple))
4128 	{
4129 		relationId = InvalidOid;
4130 	}
4131 	else
4132 	{
4133 		shardForm = (Form_pg_dist_shard) GETSTRUCT(heapTuple);
4134 		relationId = shardForm->logicalrelid;
4135 	}
4136 
4137 	systable_endscan(scanDescriptor);
4138 	table_close(pgDistShard, NoLock);
4139 
4140 	return relationId;
4141 }
4142 
4143 
4144 /*
4145  * ShardExists returns whether the given shard ID exists in pg_dist_shard.
4146  */
4147 bool
ShardExists(int64 shardId)4148 ShardExists(int64 shardId)
4149 {
4150 	ScanKeyData scanKey[1];
4151 	int scanKeyCount = 1;
4152 	Relation pgDistShard = table_open(DistShardRelationId(), AccessShareLock);
4153 	bool shardExists = false;
4154 
4155 	ScanKeyInit(&scanKey[0], Anum_pg_dist_shard_shardid,
4156 				BTEqualStrategyNumber, F_INT8EQ, Int64GetDatum(shardId));
4157 
4158 	SysScanDesc scanDescriptor = systable_beginscan(pgDistShard,
4159 													DistShardShardidIndexId(), true,
4160 													NULL, scanKeyCount, scanKey);
4161 
4162 	HeapTuple heapTuple = systable_getnext(scanDescriptor);
4163 	if (HeapTupleIsValid(heapTuple))
4164 	{
4165 		shardExists = true;
4166 	}
4167 
4168 	systable_endscan(scanDescriptor);
4169 	table_close(pgDistShard, NoLock);
4170 
4171 	return shardExists;
4172 }
4173 
4174 
4175 /*
4176  * GetPartitionTypeInputInfo populates output parameters with the interval type
4177  * identifier and modifier for the specified partition key/method combination.
4178  */
4179 static void
GetPartitionTypeInputInfo(char * partitionKeyString,char partitionMethod,Oid * columnTypeId,int32 * columnTypeMod,Oid * intervalTypeId,int32 * intervalTypeMod)4180 GetPartitionTypeInputInfo(char *partitionKeyString, char partitionMethod,
4181 						  Oid *columnTypeId, int32 *columnTypeMod,
4182 						  Oid *intervalTypeId, int32 *intervalTypeMod)
4183 {
4184 	*columnTypeId = InvalidOid;
4185 	*columnTypeMod = -1;
4186 	*intervalTypeId = InvalidOid;
4187 	*intervalTypeMod = -1;
4188 
4189 	switch (partitionMethod)
4190 	{
4191 		case DISTRIBUTE_BY_APPEND:
4192 		case DISTRIBUTE_BY_RANGE:
4193 		{
4194 			Node *partitionNode = stringToNode(partitionKeyString);
4195 			Var *partitionColumn = (Var *) partitionNode;
4196 			Assert(IsA(partitionNode, Var));
4197 
4198 			GetIntervalTypeInfo(partitionMethod, partitionColumn,
4199 								intervalTypeId, intervalTypeMod);
4200 
4201 			*columnTypeId = partitionColumn->vartype;
4202 			*columnTypeMod = partitionColumn->vartypmod;
4203 			break;
4204 		}
4205 
4206 		case DISTRIBUTE_BY_HASH:
4207 		{
4208 			Node *partitionNode = stringToNode(partitionKeyString);
4209 			Var *partitionColumn = (Var *) partitionNode;
4210 			Assert(IsA(partitionNode, Var));
4211 
4212 			GetIntervalTypeInfo(partitionMethod, partitionColumn,
4213 								intervalTypeId, intervalTypeMod);
4214 
4215 			*columnTypeId = partitionColumn->vartype;
4216 			*columnTypeMod = partitionColumn->vartypmod;
4217 			break;
4218 		}
4219 
4220 		case DISTRIBUTE_BY_NONE:
4221 		{
4222 			break;
4223 		}
4224 
4225 		default:
4226 		{
4227 			ereport(ERROR, (errcode(ERRCODE_FEATURE_NOT_SUPPORTED),
4228 							errmsg("unsupported table partition type: %c",
4229 								   partitionMethod)));
4230 		}
4231 	}
4232 }
4233 
4234 
4235 /*
4236  * GetIntervalTypeInfo gets type id and type mod of the min/max values
4237  * of shard intervals for a distributed table with given partition method
4238  * and partition column.
4239  */
4240 void
GetIntervalTypeInfo(char partitionMethod,Var * partitionColumn,Oid * intervalTypeId,int32 * intervalTypeMod)4241 GetIntervalTypeInfo(char partitionMethod, Var *partitionColumn,
4242 					Oid *intervalTypeId, int32 *intervalTypeMod)
4243 {
4244 	*intervalTypeId = InvalidOid;
4245 	*intervalTypeMod = -1;
4246 
4247 	switch (partitionMethod)
4248 	{
4249 		case DISTRIBUTE_BY_APPEND:
4250 		case DISTRIBUTE_BY_RANGE:
4251 		{
4252 			*intervalTypeId = partitionColumn->vartype;
4253 			*intervalTypeMod = partitionColumn->vartypmod;
4254 			break;
4255 		}
4256 
4257 		case DISTRIBUTE_BY_HASH:
4258 		{
4259 			*intervalTypeId = INT4OID;
4260 			break;
4261 		}
4262 
4263 		default:
4264 		{
4265 			break;
4266 		}
4267 	}
4268 }
4269 
4270 
4271 /*
4272  * TupleToShardInterval transforms the specified dist_shard tuple into a new
4273  * ShardInterval using the provided descriptor and partition type information.
4274  */
4275 ShardInterval *
TupleToShardInterval(HeapTuple heapTuple,TupleDesc tupleDescriptor,Oid intervalTypeId,int32 intervalTypeMod)4276 TupleToShardInterval(HeapTuple heapTuple, TupleDesc tupleDescriptor, Oid
4277 					 intervalTypeId,
4278 					 int32 intervalTypeMod)
4279 {
4280 	Datum datumArray[Natts_pg_dist_shard];
4281 	bool isNullArray[Natts_pg_dist_shard];
4282 
4283 	/*
4284 	 * We use heap_deform_tuple() instead of heap_getattr() to expand tuple
4285 	 * to contain missing values when ALTER TABLE ADD COLUMN happens.
4286 	 */
4287 	heap_deform_tuple(heapTuple, tupleDescriptor, datumArray, isNullArray);
4288 
4289 	ShardInterval *shardInterval =
4290 		DeformedDistShardTupleToShardInterval(datumArray, isNullArray,
4291 											  intervalTypeId, intervalTypeMod);
4292 
4293 	return shardInterval;
4294 }
4295 
4296 
4297 /*
4298  * DeformedDistShardTupleToShardInterval transforms the specified deformed
4299  * pg_dist_shard tuple into a new ShardInterval.
4300  */
4301 ShardInterval *
DeformedDistShardTupleToShardInterval(Datum * datumArray,bool * isNullArray,Oid intervalTypeId,int32 intervalTypeMod)4302 DeformedDistShardTupleToShardInterval(Datum *datumArray, bool *isNullArray,
4303 									  Oid intervalTypeId, int32 intervalTypeMod)
4304 {
4305 	Oid inputFunctionId = InvalidOid;
4306 	Oid typeIoParam = InvalidOid;
4307 	Datum minValue = 0;
4308 	Datum maxValue = 0;
4309 	bool minValueExists = false;
4310 	bool maxValueExists = false;
4311 	int16 intervalTypeLen = 0;
4312 	bool intervalByVal = false;
4313 	char intervalAlign = '0';
4314 	char intervalDelim = '0';
4315 
4316 	Oid relationId =
4317 		DatumGetObjectId(datumArray[Anum_pg_dist_shard_logicalrelid - 1]);
4318 	int64 shardId = DatumGetInt64(datumArray[Anum_pg_dist_shard_shardid - 1]);
4319 	char storageType = DatumGetChar(datumArray[Anum_pg_dist_shard_shardstorage - 1]);
4320 	Datum minValueTextDatum = datumArray[Anum_pg_dist_shard_shardminvalue - 1];
4321 	Datum maxValueTextDatum = datumArray[Anum_pg_dist_shard_shardmaxvalue - 1];
4322 
4323 	bool minValueNull = isNullArray[Anum_pg_dist_shard_shardminvalue - 1];
4324 	bool maxValueNull = isNullArray[Anum_pg_dist_shard_shardmaxvalue - 1];
4325 
4326 	if (!minValueNull && !maxValueNull)
4327 	{
4328 		char *minValueString = TextDatumGetCString(minValueTextDatum);
4329 		char *maxValueString = TextDatumGetCString(maxValueTextDatum);
4330 
4331 		/* TODO: move this up the call stack to avoid per-tuple invocation? */
4332 		get_type_io_data(intervalTypeId, IOFunc_input, &intervalTypeLen,
4333 						 &intervalByVal,
4334 						 &intervalAlign, &intervalDelim, &typeIoParam,
4335 						 &inputFunctionId);
4336 
4337 		/* finally convert min/max values to their actual types */
4338 		minValue = OidInputFunctionCall(inputFunctionId, minValueString,
4339 										typeIoParam, intervalTypeMod);
4340 		maxValue = OidInputFunctionCall(inputFunctionId, maxValueString,
4341 										typeIoParam, intervalTypeMod);
4342 
4343 		minValueExists = true;
4344 		maxValueExists = true;
4345 	}
4346 
4347 	ShardInterval *shardInterval = CitusMakeNode(ShardInterval);
4348 	shardInterval->relationId = relationId;
4349 	shardInterval->storageType = storageType;
4350 	shardInterval->valueTypeId = intervalTypeId;
4351 	shardInterval->valueTypeLen = intervalTypeLen;
4352 	shardInterval->valueByVal = intervalByVal;
4353 	shardInterval->minValueExists = minValueExists;
4354 	shardInterval->maxValueExists = maxValueExists;
4355 	shardInterval->minValue = minValue;
4356 	shardInterval->maxValue = maxValue;
4357 	shardInterval->shardId = shardId;
4358 
4359 	return shardInterval;
4360 }
4361 
4362 
4363 /*
4364  * CachedNamespaceLookup performs a cached lookup for the namespace (schema), with the
4365  * result cached in cachedOid.
4366  */
4367 static void
CachedNamespaceLookup(const char * nspname,Oid * cachedOid)4368 CachedNamespaceLookup(const char *nspname, Oid *cachedOid)
4369 {
4370 	/* force callbacks to be registered, so we always get notified upon changes */
4371 	InitializeCaches();
4372 
4373 	if (*cachedOid == InvalidOid)
4374 	{
4375 		*cachedOid = get_namespace_oid(nspname, true);
4376 
4377 		if (*cachedOid == InvalidOid)
4378 		{
4379 			ereport(ERROR, (errmsg(
4380 								"cache lookup failed for namespace %s, called too early?",
4381 								nspname)));
4382 		}
4383 	}
4384 }
4385 
4386 
4387 /*
4388  * CachedRelationLookup performs a cached lookup for the relation
4389  * relationName, with the result cached in *cachedOid.
4390  */
4391 static void
CachedRelationLookup(const char * relationName,Oid * cachedOid)4392 CachedRelationLookup(const char *relationName, Oid *cachedOid)
4393 {
4394 	CachedRelationNamespaceLookup(relationName, PG_CATALOG_NAMESPACE, cachedOid);
4395 }
4396 
4397 
4398 static void
CachedRelationNamespaceLookup(const char * relationName,Oid relnamespace,Oid * cachedOid)4399 CachedRelationNamespaceLookup(const char *relationName, Oid relnamespace,
4400 							  Oid *cachedOid)
4401 {
4402 	/* force callbacks to be registered, so we always get notified upon changes */
4403 	InitializeCaches();
4404 
4405 	if (*cachedOid == InvalidOid)
4406 	{
4407 		*cachedOid = get_relname_relid(relationName, relnamespace);
4408 
4409 		if (*cachedOid == InvalidOid)
4410 		{
4411 			ereport(ERROR, (errmsg(
4412 								"cache lookup failed for %s, called too early?",
4413 								relationName)));
4414 		}
4415 	}
4416 }
4417 
4418 
4419 /*
4420  * RelationExists returns whether a relation with the given OID exists.
4421  */
4422 bool
RelationExists(Oid relationId)4423 RelationExists(Oid relationId)
4424 {
4425 	HeapTuple relTuple = SearchSysCache1(RELOID, ObjectIdGetDatum(relationId));
4426 
4427 	bool relationExists = HeapTupleIsValid(relTuple);
4428 	if (relationExists)
4429 	{
4430 		ReleaseSysCache(relTuple);
4431 	}
4432 
4433 	return relationExists;
4434 }
4435 
4436 
4437 /*
4438  * Register a relcache invalidation for a non-shared relation.
4439  *
4440  * We ignore the case that there's no corresponding pg_class entry - that
4441  * happens if we register a relcache invalidation (e.g. for a
4442  * pg_dist_partition deletion) after the relation has been dropped. That's ok,
4443  * because in those cases we're guaranteed to already have registered an
4444  * invalidation for the target relation.
4445  */
4446 void
CitusInvalidateRelcacheByRelid(Oid relationId)4447 CitusInvalidateRelcacheByRelid(Oid relationId)
4448 {
4449 	HeapTuple classTuple = SearchSysCache1(RELOID, ObjectIdGetDatum(relationId));
4450 
4451 	if (HeapTupleIsValid(classTuple))
4452 	{
4453 		CacheInvalidateRelcacheByTuple(classTuple);
4454 		ReleaseSysCache(classTuple);
4455 	}
4456 }
4457 
4458 
4459 /*
4460  * Register a relcache invalidation for the distributed relation associated
4461  * with the shard.
4462  */
4463 void
CitusInvalidateRelcacheByShardId(int64 shardId)4464 CitusInvalidateRelcacheByShardId(int64 shardId)
4465 {
4466 	ScanKeyData scanKey[1];
4467 	int scanKeyCount = 1;
4468 	Form_pg_dist_shard shardForm = NULL;
4469 	Relation pgDistShard = table_open(DistShardRelationId(), AccessShareLock);
4470 
4471 	/*
4472 	 * Load shard, to find the associated relation id. Can't use
4473 	 * LoadShardInterval directly because that'd fail if the shard doesn't
4474 	 * exist anymore, which we can't have. Also lower overhead is desirable
4475 	 * here.
4476 	 */
4477 
4478 	ScanKeyInit(&scanKey[0], Anum_pg_dist_shard_shardid,
4479 				BTEqualStrategyNumber, F_INT8EQ, Int64GetDatum(shardId));
4480 
4481 	SysScanDesc scanDescriptor = systable_beginscan(pgDistShard,
4482 													DistShardShardidIndexId(), true,
4483 													NULL, scanKeyCount, scanKey);
4484 
4485 	HeapTuple heapTuple = systable_getnext(scanDescriptor);
4486 	if (HeapTupleIsValid(heapTuple))
4487 	{
4488 		shardForm = (Form_pg_dist_shard) GETSTRUCT(heapTuple);
4489 		CitusInvalidateRelcacheByRelid(shardForm->logicalrelid);
4490 	}
4491 	else
4492 	{
4493 		/*
4494 		 * Couldn't find associated relation. That can primarily happen in two cases:
4495 		 *
4496 		 * 1) A placement row is inserted before the shard row. That's fine,
4497 		 *	  since we don't need invalidations via placements in that case.
4498 		 *
4499 		 * 2) The shard has been deleted, but some placements were
4500 		 *    unreachable, and the user is manually deleting the rows. Not
4501 		 *    much point in WARNING or ERRORing in that case either, there's
4502 		 *    nothing to invalidate.
4503 		 *
4504 		 * Hence we just emit a DEBUG5 message.
4505 		 */
4506 		ereport(DEBUG5, (errmsg(
4507 							 "could not find distributed relation to invalidate for "
4508 							 "shard "INT64_FORMAT, shardId)));
4509 	}
4510 
4511 	systable_endscan(scanDescriptor);
4512 	table_close(pgDistShard, NoLock);
4513 
4514 	/* bump command counter, to force invalidation to take effect */
4515 	CommandCounterIncrement();
4516 }
4517 
4518 
4519 /*
4520  * DistNodeMetadata returns the single metadata jsonb object stored in
4521  * pg_dist_node_metadata.
4522  */
4523 Datum
DistNodeMetadata(void)4524 DistNodeMetadata(void)
4525 {
4526 	Datum metadata = 0;
4527 	ScanKeyData scanKey[1];
4528 	const int scanKeyCount = 0;
4529 
4530 	Oid metadataTableOid = get_relname_relid("pg_dist_node_metadata",
4531 											 PG_CATALOG_NAMESPACE);
4532 	if (metadataTableOid == InvalidOid)
4533 	{
4534 		ereport(ERROR, (errmsg("pg_dist_node_metadata was not found")));
4535 	}
4536 
4537 	Relation pgDistNodeMetadata = table_open(metadataTableOid, AccessShareLock);
4538 	SysScanDesc scanDescriptor = systable_beginscan(pgDistNodeMetadata,
4539 													InvalidOid, false,
4540 													NULL, scanKeyCount, scanKey);
4541 	TupleDesc tupleDescriptor = RelationGetDescr(pgDistNodeMetadata);
4542 
4543 	HeapTuple heapTuple = systable_getnext(scanDescriptor);
4544 	if (HeapTupleIsValid(heapTuple))
4545 	{
4546 		bool isNull = false;
4547 		metadata = heap_getattr(heapTuple, Anum_pg_dist_node_metadata_metadata,
4548 								tupleDescriptor, &isNull);
4549 		Assert(!isNull);
4550 	}
4551 	else
4552 	{
4553 		ereport(ERROR, (errmsg(
4554 							"could not find any entries in pg_dist_metadata")));
4555 	}
4556 
4557 	systable_endscan(scanDescriptor);
4558 	table_close(pgDistNodeMetadata, AccessShareLock);
4559 
4560 	return metadata;
4561 }
4562 
4563 
4564 /*
4565  * role_exists is a check constraint which ensures that roles referenced in the
4566  * pg_dist_authinfo catalog actually exist (at least at the time of insertion).
4567  */
4568 Datum
role_exists(PG_FUNCTION_ARGS)4569 role_exists(PG_FUNCTION_ARGS)
4570 {
4571 	Name roleName = PG_GETARG_NAME(0);
4572 	bool roleExists = SearchSysCacheExists1(AUTHNAME, NameGetDatum(roleName));
4573 
4574 	PG_RETURN_BOOL(roleExists);
4575 }
4576 
4577 
4578 /*
4579  * authinfo_valid is a check constraint which errors on all rows, intended for
4580  * use in prohibiting writes to pg_dist_authinfo in Citus Community.
4581  */
4582 Datum
authinfo_valid(PG_FUNCTION_ARGS)4583 authinfo_valid(PG_FUNCTION_ARGS)
4584 {
4585 	ereport(ERROR, (errcode(ERRCODE_FEATURE_NOT_SUPPORTED),
4586 					errmsg("cannot write to pg_dist_authinfo"),
4587 					errdetail(
4588 						"Citus Community Edition does not support the use of "
4589 						"custom authentication options."),
4590 					errhint(
4591 						"To learn more about using advanced authentication schemes "
4592 						"with Citus, please contact us at "
4593 						"https://citusdata.com/about/contact_us")));
4594 }
4595 
4596 
4597 /*
4598  * poolinfo_valid is a check constraint which errors on all rows, intended for
4599  * use in prohibiting writes to pg_dist_poolinfo in Citus Community.
4600  */
4601 Datum
poolinfo_valid(PG_FUNCTION_ARGS)4602 poolinfo_valid(PG_FUNCTION_ARGS)
4603 {
4604 	ereport(ERROR, (errcode(ERRCODE_FEATURE_NOT_SUPPORTED),
4605 					errmsg("cannot write to pg_dist_poolinfo"),
4606 					errdetail(
4607 						"Citus Community Edition does not support the use of "
4608 						"pooler options."),
4609 					errhint("To learn more about using advanced pooling schemes "
4610 							"with Citus, please contact us at "
4611 							"https://citusdata.com/about/contact_us")));
4612 }
4613