1 /*-------------------------------------------------------------------------
2  *
3  * metadata_utility.c
4  *    Routines for reading and modifying master node's metadata.
5  *
6  * Copyright (c) Citus Data, Inc.
7  *
8  * $Id$
9  *
10  *-------------------------------------------------------------------------
11  */
12 
13 #include <sys/statvfs.h>
14 
15 #include "postgres.h"
16 #include "funcapi.h"
17 #include "libpq-fe.h"
18 #include "miscadmin.h"
19 
20 #include "distributed/pg_version_constants.h"
21 
22 #include "access/genam.h"
23 #include "access/htup_details.h"
24 #include "access/sysattr.h"
25 #include "access/xact.h"
26 #include "catalog/dependency.h"
27 #include "catalog/indexing.h"
28 #include "catalog/pg_constraint.h"
29 #include "catalog/pg_extension.h"
30 #include "catalog/pg_namespace.h"
31 #include "catalog/pg_type.h"
32 #include "commands/extension.h"
33 #include "distributed/colocation_utils.h"
34 #include "distributed/connection_management.h"
35 #include "distributed/citus_nodes.h"
36 #include "distributed/citus_safe_lib.h"
37 #include "distributed/listutils.h"
38 #include "distributed/lock_graph.h"
39 #include "distributed/metadata_utility.h"
40 #include "distributed/coordinator_protocol.h"
41 #include "distributed/metadata_cache.h"
42 #include "distributed/multi_join_order.h"
43 #include "distributed/multi_logical_optimizer.h"
44 #include "distributed/multi_partitioning_utils.h"
45 #include "distributed/multi_physical_planner.h"
46 #include "distributed/pg_dist_colocation.h"
47 #include "distributed/pg_dist_partition.h"
48 #include "distributed/pg_dist_shard.h"
49 #include "distributed/pg_dist_placement.h"
50 #include "distributed/reference_table_utils.h"
51 #include "distributed/relay_utility.h"
52 #include "distributed/resource_lock.h"
53 #include "distributed/remote_commands.h"
54 #include "distributed/tuplestore.h"
55 #include "distributed/worker_manager.h"
56 #include "distributed/worker_protocol.h"
57 #include "distributed/version_compat.h"
58 #include "nodes/makefuncs.h"
59 #include "parser/scansup.h"
60 #include "storage/lmgr.h"
61 #include "utils/acl.h"
62 #include "utils/builtins.h"
63 #include "utils/datum.h"
64 #include "utils/fmgroids.h"
65 #include "utils/inval.h"
66 #include "utils/lsyscache.h"
67 #include "utils/rel.h"
68 #include "utils/syscache.h"
69 
70 #define DISK_SPACE_FIELDS 2
71 
72 /* Local functions forward declarations */
73 static uint64 * AllocateUint64(uint64 value);
74 static void RecordDistributedRelationDependencies(Oid distributedRelationId);
75 static GroupShardPlacement * TupleToGroupShardPlacement(TupleDesc tupleDesc,
76 														HeapTuple heapTuple);
77 static bool DistributedTableSize(Oid relationId, SizeQueryType sizeQueryType,
78 								 bool failOnError, uint64 *tableSize);
79 static bool DistributedTableSizeOnWorker(WorkerNode *workerNode, Oid relationId,
80 										 SizeQueryType sizeQueryType, bool failOnError,
81 										 uint64 *tableSize);
82 static List * ShardIntervalsOnWorkerGroup(WorkerNode *workerNode, Oid relationId);
83 static char * GenerateShardStatisticsQueryForShardList(List *shardIntervalList, bool
84 													   useShardMinMaxQuery);
85 static char * GetWorkerPartitionedSizeUDFNameBySizeQueryType(SizeQueryType sizeQueryType);
86 static char * GetSizeQueryBySizeQueryType(SizeQueryType sizeQueryType);
87 static char * GenerateAllShardStatisticsQueryForNode(WorkerNode *workerNode,
88 													 List *citusTableIds, bool
89 													 useShardMinMaxQuery);
90 static List * GenerateShardStatisticsQueryList(List *workerNodeList, List *citusTableIds,
91 											   bool useShardMinMaxQuery);
92 static void ErrorIfNotSuitableToGetSize(Oid relationId);
93 static List * OpenConnectionToNodes(List *workerNodeList);
94 static void ReceiveShardNameAndSizeResults(List *connectionList,
95 										   Tuplestorestate *tupleStore,
96 										   TupleDesc tupleDescriptor);
97 static void AppendShardSizeMinMaxQuery(StringInfo selectQuery, uint64 shardId,
98 									   ShardInterval *
99 									   shardInterval, char *shardName,
100 									   char *quotedShardName);
101 static void AppendShardSizeQuery(StringInfo selectQuery, ShardInterval *shardInterval,
102 								 char *quotedShardName);
103 
104 static HeapTuple CreateDiskSpaceTuple(TupleDesc tupleDesc, uint64 availableBytes,
105 									  uint64 totalBytes);
106 static bool GetLocalDiskSpaceStats(uint64 *availableBytes, uint64 *totalBytes);
107 
108 /* exports for SQL callable functions */
109 PG_FUNCTION_INFO_V1(citus_local_disk_space_stats);
110 PG_FUNCTION_INFO_V1(citus_table_size);
111 PG_FUNCTION_INFO_V1(citus_total_relation_size);
112 PG_FUNCTION_INFO_V1(citus_relation_size);
113 PG_FUNCTION_INFO_V1(citus_shard_sizes);
114 
115 
116 /*
117  * CreateDiskSpaceTuple creates a tuple that is used as the return value
118  * for citus_local_disk_space_stats.
119  */
120 static HeapTuple
CreateDiskSpaceTuple(TupleDesc tupleDescriptor,uint64 availableBytes,uint64 totalBytes)121 CreateDiskSpaceTuple(TupleDesc tupleDescriptor, uint64 availableBytes, uint64 totalBytes)
122 {
123 	Datum values[DISK_SPACE_FIELDS];
124 	bool isNulls[DISK_SPACE_FIELDS];
125 
126 	/* form heap tuple for remote disk space statistics */
127 	memset(values, 0, sizeof(values));
128 	memset(isNulls, false, sizeof(isNulls));
129 
130 	values[0] = UInt64GetDatum(availableBytes);
131 	values[1] = UInt64GetDatum(totalBytes);
132 
133 	HeapTuple diskSpaceTuple = heap_form_tuple(tupleDescriptor, values, isNulls);
134 
135 	return diskSpaceTuple;
136 }
137 
138 
139 /*
140  * citus_local_disk_space_stats returns total disk space and available disk
141  * space for the disk that contains PGDATA.
142  */
143 Datum
citus_local_disk_space_stats(PG_FUNCTION_ARGS)144 citus_local_disk_space_stats(PG_FUNCTION_ARGS)
145 {
146 	uint64 availableBytes = 0;
147 	uint64 totalBytes = 0;
148 
149 	if (!GetLocalDiskSpaceStats(&availableBytes, &totalBytes))
150 	{
151 		ereport(WARNING, (errmsg("could not get disk space")));
152 	}
153 
154 	TupleDesc tupleDescriptor = NULL;
155 
156 	TypeFuncClass resultTypeClass = get_call_result_type(fcinfo, NULL,
157 														 &tupleDescriptor);
158 	if (resultTypeClass != TYPEFUNC_COMPOSITE)
159 	{
160 		ereport(ERROR, (errmsg("return type must be a row type")));
161 	}
162 
163 	HeapTuple diskSpaceTuple = CreateDiskSpaceTuple(tupleDescriptor, availableBytes,
164 													totalBytes);
165 
166 	PG_RETURN_DATUM(HeapTupleGetDatum(diskSpaceTuple));
167 }
168 
169 
170 /*
171  * GetLocalDiskSpaceStats returns total and available disk space for the disk containing
172  * PGDATA (not considering tablespaces, quota).
173  */
174 static bool
GetLocalDiskSpaceStats(uint64 * availableBytes,uint64 * totalBytes)175 GetLocalDiskSpaceStats(uint64 *availableBytes, uint64 *totalBytes)
176 {
177 	struct statvfs buffer;
178 	if (statvfs(DataDir, &buffer) != 0)
179 	{
180 		return false;
181 	}
182 
183 	/*
184 	 * f_bfree: number of free blocks
185 	 * f_frsize: fragment size, same as f_bsize usually
186 	 * f_blocks: Size of fs in f_frsize units
187 	 */
188 	*availableBytes = buffer.f_bfree * buffer.f_frsize;
189 	*totalBytes = buffer.f_blocks * buffer.f_frsize;
190 
191 	return true;
192 }
193 
194 
195 /*
196  * GetNodeDiskSpaceStatsForConnection fetches the disk space statistics for the node
197  * that is on the given connection, or returns false if unsuccessful.
198  */
199 bool
GetNodeDiskSpaceStatsForConnection(MultiConnection * connection,uint64 * availableBytes,uint64 * totalBytes)200 GetNodeDiskSpaceStatsForConnection(MultiConnection *connection, uint64 *availableBytes,
201 								   uint64 *totalBytes)
202 {
203 	PGresult *result = NULL;
204 
205 	char *sizeQuery = "SELECT available_disk_size, total_disk_size "
206 					  "FROM pg_catalog.citus_local_disk_space_stats()";
207 
208 
209 	int queryResult = ExecuteOptionalRemoteCommand(connection, sizeQuery, &result);
210 	if (queryResult != RESPONSE_OKAY || !IsResponseOK(result) || PQntuples(result) != 1)
211 	{
212 		ereport(WARNING, (errcode(ERRCODE_CONNECTION_FAILURE),
213 						  errmsg("cannot get the disk space statistics for node %s:%d",
214 								 connection->hostname, connection->port)));
215 
216 		PQclear(result);
217 		ForgetResults(connection);
218 
219 		return false;
220 	}
221 
222 	char *availableBytesString = PQgetvalue(result, 0, 0);
223 	char *totalBytesString = PQgetvalue(result, 0, 1);
224 
225 	*availableBytes = SafeStringToUint64(availableBytesString);
226 	*totalBytes = SafeStringToUint64(totalBytesString);
227 
228 	PQclear(result);
229 	ForgetResults(connection);
230 
231 	return true;
232 }
233 
234 
235 /*
236  * citus_shard_sizes returns all shard names and their sizes.
237  */
238 Datum
citus_shard_sizes(PG_FUNCTION_ARGS)239 citus_shard_sizes(PG_FUNCTION_ARGS)
240 {
241 	CheckCitusVersion(ERROR);
242 
243 	List *allCitusTableIds = AllCitusTableIds();
244 
245 	/* we don't need a distributed transaction here */
246 	bool useDistributedTransaction = false;
247 
248 	/* we only want the shard sizes here so useShardMinMaxQuery parameter is false */
249 	bool useShardMinMaxQuery = false;
250 	List *connectionList = SendShardStatisticsQueriesInParallel(allCitusTableIds,
251 																useDistributedTransaction,
252 																useShardMinMaxQuery);
253 
254 	TupleDesc tupleDescriptor = NULL;
255 	Tuplestorestate *tupleStore = SetupTuplestore(fcinfo, &tupleDescriptor);
256 
257 	ReceiveShardNameAndSizeResults(connectionList, tupleStore, tupleDescriptor);
258 
259 	/* clean up and return the tuplestore */
260 	tuplestore_donestoring(tupleStore);
261 
262 	PG_RETURN_VOID();
263 }
264 
265 
266 /*
267  * citus_total_relation_size accepts a table name and returns a distributed table
268  * and its indexes' total relation size.
269  */
270 Datum
citus_total_relation_size(PG_FUNCTION_ARGS)271 citus_total_relation_size(PG_FUNCTION_ARGS)
272 {
273 	CheckCitusVersion(ERROR);
274 
275 	Oid relationId = PG_GETARG_OID(0);
276 	bool failOnError = PG_GETARG_BOOL(1);
277 
278 	SizeQueryType sizeQueryType = TOTAL_RELATION_SIZE;
279 
280 	if (CStoreTable(relationId))
281 	{
282 		sizeQueryType = CSTORE_TABLE_SIZE;
283 	}
284 
285 	uint64 tableSize = 0;
286 
287 	if (!DistributedTableSize(relationId, sizeQueryType, failOnError, &tableSize))
288 	{
289 		Assert(!failOnError);
290 		PG_RETURN_NULL();
291 	}
292 
293 	PG_RETURN_INT64(tableSize);
294 }
295 
296 
297 /*
298  * citus_table_size accepts a table name and returns a distributed table's total
299  * relation size.
300  */
301 Datum
citus_table_size(PG_FUNCTION_ARGS)302 citus_table_size(PG_FUNCTION_ARGS)
303 {
304 	CheckCitusVersion(ERROR);
305 
306 	Oid relationId = PG_GETARG_OID(0);
307 	bool failOnError = true;
308 	SizeQueryType sizeQueryType = TABLE_SIZE;
309 
310 	if (CStoreTable(relationId))
311 	{
312 		sizeQueryType = CSTORE_TABLE_SIZE;
313 	}
314 
315 	uint64 tableSize = 0;
316 
317 	if (!DistributedTableSize(relationId, sizeQueryType, failOnError, &tableSize))
318 	{
319 		Assert(!failOnError);
320 		PG_RETURN_NULL();
321 	}
322 
323 	PG_RETURN_INT64(tableSize);
324 }
325 
326 
327 /*
328  * citus_relation_size accept a table name and returns a relation's 'main'
329  * fork's size.
330  */
331 Datum
citus_relation_size(PG_FUNCTION_ARGS)332 citus_relation_size(PG_FUNCTION_ARGS)
333 {
334 	CheckCitusVersion(ERROR);
335 
336 	Oid relationId = PG_GETARG_OID(0);
337 	bool failOnError = true;
338 	SizeQueryType sizeQueryType = RELATION_SIZE;
339 
340 	if (CStoreTable(relationId))
341 	{
342 		sizeQueryType = CSTORE_TABLE_SIZE;
343 	}
344 
345 	uint64 relationSize = 0;
346 
347 	if (!DistributedTableSize(relationId, sizeQueryType, failOnError, &relationSize))
348 	{
349 		Assert(!failOnError);
350 		PG_RETURN_NULL();
351 	}
352 
353 	PG_RETURN_INT64(relationSize);
354 }
355 
356 
357 /*
358  * SendShardStatisticsQueriesInParallel generates query lists for obtaining shard
359  * statistics and then sends the commands in parallel by opening connections
360  * to available nodes. It returns the connection list.
361  */
362 List *
SendShardStatisticsQueriesInParallel(List * citusTableIds,bool useDistributedTransaction,bool useShardMinMaxQuery)363 SendShardStatisticsQueriesInParallel(List *citusTableIds, bool useDistributedTransaction,
364 									 bool
365 									 useShardMinMaxQuery)
366 {
367 	List *workerNodeList = ActivePrimaryNodeList(NoLock);
368 
369 	List *shardSizesQueryList = GenerateShardStatisticsQueryList(workerNodeList,
370 																 citusTableIds,
371 																 useShardMinMaxQuery);
372 
373 	List *connectionList = OpenConnectionToNodes(workerNodeList);
374 	FinishConnectionListEstablishment(connectionList);
375 
376 	if (useDistributedTransaction)
377 	{
378 		/*
379 		 * For now, in the case we want to include shard min and max values, we also
380 		 * want to update the entries in pg_dist_placement and pg_dist_shard with the
381 		 * latest statistics. In order to detect distributed deadlocks, we assign a
382 		 * distributed transaction ID to the current transaction
383 		 */
384 		UseCoordinatedTransaction();
385 	}
386 
387 	/* send commands in parallel */
388 	for (int i = 0; i < list_length(connectionList); i++)
389 	{
390 		MultiConnection *connection = (MultiConnection *) list_nth(connectionList, i);
391 		char *shardSizesQuery = (char *) list_nth(shardSizesQueryList, i);
392 
393 		if (useDistributedTransaction)
394 		{
395 			/* run the size query in a distributed transaction */
396 			RemoteTransactionBeginIfNecessary(connection);
397 		}
398 
399 		int querySent = SendRemoteCommand(connection, shardSizesQuery);
400 
401 		if (querySent == 0)
402 		{
403 			ReportConnectionError(connection, WARNING);
404 		}
405 	}
406 	return connectionList;
407 }
408 
409 
410 /*
411  * OpenConnectionToNodes opens a single connection per node
412  * for the given workerNodeList.
413  */
414 static List *
OpenConnectionToNodes(List * workerNodeList)415 OpenConnectionToNodes(List *workerNodeList)
416 {
417 	List *connectionList = NIL;
418 	WorkerNode *workerNode = NULL;
419 	foreach_ptr(workerNode, workerNodeList)
420 	{
421 		const char *nodeName = workerNode->workerName;
422 		int nodePort = workerNode->workerPort;
423 		int connectionFlags = 0;
424 
425 		MultiConnection *connection = StartNodeConnection(connectionFlags, nodeName,
426 														  nodePort);
427 
428 		connectionList = lappend(connectionList, connection);
429 	}
430 	return connectionList;
431 }
432 
433 
434 /*
435  * GenerateShardStatisticsQueryList generates a query per node that will return:
436  * - all shard_name, shard_size pairs from the node (if includeShardMinMax is false)
437  * - all shard_id, shard_minvalue, shard_maxvalue, shard_size quartuples from the node (if true)
438  */
439 static List *
GenerateShardStatisticsQueryList(List * workerNodeList,List * citusTableIds,bool useShardMinMaxQuery)440 GenerateShardStatisticsQueryList(List *workerNodeList, List *citusTableIds, bool
441 								 useShardMinMaxQuery)
442 {
443 	List *shardStatisticsQueryList = NIL;
444 	WorkerNode *workerNode = NULL;
445 	foreach_ptr(workerNode, workerNodeList)
446 	{
447 		char *shardStatisticsQuery = GenerateAllShardStatisticsQueryForNode(workerNode,
448 																			citusTableIds,
449 																			useShardMinMaxQuery);
450 		shardStatisticsQueryList = lappend(shardStatisticsQueryList,
451 										   shardStatisticsQuery);
452 	}
453 	return shardStatisticsQueryList;
454 }
455 
456 
457 /*
458  * ReceiveShardNameAndSizeResults receives shard name and size results from the given
459  * connection list.
460  */
461 static void
ReceiveShardNameAndSizeResults(List * connectionList,Tuplestorestate * tupleStore,TupleDesc tupleDescriptor)462 ReceiveShardNameAndSizeResults(List *connectionList, Tuplestorestate *tupleStore,
463 							   TupleDesc tupleDescriptor)
464 {
465 	MultiConnection *connection = NULL;
466 	foreach_ptr(connection, connectionList)
467 	{
468 		bool raiseInterrupts = true;
469 		Datum values[SHARD_SIZES_COLUMN_COUNT];
470 		bool isNulls[SHARD_SIZES_COLUMN_COUNT];
471 
472 		if (PQstatus(connection->pgConn) != CONNECTION_OK)
473 		{
474 			continue;
475 		}
476 
477 		PGresult *result = GetRemoteCommandResult(connection, raiseInterrupts);
478 		if (!IsResponseOK(result))
479 		{
480 			ReportResultError(connection, result, WARNING);
481 			continue;
482 		}
483 
484 		int64 rowCount = PQntuples(result);
485 		int64 colCount = PQnfields(result);
486 
487 		/* Although it is not expected */
488 		if (colCount != SHARD_SIZES_COLUMN_COUNT)
489 		{
490 			ereport(WARNING, (errmsg("unexpected number of columns from "
491 									 "citus_shard_sizes")));
492 			continue;
493 		}
494 
495 		for (int64 rowIndex = 0; rowIndex < rowCount; rowIndex++)
496 		{
497 			memset(values, 0, sizeof(values));
498 			memset(isNulls, false, sizeof(isNulls));
499 
500 			char *tableName = PQgetvalue(result, rowIndex, 0);
501 			Datum resultStringDatum = CStringGetDatum(tableName);
502 			Datum textDatum = DirectFunctionCall1(textin, resultStringDatum);
503 
504 			values[0] = textDatum;
505 			values[1] = ParseIntField(result, rowIndex, 1);
506 
507 			tuplestore_putvalues(tupleStore, tupleDescriptor, values, isNulls);
508 		}
509 
510 		PQclear(result);
511 		ForgetResults(connection);
512 	}
513 }
514 
515 
516 /*
517  * DistributedTableSize is helper function for each kind of citus size functions.
518  * It first checks whether the table is distributed and size query can be run on
519  * it. Connection to each node has to be established to get the size of the table.
520  */
521 static bool
DistributedTableSize(Oid relationId,SizeQueryType sizeQueryType,bool failOnError,uint64 * tableSize)522 DistributedTableSize(Oid relationId, SizeQueryType sizeQueryType, bool failOnError,
523 					 uint64 *tableSize)
524 {
525 	int logLevel = WARNING;
526 
527 	if (failOnError)
528 	{
529 		logLevel = ERROR;
530 	}
531 
532 	uint64 sumOfSizes = 0;
533 
534 	if (XactModificationLevel == XACT_MODIFICATION_DATA)
535 	{
536 		ereport(logLevel, (errcode(ERRCODE_ACTIVE_SQL_TRANSACTION),
537 						   errmsg("citus size functions cannot be called in transaction "
538 								  "blocks which contain multi-shard data "
539 								  "modifications")));
540 
541 		return false;
542 	}
543 
544 	Relation relation = try_relation_open(relationId, AccessShareLock);
545 
546 	if (relation == NULL)
547 	{
548 		ereport(logLevel,
549 				(errmsg("could not compute table size: relation does not exist")));
550 
551 		return false;
552 	}
553 
554 	ErrorIfNotSuitableToGetSize(relationId);
555 
556 	table_close(relation, AccessShareLock);
557 
558 	List *workerNodeList = ActiveReadableNodeList();
559 	WorkerNode *workerNode = NULL;
560 	foreach_ptr(workerNode, workerNodeList)
561 	{
562 		uint64 relationSizeOnNode = 0;
563 
564 		bool gotSize = DistributedTableSizeOnWorker(workerNode, relationId, sizeQueryType,
565 													failOnError, &relationSizeOnNode);
566 		if (!gotSize)
567 		{
568 			return false;
569 		}
570 
571 		sumOfSizes += relationSizeOnNode;
572 	}
573 
574 	*tableSize = sumOfSizes;
575 
576 	return true;
577 }
578 
579 
580 /*
581  * DistributedTableSizeOnWorker gets the workerNode and relationId to calculate
582  * size of that relation on the given workerNode by summing up the size of each
583  * shard placement.
584  */
585 static bool
DistributedTableSizeOnWorker(WorkerNode * workerNode,Oid relationId,SizeQueryType sizeQueryType,bool failOnError,uint64 * tableSize)586 DistributedTableSizeOnWorker(WorkerNode *workerNode, Oid relationId,
587 							 SizeQueryType sizeQueryType,
588 							 bool failOnError, uint64 *tableSize)
589 {
590 	int logLevel = WARNING;
591 
592 	if (failOnError)
593 	{
594 		logLevel = ERROR;
595 	}
596 
597 	char *workerNodeName = workerNode->workerName;
598 	uint32 workerNodePort = workerNode->workerPort;
599 	uint32 connectionFlag = 0;
600 	PGresult *result = NULL;
601 
602 	List *shardIntervalsOnNode = ShardIntervalsOnWorkerGroup(workerNode, relationId);
603 
604 	/*
605 	 * We pass false here, because if we optimize this, we would include child tables.
606 	 * But citus size functions shouldn't include them, like PG.
607 	 */
608 	bool optimizePartitionCalculations = false;
609 	StringInfo tableSizeQuery = GenerateSizeQueryOnMultiplePlacements(
610 		shardIntervalsOnNode,
611 		sizeQueryType,
612 		optimizePartitionCalculations);
613 
614 	MultiConnection *connection = GetNodeConnection(connectionFlag, workerNodeName,
615 													workerNodePort);
616 	int queryResult = ExecuteOptionalRemoteCommand(connection, tableSizeQuery->data,
617 												   &result);
618 
619 	if (queryResult != 0)
620 	{
621 		ereport(logLevel, (errcode(ERRCODE_CONNECTION_FAILURE),
622 						   errmsg("could not connect to %s:%d to get size of "
623 								  "table \"%s\"",
624 								  workerNodeName, workerNodePort,
625 								  get_rel_name(relationId))));
626 
627 		return false;
628 	}
629 
630 	List *sizeList = ReadFirstColumnAsText(result);
631 	if (list_length(sizeList) != 1)
632 	{
633 		PQclear(result);
634 		ClearResults(connection, failOnError);
635 
636 		ereport(logLevel, (errcode(ERRCODE_CONNECTION_FAILURE),
637 						   errmsg("cannot parse size of table \"%s\" from %s:%d",
638 								  get_rel_name(relationId), workerNodeName,
639 								  workerNodePort)));
640 
641 		return false;
642 	}
643 
644 	StringInfo tableSizeStringInfo = (StringInfo) linitial(sizeList);
645 	char *tableSizeString = tableSizeStringInfo->data;
646 
647 	if (strlen(tableSizeString) > 0)
648 	{
649 		*tableSize = SafeStringToUint64(tableSizeString);
650 	}
651 	else
652 	{
653 		/*
654 		 * This means the shard is moved or dropped while citus_total_relation_size is
655 		 * being executed. For this case we get an empty string as table size.
656 		 * We can take that as zero to prevent any unnecessary errors.
657 		 */
658 		*tableSize = 0;
659 	}
660 
661 	PQclear(result);
662 	ClearResults(connection, failOnError);
663 
664 	return true;
665 }
666 
667 
668 /*
669  * GroupShardPlacementsForTableOnGroup accepts a relationId and a group and returns a list
670  * of GroupShardPlacement's representing all of the placements for the table which reside
671  * on the group.
672  */
673 List *
GroupShardPlacementsForTableOnGroup(Oid relationId,int32 groupId)674 GroupShardPlacementsForTableOnGroup(Oid relationId, int32 groupId)
675 {
676 	CitusTableCacheEntry *distTableCacheEntry = GetCitusTableCacheEntry(relationId);
677 	List *resultList = NIL;
678 
679 	int shardIntervalArrayLength = distTableCacheEntry->shardIntervalArrayLength;
680 
681 	for (int shardIndex = 0; shardIndex < shardIntervalArrayLength; shardIndex++)
682 	{
683 		GroupShardPlacement *placementArray =
684 			distTableCacheEntry->arrayOfPlacementArrays[shardIndex];
685 		int numberOfPlacements =
686 			distTableCacheEntry->arrayOfPlacementArrayLengths[shardIndex];
687 
688 		for (int placementIndex = 0; placementIndex < numberOfPlacements;
689 			 placementIndex++)
690 		{
691 			if (placementArray[placementIndex].groupId == groupId)
692 			{
693 				GroupShardPlacement *placement = palloc0(sizeof(GroupShardPlacement));
694 				*placement = placementArray[placementIndex];
695 				resultList = lappend(resultList, placement);
696 			}
697 		}
698 	}
699 
700 	return resultList;
701 }
702 
703 
704 /*
705  * ShardIntervalsOnWorkerGroup accepts a WorkerNode and returns a list of the shard
706  * intervals of the given table which are placed on the group the node is a part of.
707  */
708 static List *
ShardIntervalsOnWorkerGroup(WorkerNode * workerNode,Oid relationId)709 ShardIntervalsOnWorkerGroup(WorkerNode *workerNode, Oid relationId)
710 {
711 	CitusTableCacheEntry *distTableCacheEntry = GetCitusTableCacheEntry(relationId);
712 	List *shardIntervalList = NIL;
713 	int shardIntervalArrayLength = distTableCacheEntry->shardIntervalArrayLength;
714 
715 	for (int shardIndex = 0; shardIndex < shardIntervalArrayLength; shardIndex++)
716 	{
717 		GroupShardPlacement *placementArray =
718 			distTableCacheEntry->arrayOfPlacementArrays[shardIndex];
719 		int numberOfPlacements =
720 			distTableCacheEntry->arrayOfPlacementArrayLengths[shardIndex];
721 
722 		for (int placementIndex = 0; placementIndex < numberOfPlacements;
723 			 placementIndex++)
724 		{
725 			GroupShardPlacement *placement = &placementArray[placementIndex];
726 
727 			if (placement->groupId == workerNode->groupId)
728 			{
729 				ShardInterval *cachedShardInterval =
730 					distTableCacheEntry->sortedShardIntervalArray[shardIndex];
731 				ShardInterval *shardInterval = CopyShardInterval(cachedShardInterval);
732 				shardIntervalList = lappend(shardIntervalList, shardInterval);
733 			}
734 		}
735 	}
736 
737 	return shardIntervalList;
738 }
739 
740 
741 /*
742  * GenerateSizeQueryOnMultiplePlacements generates a select size query to get
743  * size of multiple tables. Note that, different size functions supported by PG
744  * are also supported by this function changing the size query type given as the
745  * last parameter to function. Depending on the sizeQueryType enum parameter, the
746  * generated query will call one of the functions: pg_relation_size,
747  * pg_total_relation_size, pg_table_size and cstore_table_size.
748  * This function uses UDFs named worker_partitioned_*_size for partitioned tables,
749  * if the parameter optimizePartitionCalculations is true. The UDF to be called is
750  * determined by the parameter sizeQueryType.
751  */
752 StringInfo
GenerateSizeQueryOnMultiplePlacements(List * shardIntervalList,SizeQueryType sizeQueryType,bool optimizePartitionCalculations)753 GenerateSizeQueryOnMultiplePlacements(List *shardIntervalList,
754 									  SizeQueryType sizeQueryType,
755 									  bool optimizePartitionCalculations)
756 {
757 	StringInfo selectQuery = makeStringInfo();
758 
759 	appendStringInfo(selectQuery, "SELECT ");
760 
761 	ShardInterval *shardInterval = NULL;
762 	foreach_ptr(shardInterval, shardIntervalList)
763 	{
764 		if (optimizePartitionCalculations && PartitionTable(shardInterval->relationId))
765 		{
766 			/*
767 			 * Skip child tables of a partitioned table as they are already counted in
768 			 * worker_partitioned_*_size UDFs, if optimizePartitionCalculations is true.
769 			 * We don't expect this case to happen, since we don't send the child tables
770 			 * to this function. Because they are all eliminated in
771 			 * ColocatedNonPartitionShardIntervalList. Therefore we can't cover here with
772 			 * a test currently. This is added for possible future usages.
773 			 */
774 			continue;
775 		}
776 		uint64 shardId = shardInterval->shardId;
777 		Oid schemaId = get_rel_namespace(shardInterval->relationId);
778 		char *schemaName = get_namespace_name(schemaId);
779 		char *shardName = get_rel_name(shardInterval->relationId);
780 		AppendShardIdToName(&shardName, shardId);
781 
782 		char *shardQualifiedName = quote_qualified_identifier(schemaName, shardName);
783 		char *quotedShardName = quote_literal_cstr(shardQualifiedName);
784 
785 		if (optimizePartitionCalculations && PartitionedTable(shardInterval->relationId))
786 		{
787 			appendStringInfo(selectQuery, GetWorkerPartitionedSizeUDFNameBySizeQueryType(
788 								 sizeQueryType), quotedShardName);
789 		}
790 		else
791 		{
792 			appendStringInfo(selectQuery, GetSizeQueryBySizeQueryType(sizeQueryType),
793 							 quotedShardName);
794 		}
795 
796 		appendStringInfo(selectQuery, " + ");
797 	}
798 
799 	/*
800 	 * Add 0 as a last size, it handles empty list case and makes size control checks
801 	 * unnecessary which would have implemented without this line.
802 	 */
803 	appendStringInfo(selectQuery, "0;");
804 
805 	return selectQuery;
806 }
807 
808 
809 /*
810  * GetWorkerPartitionedSizeUDFNameBySizeQueryType returns the corresponding worker
811  * partitioned size query for given query type.
812  * Errors out for an invalid query type.
813  * Currently this function is only called with the type TOTAL_RELATION_SIZE.
814  * The others are added for possible future usages. Since they are not used anywhere,
815  * currently we can't cover them with tests.
816  */
817 static char *
GetWorkerPartitionedSizeUDFNameBySizeQueryType(SizeQueryType sizeQueryType)818 GetWorkerPartitionedSizeUDFNameBySizeQueryType(SizeQueryType sizeQueryType)
819 {
820 	switch (sizeQueryType)
821 	{
822 		case RELATION_SIZE:
823 		{
824 			return WORKER_PARTITIONED_RELATION_SIZE_FUNCTION;
825 		}
826 
827 		case TOTAL_RELATION_SIZE:
828 		{
829 			return WORKER_PARTITIONED_RELATION_TOTAL_SIZE_FUNCTION;
830 		}
831 
832 		case TABLE_SIZE:
833 		{
834 			return WORKER_PARTITIONED_TABLE_SIZE_FUNCTION;
835 		}
836 
837 		default:
838 		{
839 			elog(ERROR, "Size query type couldn't be found.");
840 		}
841 	}
842 }
843 
844 
845 /*
846  * GetSizeQueryBySizeQueryType returns the corresponding size query for given query type.
847  * Errors out for an invalid query type.
848  */
849 static char *
GetSizeQueryBySizeQueryType(SizeQueryType sizeQueryType)850 GetSizeQueryBySizeQueryType(SizeQueryType sizeQueryType)
851 {
852 	switch (sizeQueryType)
853 	{
854 		case RELATION_SIZE:
855 		{
856 			return PG_RELATION_SIZE_FUNCTION;
857 		}
858 
859 		case TOTAL_RELATION_SIZE:
860 		{
861 			return PG_TOTAL_RELATION_SIZE_FUNCTION;
862 		}
863 
864 		case CSTORE_TABLE_SIZE:
865 		{
866 			return CSTORE_TABLE_SIZE_FUNCTION;
867 		}
868 
869 		case TABLE_SIZE:
870 		{
871 			return PG_TABLE_SIZE_FUNCTION;
872 		}
873 
874 		default:
875 		{
876 			elog(ERROR, "Size query type couldn't be found.");
877 		}
878 	}
879 }
880 
881 
882 /*
883  * GenerateAllShardStatisticsQueryForNode generates a query that returns:
884  * - all shard_name, shard_size pairs for the given node (if useShardMinMaxQuery is false)
885  * - all shard_id, shard_minvalue, shard_maxvalue, shard_size quartuples (if true)
886  */
887 static char *
GenerateAllShardStatisticsQueryForNode(WorkerNode * workerNode,List * citusTableIds,bool useShardMinMaxQuery)888 GenerateAllShardStatisticsQueryForNode(WorkerNode *workerNode, List *citusTableIds, bool
889 									   useShardMinMaxQuery)
890 {
891 	StringInfo allShardStatisticsQuery = makeStringInfo();
892 
893 	Oid relationId = InvalidOid;
894 	foreach_oid(relationId, citusTableIds)
895 	{
896 		/*
897 		 * Ensure the table still exists by trying to acquire a lock on it
898 		 * If function returns NULL, it means the table doesn't exist
899 		 * hence we should skip
900 		 */
901 		Relation relation = try_relation_open(relationId, AccessShareLock);
902 		if (relation != NULL)
903 		{
904 			List *shardIntervalsOnNode = ShardIntervalsOnWorkerGroup(workerNode,
905 																	 relationId);
906 			char *shardStatisticsQuery =
907 				GenerateShardStatisticsQueryForShardList(shardIntervalsOnNode,
908 														 useShardMinMaxQuery);
909 			appendStringInfoString(allShardStatisticsQuery, shardStatisticsQuery);
910 			relation_close(relation, AccessShareLock);
911 		}
912 	}
913 
914 	/* Add a dummy entry so that UNION ALL doesn't complain */
915 	if (useShardMinMaxQuery)
916 	{
917 		/* 0 for shard_id, NULL for min, NULL for text, 0 for shard_size */
918 		appendStringInfo(allShardStatisticsQuery,
919 						 "SELECT 0::bigint, NULL::text, NULL::text, 0::bigint;");
920 	}
921 	else
922 	{
923 		/* NULL for shard_name, 0 for shard_size */
924 		appendStringInfo(allShardStatisticsQuery, "SELECT NULL::text, 0::bigint;");
925 	}
926 	return allShardStatisticsQuery->data;
927 }
928 
929 
930 /*
931  * GenerateShardStatisticsQueryForShardList generates one of the two types of queries:
932  * - SELECT shard_name - shard_size (if useShardMinMaxQuery is false)
933  * - SELECT shard_id, shard_minvalue, shard_maxvalue, shard_size (if true)
934  */
935 static char *
GenerateShardStatisticsQueryForShardList(List * shardIntervalList,bool useShardMinMaxQuery)936 GenerateShardStatisticsQueryForShardList(List *shardIntervalList, bool
937 										 useShardMinMaxQuery)
938 {
939 	StringInfo selectQuery = makeStringInfo();
940 
941 	ShardInterval *shardInterval = NULL;
942 	foreach_ptr(shardInterval, shardIntervalList)
943 	{
944 		uint64 shardId = shardInterval->shardId;
945 		Oid schemaId = get_rel_namespace(shardInterval->relationId);
946 		char *schemaName = get_namespace_name(schemaId);
947 		char *shardName = get_rel_name(shardInterval->relationId);
948 		AppendShardIdToName(&shardName, shardId);
949 
950 		char *shardQualifiedName = quote_qualified_identifier(schemaName, shardName);
951 		char *quotedShardName = quote_literal_cstr(shardQualifiedName);
952 
953 		if (useShardMinMaxQuery)
954 		{
955 			AppendShardSizeMinMaxQuery(selectQuery, shardId, shardInterval, shardName,
956 									   quotedShardName);
957 		}
958 		else
959 		{
960 			AppendShardSizeQuery(selectQuery, shardInterval, quotedShardName);
961 		}
962 		appendStringInfo(selectQuery, " UNION ALL ");
963 	}
964 
965 	return selectQuery->data;
966 }
967 
968 
969 /*
970  * AppendShardSizeMinMaxQuery appends a query in the following form to selectQuery
971  * SELECT shard_id, shard_minvalue, shard_maxvalue, shard_size
972  */
973 static void
AppendShardSizeMinMaxQuery(StringInfo selectQuery,uint64 shardId,ShardInterval * shardInterval,char * shardName,char * quotedShardName)974 AppendShardSizeMinMaxQuery(StringInfo selectQuery, uint64 shardId,
975 						   ShardInterval *shardInterval, char *shardName,
976 						   char *quotedShardName)
977 {
978 	if (IsCitusTableType(shardInterval->relationId, APPEND_DISTRIBUTED))
979 	{
980 		/* fill in the partition column name */
981 		const uint32 unusedTableId = 1;
982 		Var *partitionColumn = PartitionColumn(shardInterval->relationId,
983 											   unusedTableId);
984 		char *partitionColumnName = get_attname(shardInterval->relationId,
985 												partitionColumn->varattno, false);
986 		appendStringInfo(selectQuery,
987 						 "SELECT " UINT64_FORMAT
988 						 " AS shard_id, min(%s)::text AS shard_minvalue, max(%s)::text AS shard_maxvalue, pg_relation_size(%s) AS shard_size FROM %s ",
989 						 shardId, partitionColumnName,
990 						 partitionColumnName,
991 						 quotedShardName, shardName);
992 	}
993 	else
994 	{
995 		/* we don't need to update min/max for non-append distributed tables because they don't change */
996 		appendStringInfo(selectQuery,
997 						 "SELECT " UINT64_FORMAT
998 						 " AS shard_id, NULL::text AS shard_minvalue, NULL::text AS shard_maxvalue, pg_relation_size(%s) AS shard_size ",
999 						 shardId, quotedShardName);
1000 	}
1001 }
1002 
1003 
1004 /*
1005  * AppendShardSizeQuery appends a query in the following form to selectQuery
1006  * SELECT shard_name, shard_size
1007  */
1008 static void
AppendShardSizeQuery(StringInfo selectQuery,ShardInterval * shardInterval,char * quotedShardName)1009 AppendShardSizeQuery(StringInfo selectQuery, ShardInterval *shardInterval,
1010 					 char *quotedShardName)
1011 {
1012 	appendStringInfo(selectQuery, "SELECT %s AS shard_name, ", quotedShardName);
1013 	appendStringInfo(selectQuery, PG_RELATION_SIZE_FUNCTION, quotedShardName);
1014 }
1015 
1016 
1017 /*
1018  * ErrorIfNotSuitableToGetSize determines whether the table is suitable to find
1019  * its' size with internal functions.
1020  */
1021 static void
ErrorIfNotSuitableToGetSize(Oid relationId)1022 ErrorIfNotSuitableToGetSize(Oid relationId)
1023 {
1024 	if (!IsCitusTable(relationId))
1025 	{
1026 		char *relationName = get_rel_name(relationId);
1027 		char *escapedQueryString = quote_literal_cstr(relationName);
1028 		ereport(ERROR, (errcode(ERRCODE_INVALID_TABLE_DEFINITION),
1029 						errmsg("cannot calculate the size because relation %s is not "
1030 							   "distributed", escapedQueryString)));
1031 	}
1032 }
1033 
1034 
1035 /*
1036  * CompareShardPlacementsByWorker compares two shard placements by their
1037  * worker node name and port.
1038  */
1039 int
CompareShardPlacementsByWorker(const void * leftElement,const void * rightElement)1040 CompareShardPlacementsByWorker(const void *leftElement, const void *rightElement)
1041 {
1042 	const ShardPlacement *leftPlacement = *((const ShardPlacement **) leftElement);
1043 	const ShardPlacement *rightPlacement = *((const ShardPlacement **) rightElement);
1044 
1045 	int nodeNameCmp = strncmp(leftPlacement->nodeName, rightPlacement->nodeName,
1046 							  WORKER_LENGTH);
1047 	if (nodeNameCmp != 0)
1048 	{
1049 		return nodeNameCmp;
1050 	}
1051 	else if (leftPlacement->nodePort > rightPlacement->nodePort)
1052 	{
1053 		return 1;
1054 	}
1055 	else if (leftPlacement->nodePort < rightPlacement->nodePort)
1056 	{
1057 		return -1;
1058 	}
1059 
1060 	return 0;
1061 }
1062 
1063 
1064 /*
1065  * CompareShardPlacementsByGroupId compares two shard placements by their
1066  * group id.
1067  */
1068 int
CompareShardPlacementsByGroupId(const void * leftElement,const void * rightElement)1069 CompareShardPlacementsByGroupId(const void *leftElement, const void *rightElement)
1070 {
1071 	const ShardPlacement *leftPlacement = *((const ShardPlacement **) leftElement);
1072 	const ShardPlacement *rightPlacement = *((const ShardPlacement **) rightElement);
1073 
1074 	if (leftPlacement->groupId > rightPlacement->groupId)
1075 	{
1076 		return 1;
1077 	}
1078 	else if (leftPlacement->groupId < rightPlacement->groupId)
1079 	{
1080 		return -1;
1081 	}
1082 	else
1083 	{
1084 		return 0;
1085 	}
1086 }
1087 
1088 
1089 /*
1090  * TableShardReplicationFactor returns the current replication factor of the
1091  * given relation by looking into shard placements. It errors out if there
1092  * are different number of shard placements for different shards. It also
1093  * errors out if the table does not have any shards.
1094  */
1095 uint32
TableShardReplicationFactor(Oid relationId)1096 TableShardReplicationFactor(Oid relationId)
1097 {
1098 	uint32 replicationCount = 0;
1099 
1100 	List *shardIntervalList = LoadShardIntervalList(relationId);
1101 	ShardInterval *shardInterval = NULL;
1102 	foreach_ptr(shardInterval, shardIntervalList)
1103 	{
1104 		uint64 shardId = shardInterval->shardId;
1105 
1106 		List *shardPlacementList = ShardPlacementListWithoutOrphanedPlacements(shardId);
1107 		uint32 shardPlacementCount = list_length(shardPlacementList);
1108 
1109 		/*
1110 		 * Get the replication count of the first shard in the list, and error
1111 		 * out if there is a shard with different replication count.
1112 		 */
1113 		if (replicationCount == 0)
1114 		{
1115 			replicationCount = shardPlacementCount;
1116 		}
1117 		else if (replicationCount != shardPlacementCount)
1118 		{
1119 			char *relationName = get_rel_name(relationId);
1120 			ereport(ERROR, (errcode(ERRCODE_FEATURE_NOT_SUPPORTED),
1121 							errmsg("cannot find the replication factor of the "
1122 								   "table %s", relationName),
1123 							errdetail("The shard " UINT64_FORMAT
1124 									  " has different shards replication counts from "
1125 									  "other shards.", shardId)));
1126 		}
1127 	}
1128 
1129 	/* error out if the table does not have any shards */
1130 	if (replicationCount == 0)
1131 	{
1132 		char *relationName = get_rel_name(relationId);
1133 		ereport(ERROR, (errcode(ERRCODE_FEATURE_NOT_SUPPORTED),
1134 						errmsg("cannot find the replication factor of the "
1135 							   "table %s", relationName),
1136 						errdetail("The table %s does not have any shards.",
1137 								  relationName)));
1138 	}
1139 
1140 	return replicationCount;
1141 }
1142 
1143 
1144 /*
1145  * LoadShardIntervalList returns a list of shard intervals related for a given
1146  * distributed table. The function returns an empty list if no shards can be
1147  * found for the given relation.
1148  * Since LoadShardIntervalList relies on sortedShardIntervalArray, it returns
1149  * a shard interval list whose elements are sorted on shardminvalue. Shard intervals
1150  * with uninitialized shard min/max values are placed in the end of the list.
1151  */
1152 List *
LoadShardIntervalList(Oid relationId)1153 LoadShardIntervalList(Oid relationId)
1154 {
1155 	CitusTableCacheEntry *cacheEntry = GetCitusTableCacheEntry(relationId);
1156 	List *shardList = NIL;
1157 
1158 	for (int i = 0; i < cacheEntry->shardIntervalArrayLength; i++)
1159 	{
1160 		ShardInterval *newShardInterval =
1161 			CopyShardInterval(cacheEntry->sortedShardIntervalArray[i]);
1162 		shardList = lappend(shardList, newShardInterval);
1163 	}
1164 
1165 	return shardList;
1166 }
1167 
1168 
1169 /*
1170  * LoadShardIntervalWithLongestShardName is a utility function that returns
1171  * the shard interaval with the largest shardId for the given relationId. Note
1172  * that largest shardId implies longest shard name.
1173  */
1174 ShardInterval *
LoadShardIntervalWithLongestShardName(Oid relationId)1175 LoadShardIntervalWithLongestShardName(Oid relationId)
1176 {
1177 	CitusTableCacheEntry *cacheEntry = GetCitusTableCacheEntry(relationId);
1178 	int shardIntervalCount = cacheEntry->shardIntervalArrayLength;
1179 
1180 	int maxShardIndex = shardIntervalCount - 1;
1181 	uint64 largestShardId = INVALID_SHARD_ID;
1182 
1183 	for (int shardIndex = 0; shardIndex <= maxShardIndex; ++shardIndex)
1184 	{
1185 		ShardInterval *currentShardInterval =
1186 			cacheEntry->sortedShardIntervalArray[shardIndex];
1187 
1188 		if (largestShardId < currentShardInterval->shardId)
1189 		{
1190 			largestShardId = currentShardInterval->shardId;
1191 		}
1192 	}
1193 
1194 	return LoadShardInterval(largestShardId);
1195 }
1196 
1197 
1198 /*
1199  * ShardIntervalCount returns number of shard intervals for a given distributed table.
1200  * The function returns 0 if no shards can be found for the given relation id.
1201  */
1202 int
ShardIntervalCount(Oid relationId)1203 ShardIntervalCount(Oid relationId)
1204 {
1205 	CitusTableCacheEntry *cacheEntry = GetCitusTableCacheEntry(relationId);
1206 
1207 	return cacheEntry->shardIntervalArrayLength;
1208 }
1209 
1210 
1211 /*
1212  * LoadShardList reads list of shards for given relationId from pg_dist_shard,
1213  * and returns the list of found shardIds.
1214  * Since LoadShardList relies on sortedShardIntervalArray, it returns a shard
1215  * list whose elements are sorted on shardminvalue. Shards with uninitialized
1216  * shard min/max values are placed in the end of the list.
1217  */
1218 List *
LoadShardList(Oid relationId)1219 LoadShardList(Oid relationId)
1220 {
1221 	CitusTableCacheEntry *cacheEntry = GetCitusTableCacheEntry(relationId);
1222 	List *shardList = NIL;
1223 
1224 	for (int i = 0; i < cacheEntry->shardIntervalArrayLength; i++)
1225 	{
1226 		ShardInterval *currentShardInterval = cacheEntry->sortedShardIntervalArray[i];
1227 		uint64 *shardIdPointer = AllocateUint64(currentShardInterval->shardId);
1228 
1229 		shardList = lappend(shardList, shardIdPointer);
1230 	}
1231 
1232 	return shardList;
1233 }
1234 
1235 
1236 /* Allocates eight bytes, and copies given value's contents those bytes. */
1237 static uint64 *
AllocateUint64(uint64 value)1238 AllocateUint64(uint64 value)
1239 {
1240 	uint64 *allocatedValue = (uint64 *) palloc0(sizeof(uint64));
1241 	Assert(sizeof(uint64) >= 8);
1242 
1243 	(*allocatedValue) = value;
1244 
1245 	return allocatedValue;
1246 }
1247 
1248 
1249 /*
1250  * CopyShardInterval creates a copy of the specified source ShardInterval.
1251  */
1252 ShardInterval *
CopyShardInterval(ShardInterval * srcInterval)1253 CopyShardInterval(ShardInterval *srcInterval)
1254 {
1255 	ShardInterval *destInterval = palloc0(sizeof(ShardInterval));
1256 
1257 	destInterval->type = srcInterval->type;
1258 	destInterval->relationId = srcInterval->relationId;
1259 	destInterval->storageType = srcInterval->storageType;
1260 	destInterval->valueTypeId = srcInterval->valueTypeId;
1261 	destInterval->valueTypeLen = srcInterval->valueTypeLen;
1262 	destInterval->valueByVal = srcInterval->valueByVal;
1263 	destInterval->minValueExists = srcInterval->minValueExists;
1264 	destInterval->maxValueExists = srcInterval->maxValueExists;
1265 	destInterval->shardId = srcInterval->shardId;
1266 	destInterval->shardIndex = srcInterval->shardIndex;
1267 
1268 	destInterval->minValue = 0;
1269 	if (destInterval->minValueExists)
1270 	{
1271 		destInterval->minValue = datumCopy(srcInterval->minValue,
1272 										   srcInterval->valueByVal,
1273 										   srcInterval->valueTypeLen);
1274 	}
1275 
1276 	destInterval->maxValue = 0;
1277 	if (destInterval->maxValueExists)
1278 	{
1279 		destInterval->maxValue = datumCopy(srcInterval->maxValue,
1280 										   srcInterval->valueByVal,
1281 										   srcInterval->valueTypeLen);
1282 	}
1283 
1284 	return destInterval;
1285 }
1286 
1287 
1288 /*
1289  * ShardLength finds shard placements for the given shardId, extracts the length
1290  * of an active shard, and returns the shard's length. This function errors
1291  * out if we cannot find any active shard placements for the given shardId.
1292  */
1293 uint64
ShardLength(uint64 shardId)1294 ShardLength(uint64 shardId)
1295 {
1296 	uint64 shardLength = 0;
1297 
1298 	List *shardPlacementList = ActiveShardPlacementList(shardId);
1299 	if (shardPlacementList == NIL)
1300 	{
1301 		ereport(ERROR, (errmsg("could not find length of shard " UINT64_FORMAT, shardId),
1302 						errdetail("Could not find any shard placements for the shard.")));
1303 	}
1304 	else
1305 	{
1306 		ShardPlacement *shardPlacement = (ShardPlacement *) linitial(shardPlacementList);
1307 		shardLength = shardPlacement->shardLength;
1308 	}
1309 
1310 	return shardLength;
1311 }
1312 
1313 
1314 /*
1315  * NodeGroupHasLivePlacements returns true if there is any placement
1316  * on the given node group which is not a SHARD_STATE_TO_DELETE placement.
1317  */
1318 bool
NodeGroupHasLivePlacements(int32 groupId)1319 NodeGroupHasLivePlacements(int32 groupId)
1320 {
1321 	List *shardPlacements = AllShardPlacementsOnNodeGroup(groupId);
1322 	GroupShardPlacement *placement = NULL;
1323 	foreach_ptr(placement, shardPlacements)
1324 	{
1325 		if (placement->shardState != SHARD_STATE_TO_DELETE)
1326 		{
1327 			return true;
1328 		}
1329 	}
1330 	return false;
1331 }
1332 
1333 
1334 /*
1335  * NodeGroupHasShardPlacements returns whether any active shards are placed on the group
1336  */
1337 bool
NodeGroupHasShardPlacements(int32 groupId,bool onlyConsiderActivePlacements)1338 NodeGroupHasShardPlacements(int32 groupId, bool onlyConsiderActivePlacements)
1339 {
1340 	const int scanKeyCount = (onlyConsiderActivePlacements ? 2 : 1);
1341 	const bool indexOK = false;
1342 
1343 
1344 	ScanKeyData scanKey[2];
1345 
1346 	Relation pgPlacement = table_open(DistPlacementRelationId(),
1347 									  AccessShareLock);
1348 
1349 	ScanKeyInit(&scanKey[0], Anum_pg_dist_placement_groupid,
1350 				BTEqualStrategyNumber, F_INT4EQ, Int32GetDatum(groupId));
1351 	if (onlyConsiderActivePlacements)
1352 	{
1353 		ScanKeyInit(&scanKey[1], Anum_pg_dist_placement_shardstate,
1354 					BTEqualStrategyNumber, F_INT4EQ,
1355 					Int32GetDatum(SHARD_STATE_ACTIVE));
1356 	}
1357 
1358 	SysScanDesc scanDescriptor = systable_beginscan(pgPlacement,
1359 													DistPlacementGroupidIndexId(),
1360 													indexOK,
1361 													NULL, scanKeyCount, scanKey);
1362 
1363 	HeapTuple heapTuple = systable_getnext(scanDescriptor);
1364 	bool hasActivePlacements = HeapTupleIsValid(heapTuple);
1365 
1366 	systable_endscan(scanDescriptor);
1367 	table_close(pgPlacement, NoLock);
1368 
1369 	return hasActivePlacements;
1370 }
1371 
1372 
1373 /*
1374  * ActiveShardPlacementListOnGroup returns a list of active shard placements
1375  * that are sitting on group with groupId for given shardId.
1376  */
1377 List *
ActiveShardPlacementListOnGroup(uint64 shardId,int32 groupId)1378 ActiveShardPlacementListOnGroup(uint64 shardId, int32 groupId)
1379 {
1380 	List *activeShardPlacementListOnGroup = NIL;
1381 
1382 	List *activePlacementList = ActiveShardPlacementList(shardId);
1383 	ShardPlacement *shardPlacement = NULL;
1384 	foreach_ptr(shardPlacement, activePlacementList)
1385 	{
1386 		if (shardPlacement->groupId == groupId)
1387 		{
1388 			activeShardPlacementListOnGroup = lappend(activeShardPlacementListOnGroup,
1389 													  shardPlacement);
1390 		}
1391 	}
1392 
1393 	return activeShardPlacementListOnGroup;
1394 }
1395 
1396 
1397 /*
1398  * ActiveShardPlacementList finds shard placements for the given shardId from
1399  * system catalogs, chooses placements that are in active state, and returns
1400  * these shard placements in a new list.
1401  */
1402 List *
ActiveShardPlacementList(uint64 shardId)1403 ActiveShardPlacementList(uint64 shardId)
1404 {
1405 	List *activePlacementList = NIL;
1406 	List *shardPlacementList =
1407 		ShardPlacementListIncludingOrphanedPlacements(shardId);
1408 
1409 	ShardPlacement *shardPlacement = NULL;
1410 	foreach_ptr(shardPlacement, shardPlacementList)
1411 	{
1412 		if (shardPlacement->shardState == SHARD_STATE_ACTIVE)
1413 		{
1414 			activePlacementList = lappend(activePlacementList, shardPlacement);
1415 		}
1416 	}
1417 
1418 	return SortList(activePlacementList, CompareShardPlacementsByWorker);
1419 }
1420 
1421 
1422 /*
1423  * ShardPlacementListWithoutOrphanedPlacements returns shard placements exluding
1424  * the ones that are orphaned, because they are marked to be deleted at a later
1425  * point (shardstate = 4).
1426  */
1427 List *
ShardPlacementListWithoutOrphanedPlacements(uint64 shardId)1428 ShardPlacementListWithoutOrphanedPlacements(uint64 shardId)
1429 {
1430 	List *activePlacementList = NIL;
1431 	List *shardPlacementList =
1432 		ShardPlacementListIncludingOrphanedPlacements(shardId);
1433 
1434 	ShardPlacement *shardPlacement = NULL;
1435 	foreach_ptr(shardPlacement, shardPlacementList)
1436 	{
1437 		if (shardPlacement->shardState != SHARD_STATE_TO_DELETE)
1438 		{
1439 			activePlacementList = lappend(activePlacementList, shardPlacement);
1440 		}
1441 	}
1442 
1443 	return SortList(activePlacementList, CompareShardPlacementsByWorker);
1444 }
1445 
1446 
1447 /*
1448  * ActiveShardPlacement finds a shard placement for the given shardId from
1449  * system catalog, chooses a placement that is in active state and returns
1450  * that shard placement. If this function cannot find a healthy shard placement
1451  * and missingOk is set to false it errors out.
1452  */
1453 ShardPlacement *
ActiveShardPlacement(uint64 shardId,bool missingOk)1454 ActiveShardPlacement(uint64 shardId, bool missingOk)
1455 {
1456 	List *activePlacementList = ActiveShardPlacementList(shardId);
1457 	ShardPlacement *shardPlacement = NULL;
1458 
1459 	if (list_length(activePlacementList) == 0)
1460 	{
1461 		if (!missingOk)
1462 		{
1463 			ereport(ERROR, (errcode(ERRCODE_INVALID_PARAMETER_VALUE),
1464 							errmsg("could not find any healthy placement for shard "
1465 								   UINT64_FORMAT, shardId)));
1466 		}
1467 
1468 		return shardPlacement;
1469 	}
1470 
1471 	shardPlacement = (ShardPlacement *) linitial(activePlacementList);
1472 
1473 	return shardPlacement;
1474 }
1475 
1476 
1477 /*
1478  * BuildShardPlacementList finds shard placements for the given shardId from
1479  * system catalogs, converts these placements to their in-memory
1480  * representation, and returns the converted shard placements in a new list.
1481  *
1482  * This probably only should be called from metadata_cache.c.  Resides here
1483  * because it shares code with other routines in this file.
1484  */
1485 List *
BuildShardPlacementList(int64 shardId)1486 BuildShardPlacementList(int64 shardId)
1487 {
1488 	List *shardPlacementList = NIL;
1489 	ScanKeyData scanKey[1];
1490 	int scanKeyCount = 1;
1491 	bool indexOK = true;
1492 
1493 	Relation pgPlacement = table_open(DistPlacementRelationId(), AccessShareLock);
1494 
1495 	ScanKeyInit(&scanKey[0], Anum_pg_dist_placement_shardid,
1496 				BTEqualStrategyNumber, F_INT8EQ, Int64GetDatum(shardId));
1497 
1498 	SysScanDesc scanDescriptor = systable_beginscan(pgPlacement,
1499 													DistPlacementShardidIndexId(),
1500 													indexOK,
1501 													NULL, scanKeyCount, scanKey);
1502 
1503 	HeapTuple heapTuple = systable_getnext(scanDescriptor);
1504 	while (HeapTupleIsValid(heapTuple))
1505 	{
1506 		TupleDesc tupleDescriptor = RelationGetDescr(pgPlacement);
1507 
1508 		GroupShardPlacement *placement =
1509 			TupleToGroupShardPlacement(tupleDescriptor, heapTuple);
1510 
1511 		shardPlacementList = lappend(shardPlacementList, placement);
1512 
1513 		heapTuple = systable_getnext(scanDescriptor);
1514 	}
1515 
1516 	systable_endscan(scanDescriptor);
1517 	table_close(pgPlacement, NoLock);
1518 
1519 	return shardPlacementList;
1520 }
1521 
1522 
1523 /*
1524  * BuildShardPlacementListForGroup finds shard placements for the given groupId
1525  * from system catalogs, converts these placements to their in-memory
1526  * representation, and returns the converted shard placements in a new list.
1527  */
1528 List *
AllShardPlacementsOnNodeGroup(int32 groupId)1529 AllShardPlacementsOnNodeGroup(int32 groupId)
1530 {
1531 	List *shardPlacementList = NIL;
1532 	ScanKeyData scanKey[1];
1533 	int scanKeyCount = 1;
1534 	bool indexOK = true;
1535 
1536 	Relation pgPlacement = table_open(DistPlacementRelationId(), AccessShareLock);
1537 
1538 	ScanKeyInit(&scanKey[0], Anum_pg_dist_placement_groupid,
1539 				BTEqualStrategyNumber, F_INT4EQ, Int32GetDatum(groupId));
1540 
1541 	SysScanDesc scanDescriptor = systable_beginscan(pgPlacement,
1542 													DistPlacementGroupidIndexId(),
1543 													indexOK,
1544 													NULL, scanKeyCount, scanKey);
1545 
1546 	HeapTuple heapTuple = systable_getnext(scanDescriptor);
1547 	while (HeapTupleIsValid(heapTuple))
1548 	{
1549 		TupleDesc tupleDescriptor = RelationGetDescr(pgPlacement);
1550 
1551 		GroupShardPlacement *placement =
1552 			TupleToGroupShardPlacement(tupleDescriptor, heapTuple);
1553 
1554 		shardPlacementList = lappend(shardPlacementList, placement);
1555 
1556 		heapTuple = systable_getnext(scanDescriptor);
1557 	}
1558 
1559 	systable_endscan(scanDescriptor);
1560 	table_close(pgPlacement, NoLock);
1561 
1562 	return shardPlacementList;
1563 }
1564 
1565 
1566 /*
1567  * AllShardPlacementsWithShardPlacementState finds shard placements with the given
1568  * shardState from system catalogs, converts these placements to their in-memory
1569  * representation, and returns the converted shard placements in a new list.
1570  */
1571 List *
AllShardPlacementsWithShardPlacementState(ShardState shardState)1572 AllShardPlacementsWithShardPlacementState(ShardState shardState)
1573 {
1574 	List *shardPlacementList = NIL;
1575 	ScanKeyData scanKey[1];
1576 	int scanKeyCount = 1;
1577 
1578 	Relation pgPlacement = table_open(DistPlacementRelationId(), AccessShareLock);
1579 
1580 	ScanKeyInit(&scanKey[0], Anum_pg_dist_placement_shardstate,
1581 				BTEqualStrategyNumber, F_INT4EQ, Int32GetDatum(shardState));
1582 
1583 	SysScanDesc scanDescriptor = systable_beginscan(pgPlacement, InvalidOid, false,
1584 													NULL, scanKeyCount, scanKey);
1585 
1586 	HeapTuple heapTuple = systable_getnext(scanDescriptor);
1587 	while (HeapTupleIsValid(heapTuple))
1588 	{
1589 		TupleDesc tupleDescriptor = RelationGetDescr(pgPlacement);
1590 
1591 		GroupShardPlacement *placement =
1592 			TupleToGroupShardPlacement(tupleDescriptor, heapTuple);
1593 
1594 		shardPlacementList = lappend(shardPlacementList, placement);
1595 
1596 		heapTuple = systable_getnext(scanDescriptor);
1597 	}
1598 
1599 	systable_endscan(scanDescriptor);
1600 	table_close(pgPlacement, NoLock);
1601 
1602 	return shardPlacementList;
1603 }
1604 
1605 
1606 /*
1607  * TupleToGroupShardPlacement takes in a heap tuple from pg_dist_placement,
1608  * and converts this tuple to in-memory struct. The function assumes the
1609  * caller already has locks on the tuple, and doesn't perform any locking.
1610  */
1611 static GroupShardPlacement *
TupleToGroupShardPlacement(TupleDesc tupleDescriptor,HeapTuple heapTuple)1612 TupleToGroupShardPlacement(TupleDesc tupleDescriptor, HeapTuple heapTuple)
1613 {
1614 	bool isNullArray[Natts_pg_dist_placement];
1615 	Datum datumArray[Natts_pg_dist_placement];
1616 
1617 	if (HeapTupleHeaderGetNatts(heapTuple->t_data) != Natts_pg_dist_placement ||
1618 		HeapTupleHasNulls(heapTuple))
1619 	{
1620 		ereport(ERROR, (errmsg("unexpected null in pg_dist_placement tuple")));
1621 	}
1622 
1623 	/*
1624 	 * We use heap_deform_tuple() instead of heap_getattr() to expand tuple
1625 	 * to contain missing values when ALTER TABLE ADD COLUMN happens.
1626 	 */
1627 	heap_deform_tuple(heapTuple, tupleDescriptor, datumArray, isNullArray);
1628 
1629 	GroupShardPlacement *shardPlacement = CitusMakeNode(GroupShardPlacement);
1630 	shardPlacement->placementId = DatumGetInt64(
1631 		datumArray[Anum_pg_dist_placement_placementid - 1]);
1632 	shardPlacement->shardId = DatumGetInt64(
1633 		datumArray[Anum_pg_dist_placement_shardid - 1]);
1634 	shardPlacement->shardLength = DatumGetInt64(
1635 		datumArray[Anum_pg_dist_placement_shardlength - 1]);
1636 	shardPlacement->shardState = DatumGetUInt32(
1637 		datumArray[Anum_pg_dist_placement_shardstate - 1]);
1638 	shardPlacement->groupId = DatumGetInt32(
1639 		datumArray[Anum_pg_dist_placement_groupid - 1]);
1640 
1641 	return shardPlacement;
1642 }
1643 
1644 
1645 /*
1646  * InsertShardRow opens the shard system catalog, and inserts a new row with the
1647  * given values into that system catalog. Note that we allow the user to pass in
1648  * null min/max values in case they are creating an empty shard.
1649  */
1650 void
InsertShardRow(Oid relationId,uint64 shardId,char storageType,text * shardMinValue,text * shardMaxValue)1651 InsertShardRow(Oid relationId, uint64 shardId, char storageType,
1652 			   text *shardMinValue, text *shardMaxValue)
1653 {
1654 	Datum values[Natts_pg_dist_shard];
1655 	bool isNulls[Natts_pg_dist_shard];
1656 
1657 	/* form new shard tuple */
1658 	memset(values, 0, sizeof(values));
1659 	memset(isNulls, false, sizeof(isNulls));
1660 
1661 	values[Anum_pg_dist_shard_logicalrelid - 1] = ObjectIdGetDatum(relationId);
1662 	values[Anum_pg_dist_shard_shardid - 1] = Int64GetDatum(shardId);
1663 	values[Anum_pg_dist_shard_shardstorage - 1] = CharGetDatum(storageType);
1664 
1665 	/* dropped shardalias column must also be set; it is still part of the tuple */
1666 	isNulls[Anum_pg_dist_shard_shardalias_DROPPED - 1] = true;
1667 
1668 	/* check if shard min/max values are null */
1669 	if (shardMinValue != NULL && shardMaxValue != NULL)
1670 	{
1671 		values[Anum_pg_dist_shard_shardminvalue - 1] = PointerGetDatum(shardMinValue);
1672 		values[Anum_pg_dist_shard_shardmaxvalue - 1] = PointerGetDatum(shardMaxValue);
1673 	}
1674 	else
1675 	{
1676 		isNulls[Anum_pg_dist_shard_shardminvalue - 1] = true;
1677 		isNulls[Anum_pg_dist_shard_shardmaxvalue - 1] = true;
1678 	}
1679 
1680 	/* open shard relation and insert new tuple */
1681 	Relation pgDistShard = table_open(DistShardRelationId(), RowExclusiveLock);
1682 
1683 	TupleDesc tupleDescriptor = RelationGetDescr(pgDistShard);
1684 	HeapTuple heapTuple = heap_form_tuple(tupleDescriptor, values, isNulls);
1685 
1686 	CatalogTupleInsert(pgDistShard, heapTuple);
1687 
1688 	/* invalidate previous cache entry and close relation */
1689 	CitusInvalidateRelcacheByRelid(relationId);
1690 
1691 	CommandCounterIncrement();
1692 	table_close(pgDistShard, NoLock);
1693 }
1694 
1695 
1696 /*
1697  * InsertShardPlacementRow opens the shard placement system catalog, and inserts
1698  * a new row with the given values into that system catalog. If placementId is
1699  * INVALID_PLACEMENT_ID, a new placement id will be assigned.Then, returns the
1700  * placement id of the added shard placement.
1701  */
1702 uint64
InsertShardPlacementRow(uint64 shardId,uint64 placementId,char shardState,uint64 shardLength,int32 groupId)1703 InsertShardPlacementRow(uint64 shardId, uint64 placementId,
1704 						char shardState, uint64 shardLength,
1705 						int32 groupId)
1706 {
1707 	Datum values[Natts_pg_dist_placement];
1708 	bool isNulls[Natts_pg_dist_placement];
1709 
1710 	/* form new shard placement tuple */
1711 	memset(values, 0, sizeof(values));
1712 	memset(isNulls, false, sizeof(isNulls));
1713 
1714 	if (placementId == INVALID_PLACEMENT_ID)
1715 	{
1716 		placementId = master_get_new_placementid(NULL);
1717 	}
1718 	values[Anum_pg_dist_placement_placementid - 1] = Int64GetDatum(placementId);
1719 	values[Anum_pg_dist_placement_shardid - 1] = Int64GetDatum(shardId);
1720 	values[Anum_pg_dist_placement_shardstate - 1] = CharGetDatum(shardState);
1721 	values[Anum_pg_dist_placement_shardlength - 1] = Int64GetDatum(shardLength);
1722 	values[Anum_pg_dist_placement_groupid - 1] = Int32GetDatum(groupId);
1723 
1724 	/* open shard placement relation and insert new tuple */
1725 	Relation pgDistPlacement = table_open(DistPlacementRelationId(), RowExclusiveLock);
1726 
1727 	TupleDesc tupleDescriptor = RelationGetDescr(pgDistPlacement);
1728 	HeapTuple heapTuple = heap_form_tuple(tupleDescriptor, values, isNulls);
1729 
1730 	CatalogTupleInsert(pgDistPlacement, heapTuple);
1731 
1732 	CitusInvalidateRelcacheByShardId(shardId);
1733 
1734 	CommandCounterIncrement();
1735 	table_close(pgDistPlacement, NoLock);
1736 
1737 	return placementId;
1738 }
1739 
1740 
1741 /*
1742  * InsertIntoPgDistPartition inserts a new tuple into pg_dist_partition.
1743  */
1744 void
InsertIntoPgDistPartition(Oid relationId,char distributionMethod,Var * distributionColumn,uint32 colocationId,char replicationModel)1745 InsertIntoPgDistPartition(Oid relationId, char distributionMethod,
1746 						  Var *distributionColumn, uint32 colocationId,
1747 						  char replicationModel)
1748 {
1749 	char *distributionColumnString = NULL;
1750 
1751 	Datum newValues[Natts_pg_dist_partition];
1752 	bool newNulls[Natts_pg_dist_partition];
1753 
1754 	/* open system catalog and insert new tuple */
1755 	Relation pgDistPartition = table_open(DistPartitionRelationId(), RowExclusiveLock);
1756 
1757 	/* form new tuple for pg_dist_partition */
1758 	memset(newValues, 0, sizeof(newValues));
1759 	memset(newNulls, false, sizeof(newNulls));
1760 
1761 	newValues[Anum_pg_dist_partition_logicalrelid - 1] =
1762 		ObjectIdGetDatum(relationId);
1763 	newValues[Anum_pg_dist_partition_partmethod - 1] =
1764 		CharGetDatum(distributionMethod);
1765 	newValues[Anum_pg_dist_partition_colocationid - 1] = UInt32GetDatum(colocationId);
1766 	newValues[Anum_pg_dist_partition_repmodel - 1] = CharGetDatum(replicationModel);
1767 
1768 	/* set partkey column to NULL for reference tables */
1769 	if (distributionMethod != DISTRIBUTE_BY_NONE)
1770 	{
1771 		distributionColumnString = nodeToString((Node *) distributionColumn);
1772 
1773 		newValues[Anum_pg_dist_partition_partkey - 1] =
1774 			CStringGetTextDatum(distributionColumnString);
1775 	}
1776 	else
1777 	{
1778 		newValues[Anum_pg_dist_partition_partkey - 1] = PointerGetDatum(NULL);
1779 		newNulls[Anum_pg_dist_partition_partkey - 1] = true;
1780 	}
1781 
1782 	HeapTuple newTuple = heap_form_tuple(RelationGetDescr(pgDistPartition), newValues,
1783 										 newNulls);
1784 
1785 	/* finally insert tuple, build index entries & register cache invalidation */
1786 	CatalogTupleInsert(pgDistPartition, newTuple);
1787 
1788 	CitusInvalidateRelcacheByRelid(relationId);
1789 
1790 	RecordDistributedRelationDependencies(relationId);
1791 
1792 	CommandCounterIncrement();
1793 	table_close(pgDistPartition, NoLock);
1794 }
1795 
1796 
1797 /*
1798  * RecordDistributedRelationDependencies creates the dependency entries
1799  * necessary for a distributed relation in addition to the preexisting ones
1800  * for a normal relation.
1801  *
1802  * We create one dependency from the (now distributed) relation to the citus
1803  * extension to prevent the extension from being dropped while distributed
1804  * tables exist. Furthermore a dependency from pg_dist_partition's
1805  * distribution clause to the underlying columns is created, but it's marked
1806  * as being owned by the relation itself. That means the entire table can be
1807  * dropped, but the column itself can't. Neither can the type of the
1808  * distribution column be changed (c.f. ATExecAlterColumnType).
1809  */
1810 static void
RecordDistributedRelationDependencies(Oid distributedRelationId)1811 RecordDistributedRelationDependencies(Oid distributedRelationId)
1812 {
1813 	ObjectAddress relationAddr = { 0, 0, 0 };
1814 	ObjectAddress citusExtensionAddr = { 0, 0, 0 };
1815 
1816 	relationAddr.classId = RelationRelationId;
1817 	relationAddr.objectId = distributedRelationId;
1818 	relationAddr.objectSubId = 0;
1819 
1820 	citusExtensionAddr.classId = ExtensionRelationId;
1821 	citusExtensionAddr.objectId = get_extension_oid("citus", false);
1822 	citusExtensionAddr.objectSubId = 0;
1823 
1824 	/* dependency from table entry to extension */
1825 	recordDependencyOn(&relationAddr, &citusExtensionAddr, DEPENDENCY_NORMAL);
1826 }
1827 
1828 
1829 /*
1830  * DeletePartitionRow removes the row from pg_dist_partition where the logicalrelid
1831  * field equals to distributedRelationId. Then, the function invalidates the
1832  * metadata cache.
1833  */
1834 void
DeletePartitionRow(Oid distributedRelationId)1835 DeletePartitionRow(Oid distributedRelationId)
1836 {
1837 	ScanKeyData scanKey[1];
1838 	int scanKeyCount = 1;
1839 
1840 	Relation pgDistPartition = table_open(DistPartitionRelationId(), RowExclusiveLock);
1841 
1842 	ScanKeyInit(&scanKey[0], Anum_pg_dist_partition_logicalrelid,
1843 				BTEqualStrategyNumber, F_OIDEQ, ObjectIdGetDatum(distributedRelationId));
1844 
1845 	SysScanDesc scanDescriptor = systable_beginscan(pgDistPartition, InvalidOid, false,
1846 													NULL,
1847 													scanKeyCount, scanKey);
1848 
1849 	HeapTuple heapTuple = systable_getnext(scanDescriptor);
1850 	if (!HeapTupleIsValid(heapTuple))
1851 	{
1852 		ereport(ERROR, (errmsg("could not find valid entry for partition %d",
1853 							   distributedRelationId)));
1854 	}
1855 
1856 	simple_heap_delete(pgDistPartition, &heapTuple->t_self);
1857 
1858 	systable_endscan(scanDescriptor);
1859 
1860 	/* invalidate the cache */
1861 	CitusInvalidateRelcacheByRelid(distributedRelationId);
1862 
1863 	/* increment the counter so that next command can see the row */
1864 	CommandCounterIncrement();
1865 
1866 	table_close(pgDistPartition, NoLock);
1867 }
1868 
1869 
1870 /*
1871  * DeleteShardRow opens the shard system catalog, finds the unique row that has
1872  * the given shardId, and deletes this row.
1873  */
1874 void
DeleteShardRow(uint64 shardId)1875 DeleteShardRow(uint64 shardId)
1876 {
1877 	ScanKeyData scanKey[1];
1878 	int scanKeyCount = 1;
1879 	bool indexOK = true;
1880 
1881 	Relation pgDistShard = table_open(DistShardRelationId(), RowExclusiveLock);
1882 
1883 	ScanKeyInit(&scanKey[0], Anum_pg_dist_shard_shardid,
1884 				BTEqualStrategyNumber, F_INT8EQ, Int64GetDatum(shardId));
1885 
1886 	SysScanDesc scanDescriptor = systable_beginscan(pgDistShard,
1887 													DistShardShardidIndexId(), indexOK,
1888 													NULL, scanKeyCount, scanKey);
1889 
1890 	HeapTuple heapTuple = systable_getnext(scanDescriptor);
1891 	if (!HeapTupleIsValid(heapTuple))
1892 	{
1893 		ereport(ERROR, (errmsg("could not find valid entry for shard "
1894 							   UINT64_FORMAT, shardId)));
1895 	}
1896 
1897 	Form_pg_dist_shard pgDistShardForm = (Form_pg_dist_shard) GETSTRUCT(heapTuple);
1898 	Oid distributedRelationId = pgDistShardForm->logicalrelid;
1899 
1900 	simple_heap_delete(pgDistShard, &heapTuple->t_self);
1901 
1902 	systable_endscan(scanDescriptor);
1903 
1904 	/* invalidate previous cache entry */
1905 	CitusInvalidateRelcacheByRelid(distributedRelationId);
1906 
1907 	CommandCounterIncrement();
1908 	table_close(pgDistShard, NoLock);
1909 }
1910 
1911 
1912 /*
1913  * DeleteShardPlacementRow opens the shard placement system catalog, finds the placement
1914  * with the given placementId, and deletes it.
1915  */
1916 void
DeleteShardPlacementRow(uint64 placementId)1917 DeleteShardPlacementRow(uint64 placementId)
1918 {
1919 	const int scanKeyCount = 1;
1920 	ScanKeyData scanKey[1];
1921 	bool indexOK = true;
1922 	bool isNull = false;
1923 
1924 	Relation pgDistPlacement = table_open(DistPlacementRelationId(), RowExclusiveLock);
1925 	TupleDesc tupleDescriptor = RelationGetDescr(pgDistPlacement);
1926 
1927 	ScanKeyInit(&scanKey[0], Anum_pg_dist_placement_placementid,
1928 				BTEqualStrategyNumber, F_INT8EQ, Int64GetDatum(placementId));
1929 
1930 	SysScanDesc scanDescriptor = systable_beginscan(pgDistPlacement,
1931 													DistPlacementPlacementidIndexId(),
1932 													indexOK,
1933 													NULL, scanKeyCount, scanKey);
1934 
1935 	HeapTuple heapTuple = systable_getnext(scanDescriptor);
1936 	if (heapTuple == NULL)
1937 	{
1938 		ereport(ERROR, (errmsg("could not find valid entry for shard placement "
1939 							   INT64_FORMAT, placementId)));
1940 	}
1941 
1942 	uint64 shardId = heap_getattr(heapTuple, Anum_pg_dist_placement_shardid,
1943 								  tupleDescriptor, &isNull);
1944 	if (HeapTupleHeaderGetNatts(heapTuple->t_data) != Natts_pg_dist_placement ||
1945 		HeapTupleHasNulls(heapTuple))
1946 	{
1947 		ereport(ERROR, (errmsg("unexpected null in pg_dist_placement tuple")));
1948 	}
1949 
1950 	simple_heap_delete(pgDistPlacement, &heapTuple->t_self);
1951 	systable_endscan(scanDescriptor);
1952 
1953 	CitusInvalidateRelcacheByShardId(shardId);
1954 
1955 	CommandCounterIncrement();
1956 	table_close(pgDistPlacement, NoLock);
1957 }
1958 
1959 
1960 /*
1961  * UpdatePartitionShardPlacementStates gets a shard placement which is asserted to belong
1962  * to partitioned table. The function goes over the corresponding placements of its
1963  * partitions, and sets their state to the input shardState.
1964  */
1965 void
UpdatePartitionShardPlacementStates(ShardPlacement * parentShardPlacement,char shardState)1966 UpdatePartitionShardPlacementStates(ShardPlacement *parentShardPlacement, char shardState)
1967 {
1968 	ShardInterval *parentShardInterval =
1969 		LoadShardInterval(parentShardPlacement->shardId);
1970 	Oid partitionedTableOid = parentShardInterval->relationId;
1971 
1972 	/* this function should only be called for partitioned tables */
1973 	Assert(PartitionedTable(partitionedTableOid));
1974 
1975 	List *partitionList = PartitionList(partitionedTableOid);
1976 	Oid partitionOid = InvalidOid;
1977 	foreach_oid(partitionOid, partitionList)
1978 	{
1979 		uint64 partitionShardId =
1980 			ColocatedShardIdInRelation(partitionOid, parentShardInterval->shardIndex);
1981 
1982 		ShardPlacement *partitionPlacement =
1983 			ShardPlacementOnGroupIncludingOrphanedPlacements(
1984 				parentShardPlacement->groupId, partitionShardId);
1985 
1986 		/* the partition should have a placement with the same group */
1987 		Assert(partitionPlacement != NULL);
1988 
1989 		UpdateShardPlacementState(partitionPlacement->placementId, shardState);
1990 	}
1991 }
1992 
1993 
1994 /*
1995  * MarkShardPlacementInactive is a wrapper around UpdateShardPlacementState where
1996  * the state is set to SHARD_STATE_INACTIVE. It also marks partitions of the
1997  * shard placements as inactive if shardPlacement belongs to a partitioned table.
1998  */
1999 void
MarkShardPlacementInactive(ShardPlacement * shardPlacement)2000 MarkShardPlacementInactive(ShardPlacement *shardPlacement)
2001 {
2002 	UpdateShardPlacementState(shardPlacement->placementId, SHARD_STATE_INACTIVE);
2003 
2004 	/*
2005 	 * In case the shard belongs to a partitioned table, we make sure to update
2006 	 * the states of its partitions. Repairing shards already ensures to recreate
2007 	 * all the partitions.
2008 	 */
2009 	ShardInterval *shardInterval = LoadShardInterval(shardPlacement->shardId);
2010 	if (PartitionedTable(shardInterval->relationId))
2011 	{
2012 		UpdatePartitionShardPlacementStates(shardPlacement, SHARD_STATE_INACTIVE);
2013 	}
2014 }
2015 
2016 
2017 /*
2018  * UpdateShardPlacementState sets the shardState for the placement identified
2019  * by placementId.
2020  */
2021 void
UpdateShardPlacementState(uint64 placementId,char shardState)2022 UpdateShardPlacementState(uint64 placementId, char shardState)
2023 {
2024 	ScanKeyData scanKey[1];
2025 	int scanKeyCount = 1;
2026 	bool indexOK = true;
2027 	Datum values[Natts_pg_dist_placement];
2028 	bool isnull[Natts_pg_dist_placement];
2029 	bool replace[Natts_pg_dist_placement];
2030 	bool colIsNull = false;
2031 
2032 	Relation pgDistPlacement = table_open(DistPlacementRelationId(), RowExclusiveLock);
2033 	TupleDesc tupleDescriptor = RelationGetDescr(pgDistPlacement);
2034 	ScanKeyInit(&scanKey[0], Anum_pg_dist_placement_placementid,
2035 				BTEqualStrategyNumber, F_INT8EQ, Int64GetDatum(placementId));
2036 
2037 	SysScanDesc scanDescriptor = systable_beginscan(pgDistPlacement,
2038 													DistPlacementPlacementidIndexId(),
2039 													indexOK,
2040 													NULL, scanKeyCount, scanKey);
2041 
2042 	HeapTuple heapTuple = systable_getnext(scanDescriptor);
2043 	if (!HeapTupleIsValid(heapTuple))
2044 	{
2045 		ereport(ERROR, (errmsg("could not find valid entry for shard placement "
2046 							   UINT64_FORMAT,
2047 							   placementId)));
2048 	}
2049 
2050 	memset(replace, 0, sizeof(replace));
2051 
2052 	values[Anum_pg_dist_placement_shardstate - 1] = CharGetDatum(shardState);
2053 	isnull[Anum_pg_dist_placement_shardstate - 1] = false;
2054 	replace[Anum_pg_dist_placement_shardstate - 1] = true;
2055 
2056 	heapTuple = heap_modify_tuple(heapTuple, tupleDescriptor, values, isnull, replace);
2057 
2058 	CatalogTupleUpdate(pgDistPlacement, &heapTuple->t_self, heapTuple);
2059 
2060 	uint64 shardId = DatumGetInt64(heap_getattr(heapTuple,
2061 												Anum_pg_dist_placement_shardid,
2062 												tupleDescriptor, &colIsNull));
2063 	Assert(!colIsNull);
2064 	CitusInvalidateRelcacheByShardId(shardId);
2065 
2066 	CommandCounterIncrement();
2067 
2068 	systable_endscan(scanDescriptor);
2069 	table_close(pgDistPlacement, NoLock);
2070 }
2071 
2072 
2073 /*
2074  * UpdatePlacementGroupId sets the groupId for the placement identified
2075  * by placementId.
2076  */
2077 void
UpdatePlacementGroupId(uint64 placementId,int groupId)2078 UpdatePlacementGroupId(uint64 placementId, int groupId)
2079 {
2080 	ScanKeyData scanKey[1];
2081 	int scanKeyCount = 1;
2082 	bool indexOK = true;
2083 	Datum values[Natts_pg_dist_placement];
2084 	bool isnull[Natts_pg_dist_placement];
2085 	bool replace[Natts_pg_dist_placement];
2086 	bool colIsNull = false;
2087 
2088 	Relation pgDistPlacement = table_open(DistPlacementRelationId(), RowExclusiveLock);
2089 	TupleDesc tupleDescriptor = RelationGetDescr(pgDistPlacement);
2090 	ScanKeyInit(&scanKey[0], Anum_pg_dist_placement_placementid,
2091 				BTEqualStrategyNumber, F_INT8EQ, Int64GetDatum(placementId));
2092 
2093 	SysScanDesc scanDescriptor = systable_beginscan(pgDistPlacement,
2094 													DistPlacementPlacementidIndexId(),
2095 													indexOK,
2096 													NULL, scanKeyCount, scanKey);
2097 
2098 	HeapTuple heapTuple = systable_getnext(scanDescriptor);
2099 	if (!HeapTupleIsValid(heapTuple))
2100 	{
2101 		ereport(ERROR, (errmsg("could not find valid entry for shard placement "
2102 							   UINT64_FORMAT,
2103 							   placementId)));
2104 	}
2105 
2106 	memset(replace, 0, sizeof(replace));
2107 
2108 	values[Anum_pg_dist_placement_groupid - 1] = Int32GetDatum(groupId);
2109 	isnull[Anum_pg_dist_placement_groupid - 1] = false;
2110 	replace[Anum_pg_dist_placement_groupid - 1] = true;
2111 
2112 	heapTuple = heap_modify_tuple(heapTuple, tupleDescriptor, values, isnull, replace);
2113 
2114 	CatalogTupleUpdate(pgDistPlacement, &heapTuple->t_self, heapTuple);
2115 
2116 	uint64 shardId = DatumGetInt64(heap_getattr(heapTuple,
2117 												Anum_pg_dist_placement_shardid,
2118 												tupleDescriptor, &colIsNull));
2119 	Assert(!colIsNull);
2120 	CitusInvalidateRelcacheByShardId(shardId);
2121 
2122 	CommandCounterIncrement();
2123 
2124 	systable_endscan(scanDescriptor);
2125 	table_close(pgDistPlacement, NoLock);
2126 }
2127 
2128 
2129 /*
2130  * Check that the current user has `mode` permissions on relationId, error out
2131  * if not. Superusers always have such permissions.
2132  */
2133 void
EnsureTablePermissions(Oid relationId,AclMode mode)2134 EnsureTablePermissions(Oid relationId, AclMode mode)
2135 {
2136 	AclResult aclresult = pg_class_aclcheck(relationId, GetUserId(), mode);
2137 
2138 	if (aclresult != ACLCHECK_OK)
2139 	{
2140 		aclcheck_error(aclresult, OBJECT_TABLE, get_rel_name(relationId));
2141 	}
2142 }
2143 
2144 
2145 /*
2146  * Check that the current user has owner rights to relationId, error out if
2147  * not. Superusers are regarded as owners.
2148  */
2149 void
EnsureTableOwner(Oid relationId)2150 EnsureTableOwner(Oid relationId)
2151 {
2152 	if (!pg_class_ownercheck(relationId, GetUserId()))
2153 	{
2154 		aclcheck_error(ACLCHECK_NOT_OWNER, OBJECT_TABLE,
2155 					   get_rel_name(relationId));
2156 	}
2157 }
2158 
2159 
2160 /*
2161  * Check that the current user has owner rights to the schema, error out if
2162  * not. Superusers are regarded as owners.
2163  */
2164 void
EnsureSchemaOwner(Oid schemaId)2165 EnsureSchemaOwner(Oid schemaId)
2166 {
2167 	if (!pg_namespace_ownercheck(schemaId, GetUserId()))
2168 	{
2169 		aclcheck_error(ACLCHECK_NOT_OWNER, OBJECT_SCHEMA,
2170 					   get_namespace_name(schemaId));
2171 	}
2172 }
2173 
2174 
2175 /*
2176  * Check that the current user has owner rights to functionId, error out if
2177  * not. Superusers are regarded as owners. Functions and procedures are
2178  * treated equally.
2179  */
2180 void
EnsureFunctionOwner(Oid functionId)2181 EnsureFunctionOwner(Oid functionId)
2182 {
2183 	if (!pg_proc_ownercheck(functionId, GetUserId()))
2184 	{
2185 		aclcheck_error(ACLCHECK_NOT_OWNER, OBJECT_FUNCTION,
2186 					   get_func_name(functionId));
2187 	}
2188 }
2189 
2190 
2191 /*
2192  * EnsureHashDistributedTable error out if the given relation is not a hash distributed table
2193  * with the given message.
2194  */
2195 void
EnsureHashDistributedTable(Oid relationId)2196 EnsureHashDistributedTable(Oid relationId)
2197 {
2198 	if (!IsCitusTableType(relationId, HASH_DISTRIBUTED))
2199 	{
2200 		ereport(ERROR, (errcode(ERRCODE_FEATURE_NOT_SUPPORTED),
2201 						errmsg("relation %s should be a "
2202 							   "hash distributed table", get_rel_name(relationId))));
2203 	}
2204 }
2205 
2206 
2207 /*
2208  * EnsureSuperUser check that the current user is a superuser and errors out if not.
2209  */
2210 void
EnsureSuperUser(void)2211 EnsureSuperUser(void)
2212 {
2213 	if (!superuser())
2214 	{
2215 		ereport(ERROR, (errmsg("operation is not allowed"),
2216 						errhint("Run the command with a superuser.")));
2217 	}
2218 }
2219 
2220 
2221 /*
2222  * Return a table's owner as a string.
2223  */
2224 char *
TableOwner(Oid relationId)2225 TableOwner(Oid relationId)
2226 {
2227 	HeapTuple tuple = SearchSysCache1(RELOID, ObjectIdGetDatum(relationId));
2228 	if (!HeapTupleIsValid(tuple))
2229 	{
2230 		ereport(ERROR, (errcode(ERRCODE_UNDEFINED_TABLE),
2231 						errmsg("relation with OID %u does not exist", relationId)));
2232 	}
2233 
2234 	Oid userId = ((Form_pg_class) GETSTRUCT(tuple))->relowner;
2235 
2236 	ReleaseSysCache(tuple);
2237 
2238 	return GetUserNameFromId(userId, false);
2239 }
2240