1 /*-------------------------------------------------------------------------
2  *
3  * colocation_utils.c
4  *
5  * This file contains functions to perform useful operations on co-located tables.
6  *
7  * Copyright (c) Citus Data, Inc.
8  *
9  *-------------------------------------------------------------------------
10  */
11 
12 #include "postgres.h"
13 #include "miscadmin.h"
14 
15 #include "access/genam.h"
16 #include "access/heapam.h"
17 #include "access/htup_details.h"
18 #include "access/xact.h"
19 #include "catalog/indexing.h"
20 #include "catalog/pg_type.h"
21 #include "commands/sequence.h"
22 #include "distributed/colocation_utils.h"
23 #include "distributed/listutils.h"
24 #include "distributed/metadata_utility.h"
25 #include "distributed/coordinator_protocol.h"
26 #include "distributed/metadata_cache.h"
27 #include "distributed/metadata_sync.h"
28 #include "distributed/multi_logical_planner.h"
29 #include "distributed/multi_partitioning_utils.h"
30 #include "distributed/pg_dist_colocation.h"
31 #include "distributed/resource_lock.h"
32 #include "distributed/shardinterval_utils.h"
33 #include "distributed/version_compat.h"
34 #include "distributed/worker_protocol.h"
35 #include "distributed/worker_transaction.h"
36 #include "storage/lmgr.h"
37 #include "utils/builtins.h"
38 #include "utils/fmgroids.h"
39 #include "utils/lsyscache.h"
40 #include "utils/rel.h"
41 
42 
43 /* local function forward declarations */
44 static void MarkTablesColocated(Oid sourceRelationId, Oid targetRelationId);
45 static bool ShardsIntervalsEqual(ShardInterval *leftShardInterval,
46 								 ShardInterval *rightShardInterval);
47 static bool HashPartitionedShardIntervalsEqual(ShardInterval *leftShardInterval,
48 											   ShardInterval *rightShardInterval);
49 static int CompareShardPlacementsByNode(const void *leftElement,
50 										const void *rightElement);
51 static void DeleteColocationGroup(uint32 colocationId);
52 static uint32 CreateColocationGroupForRelation(Oid sourceRelationId);
53 static void BreakColocation(Oid sourceRelationId);
54 
55 /* exports for SQL callable functions */
56 PG_FUNCTION_INFO_V1(mark_tables_colocated);
57 PG_FUNCTION_INFO_V1(get_colocated_shard_array);
58 PG_FUNCTION_INFO_V1(update_distributed_table_colocation);
59 
60 
61 /*
62  * mark_tables_colocated puts target tables to same colocation group with the
63  * source table. If the source table is in INVALID_COLOCATION_ID group, then it
64  * creates a new colocation group and assigns all tables to this new colocation
65  * group.
66  */
67 Datum
mark_tables_colocated(PG_FUNCTION_ARGS)68 mark_tables_colocated(PG_FUNCTION_ARGS)
69 {
70 	CheckCitusVersion(ERROR);
71 	EnsureCoordinator();
72 
73 	Oid sourceRelationId = PG_GETARG_OID(0);
74 	ArrayType *relationIdArrayObject = PG_GETARG_ARRAYTYPE_P(1);
75 
76 	int relationCount = ArrayObjectCount(relationIdArrayObject);
77 	if (relationCount < 1)
78 	{
79 		ereport(ERROR, (errmsg("at least one target table is required for this "
80 							   "operation")));
81 	}
82 
83 	EnsureTableOwner(sourceRelationId);
84 
85 	Datum *relationIdDatumArray = DeconstructArrayObject(relationIdArrayObject);
86 
87 	for (int relationIndex = 0; relationIndex < relationCount; relationIndex++)
88 	{
89 		Oid nextRelationOid = DatumGetObjectId(relationIdDatumArray[relationIndex]);
90 
91 		/* we require that the user either owns all tables or is superuser */
92 		EnsureTableOwner(nextRelationOid);
93 
94 		MarkTablesColocated(sourceRelationId, nextRelationOid);
95 	}
96 
97 	PG_RETURN_VOID();
98 }
99 
100 
101 /*
102  * update_distributed_table_colocation updates the colocation of a table.
103  * if colocate_with -> 'none' then the table is assigned a new
104  * colocation group.
105  */
106 Datum
update_distributed_table_colocation(PG_FUNCTION_ARGS)107 update_distributed_table_colocation(PG_FUNCTION_ARGS)
108 {
109 	CheckCitusVersion(ERROR);
110 	EnsureCoordinator();
111 
112 	Oid targetRelationId = PG_GETARG_OID(0);
113 	text *colocateWithTableNameText = PG_GETARG_TEXT_P(1);
114 
115 	EnsureTableOwner(targetRelationId);
116 
117 	char *colocateWithTableName = text_to_cstring(colocateWithTableNameText);
118 	if (IsColocateWithNone(colocateWithTableName))
119 	{
120 		EnsureHashDistributedTable(targetRelationId);
121 		BreakColocation(targetRelationId);
122 	}
123 	else
124 	{
125 		Oid colocateWithTableId = ResolveRelationId(colocateWithTableNameText, false);
126 		EnsureTableOwner(colocateWithTableId);
127 		MarkTablesColocated(colocateWithTableId, targetRelationId);
128 	}
129 	PG_RETURN_VOID();
130 }
131 
132 
133 /*
134  * IsColocateWithNone returns true if the given table is
135  * the special keyword "none".
136  */
137 bool
IsColocateWithNone(char * colocateWithTableName)138 IsColocateWithNone(char *colocateWithTableName)
139 {
140 	return pg_strncasecmp(colocateWithTableName, "none", NAMEDATALEN) == 0;
141 }
142 
143 
144 /*
145  * BreakColocation breaks the colocations of the given relation id.
146  * If t1, t2 and t3 are colocated and we call this function with t2,
147  * t1 and t3 will stay colocated but t2 will have a new colocation id.
148  * Note that this function does not move any data around for the new colocation.
149  */
150 static void
BreakColocation(Oid sourceRelationId)151 BreakColocation(Oid sourceRelationId)
152 {
153 	/*
154 	 * Get an exclusive lock on the colocation system catalog. Therefore, we
155 	 * can be sure that there will no modifications on the colocation table
156 	 * until this transaction is committed.
157 	 */
158 	Relation pgDistColocation = table_open(DistColocationRelationId(), ExclusiveLock);
159 
160 	uint32 newColocationId = GetNextColocationId();
161 	bool localOnly = false;
162 	UpdateRelationColocationGroup(sourceRelationId, newColocationId, localOnly);
163 
164 	/* if there is not any remaining table in the colocation group, delete it */
165 	DeleteColocationGroupIfNoTablesBelong(sourceRelationId);
166 
167 	table_close(pgDistColocation, NoLock);
168 }
169 
170 
171 /*
172  * get_colocated_shards_array returns array of shards ids which are co-located with given
173  * shard.
174  */
175 Datum
get_colocated_shard_array(PG_FUNCTION_ARGS)176 get_colocated_shard_array(PG_FUNCTION_ARGS)
177 {
178 	uint32 shardId = PG_GETARG_UINT32(0);
179 	ShardInterval *shardInterval = LoadShardInterval(shardId);
180 
181 	List *colocatedShardList = ColocatedShardIntervalList(shardInterval);
182 	int colocatedShardCount = list_length(colocatedShardList);
183 	Datum *colocatedShardsDatumArray = palloc0(colocatedShardCount * sizeof(Datum));
184 	Oid arrayTypeId = OIDOID;
185 	int colocatedShardIndex = 0;
186 
187 	ShardInterval *colocatedShardInterval = NULL;
188 	foreach_ptr(colocatedShardInterval, colocatedShardList)
189 	{
190 		uint64 colocatedShardId = colocatedShardInterval->shardId;
191 
192 		Datum colocatedShardDatum = Int64GetDatum(colocatedShardId);
193 
194 		colocatedShardsDatumArray[colocatedShardIndex] = colocatedShardDatum;
195 		colocatedShardIndex++;
196 	}
197 
198 	ArrayType *colocatedShardsArrayType = DatumArrayToArrayType(colocatedShardsDatumArray,
199 																colocatedShardCount,
200 																arrayTypeId);
201 
202 	PG_RETURN_ARRAYTYPE_P(colocatedShardsArrayType);
203 }
204 
205 
206 /*
207  * CreateColocationGroupForRelation creates colocation entry in
208  * pg_dist_colocation and updated the colocation id in pg_dist_partition
209  * for the given relation.
210  */
211 static uint32
CreateColocationGroupForRelation(Oid sourceRelationId)212 CreateColocationGroupForRelation(Oid sourceRelationId)
213 {
214 	uint32 shardCount = ShardIntervalCount(sourceRelationId);
215 	uint32 shardReplicationFactor = TableShardReplicationFactor(sourceRelationId);
216 
217 	Var *sourceDistributionColumn = DistPartitionKey(sourceRelationId);
218 	Oid sourceDistributionColumnType = InvalidOid;
219 	Oid sourceDistributionColumnCollation = InvalidOid;
220 
221 	/* reference tables has NULL distribution column */
222 	if (sourceDistributionColumn != NULL)
223 	{
224 		sourceDistributionColumnType = sourceDistributionColumn->vartype;
225 		sourceDistributionColumnCollation = sourceDistributionColumn->varcollid;
226 	}
227 
228 	uint32 sourceColocationId = CreateColocationGroup(shardCount, shardReplicationFactor,
229 													  sourceDistributionColumnType,
230 													  sourceDistributionColumnCollation);
231 	bool localOnly = false;
232 	UpdateRelationColocationGroup(sourceRelationId, sourceColocationId, localOnly);
233 	return sourceColocationId;
234 }
235 
236 
237 /*
238  * MarkTablesColocated puts both tables to same colocation group. If the
239  * source table is in INVALID_COLOCATION_ID group, then it creates a new
240  * colocation group and assigns both tables to same colocation group. Otherwise,
241  * it adds the target table to colocation group of the source table.
242  */
243 static void
MarkTablesColocated(Oid sourceRelationId,Oid targetRelationId)244 MarkTablesColocated(Oid sourceRelationId, Oid targetRelationId)
245 {
246 	if (IsCitusTableType(sourceRelationId, CITUS_LOCAL_TABLE) ||
247 		IsCitusTableType(targetRelationId, CITUS_LOCAL_TABLE))
248 	{
249 		ereport(ERROR, (errmsg("local tables cannot be colocated with "
250 							   "other tables")));
251 	}
252 
253 	EnsureHashDistributedTable(sourceRelationId);
254 	EnsureHashDistributedTable(targetRelationId);
255 	CheckReplicationModel(sourceRelationId, targetRelationId);
256 	CheckDistributionColumnType(sourceRelationId, targetRelationId);
257 
258 	/*
259 	 * Get an exclusive lock on the colocation system catalog. Therefore, we
260 	 * can be sure that there will no modifications on the colocation table
261 	 * until this transaction is committed.
262 	 */
263 	Relation pgDistColocation = table_open(DistColocationRelationId(), ExclusiveLock);
264 
265 	/* check if shard placements are colocated */
266 	ErrorIfShardPlacementsNotColocated(sourceRelationId, targetRelationId);
267 
268 	/*
269 	 * Get colocation group of the source table, if the source table does not
270 	 * have a colocation group, create a new one, and set it for the source table.
271 	 */
272 	uint32 sourceColocationId = TableColocationId(sourceRelationId);
273 	if (sourceColocationId == INVALID_COLOCATION_ID)
274 	{
275 		sourceColocationId = CreateColocationGroupForRelation(sourceRelationId);
276 	}
277 
278 	uint32 targetColocationId = TableColocationId(targetRelationId);
279 
280 	/* finally set colocation group for the target relation */
281 	bool localOnly = false;
282 	UpdateRelationColocationGroup(targetRelationId, sourceColocationId, localOnly);
283 
284 	/* if there is not any remaining table in the colocation group, delete it */
285 	DeleteColocationGroupIfNoTablesBelong(targetColocationId);
286 
287 	table_close(pgDistColocation, NoLock);
288 }
289 
290 
291 /*
292  * ErrorIfShardPlacementsNotColocated checks if the shard placements of the
293  * given two relations are physically colocated. It errors out in any of
294  * following cases:
295  * 1.Shard counts are different,
296  * 2.Shard intervals don't match
297  * 3.Matching shard intervals have different number of shard placements
298  * 4.Shard placements are not colocated (not on the same node)
299  * 5.Shard placements have different health states
300  *
301  * Note that, this functions assumes that both tables are hash distributed.
302  */
303 void
ErrorIfShardPlacementsNotColocated(Oid leftRelationId,Oid rightRelationId)304 ErrorIfShardPlacementsNotColocated(Oid leftRelationId, Oid rightRelationId)
305 {
306 	ListCell *leftShardIntervalCell = NULL;
307 	ListCell *rightShardIntervalCell = NULL;
308 
309 	/* get sorted shard interval lists for both tables */
310 	List *leftShardIntervalList = LoadShardIntervalList(leftRelationId);
311 	List *rightShardIntervalList = LoadShardIntervalList(rightRelationId);
312 
313 	/* prevent concurrent placement changes */
314 	LockShardListMetadata(leftShardIntervalList, ShareLock);
315 	LockShardListMetadata(rightShardIntervalList, ShareLock);
316 
317 	char *leftRelationName = get_rel_name(leftRelationId);
318 	char *rightRelationName = get_rel_name(rightRelationId);
319 
320 	uint32 leftShardCount = list_length(leftShardIntervalList);
321 	uint32 rightShardCount = list_length(rightShardIntervalList);
322 
323 	if (leftShardCount != rightShardCount)
324 	{
325 		ereport(ERROR, (errmsg("cannot colocate tables %s and %s",
326 							   leftRelationName, rightRelationName),
327 						errdetail("Shard counts don't match for %s and %s.",
328 								  leftRelationName, rightRelationName)));
329 	}
330 
331 	/* compare shard intervals one by one */
332 	forboth(leftShardIntervalCell, leftShardIntervalList,
333 			rightShardIntervalCell, rightShardIntervalList)
334 	{
335 		ShardInterval *leftInterval = (ShardInterval *) lfirst(leftShardIntervalCell);
336 		ShardInterval *rightInterval = (ShardInterval *) lfirst(rightShardIntervalCell);
337 
338 		ListCell *leftPlacementCell = NULL;
339 		ListCell *rightPlacementCell = NULL;
340 
341 		uint64 leftShardId = leftInterval->shardId;
342 		uint64 rightShardId = rightInterval->shardId;
343 
344 		bool shardsIntervalsEqual = ShardsIntervalsEqual(leftInterval, rightInterval);
345 		if (!shardsIntervalsEqual)
346 		{
347 			ereport(ERROR, (errmsg("cannot colocate tables %s and %s",
348 								   leftRelationName, rightRelationName),
349 							errdetail("Shard intervals don't match for %s and %s.",
350 									  leftRelationName, rightRelationName)));
351 		}
352 
353 		List *leftPlacementList = ShardPlacementListWithoutOrphanedPlacements(
354 			leftShardId);
355 		List *rightPlacementList = ShardPlacementListWithoutOrphanedPlacements(
356 			rightShardId);
357 
358 		if (list_length(leftPlacementList) != list_length(rightPlacementList))
359 		{
360 			ereport(ERROR, (errmsg("cannot colocate tables %s and %s",
361 								   leftRelationName, rightRelationName),
362 							errdetail("Shard " UINT64_FORMAT
363 									  " of %s and shard " UINT64_FORMAT
364 									  " of %s have different number of shard placements.",
365 									  leftShardId, leftRelationName,
366 									  rightShardId, rightRelationName)));
367 		}
368 
369 		/* sort shard placements according to the node */
370 		List *sortedLeftPlacementList = SortList(leftPlacementList,
371 												 CompareShardPlacementsByNode);
372 		List *sortedRightPlacementList = SortList(rightPlacementList,
373 												  CompareShardPlacementsByNode);
374 
375 		/* compare shard placements one by one */
376 		forboth(leftPlacementCell, sortedLeftPlacementList,
377 				rightPlacementCell, sortedRightPlacementList)
378 		{
379 			ShardPlacement *leftPlacement =
380 				(ShardPlacement *) lfirst(leftPlacementCell);
381 			ShardPlacement *rightPlacement =
382 				(ShardPlacement *) lfirst(rightPlacementCell);
383 
384 			/*
385 			 * If shard placements are on different nodes, these shard
386 			 * placements are not colocated.
387 			 */
388 			int nodeCompare = CompareShardPlacementsByNode((void *) &leftPlacement,
389 														   (void *) &rightPlacement);
390 			if (nodeCompare != 0)
391 			{
392 				ereport(ERROR, (errmsg("cannot colocate tables %s and %s",
393 									   leftRelationName, rightRelationName),
394 								errdetail("Shard " UINT64_FORMAT " of %s and shard "
395 										  UINT64_FORMAT " of %s are not colocated.",
396 										  leftShardId, leftRelationName,
397 										  rightShardId, rightRelationName)));
398 			}
399 
400 			/* we also don't allow colocated shards to be in different shard states */
401 			if (leftPlacement->shardState != rightPlacement->shardState)
402 			{
403 				ereport(ERROR, (errmsg("cannot colocate tables %s and %s",
404 									   leftRelationName, rightRelationName),
405 								errdetail("%s and %s have shard placements in "
406 										  "different shard states.",
407 										  leftRelationName, rightRelationName)));
408 			}
409 		}
410 	}
411 }
412 
413 
414 /*
415  * ShardsIntervalsEqual checks if two shard intervals of distributed
416  * tables are equal.
417  *
418  * Notes on the function:
419  * (i)   The function returns true if both shard intervals are the same.
420  * (ii)  The function returns false even if the shard intervals equal, but,
421  *       their distribution method are different.
422  * (iii) The function returns false for append and range partitioned tables
423  *       excluding (i) case.
424  * (iv)  For reference tables, all shards are equal (i.e., same replication factor
425  *       and shard min/max values). Thus, always return true for shards of reference
426  *       tables.
427  */
428 static bool
ShardsIntervalsEqual(ShardInterval * leftShardInterval,ShardInterval * rightShardInterval)429 ShardsIntervalsEqual(ShardInterval *leftShardInterval, ShardInterval *rightShardInterval)
430 {
431 	char leftIntervalPartitionMethod = PartitionMethod(leftShardInterval->relationId);
432 	char rightIntervalPartitionMethod = PartitionMethod(rightShardInterval->relationId);
433 
434 	/* if both shards are the same, return true */
435 	if (leftShardInterval->shardId == rightShardInterval->shardId)
436 	{
437 		return true;
438 	}
439 
440 	/* if partition methods are not the same, shards cannot be considered as co-located */
441 	leftIntervalPartitionMethod = PartitionMethod(leftShardInterval->relationId);
442 	rightIntervalPartitionMethod = PartitionMethod(rightShardInterval->relationId);
443 	if (leftIntervalPartitionMethod != rightIntervalPartitionMethod)
444 	{
445 		return false;
446 	}
447 
448 	if (IsCitusTableType(leftShardInterval->relationId, HASH_DISTRIBUTED))
449 	{
450 		return HashPartitionedShardIntervalsEqual(leftShardInterval, rightShardInterval);
451 	}
452 	else if (IsCitusTableType(leftShardInterval->relationId,
453 							  CITUS_TABLE_WITH_NO_DIST_KEY))
454 	{
455 		/*
456 		 * Reference tables has only a single shard and all reference tables
457 		 * are always co-located with each other.
458 		 */
459 
460 		return true;
461 	}
462 
463 	/* append and range partitioned shard never co-located */
464 	return false;
465 }
466 
467 
468 /*
469  * HashPartitionedShardIntervalsEqual checks if two shard intervals of hash distributed
470  * tables are equal. Note that, this function doesn't work with non-hash
471  * partitioned table's shards.
472  *
473  * We do min/max value check here to decide whether two shards are colocated,
474  * instead we can simply use ShardIndex function on both shards then
475  * but do index check, but we avoid it because this way it is more cheaper.
476  */
477 static bool
HashPartitionedShardIntervalsEqual(ShardInterval * leftShardInterval,ShardInterval * rightShardInterval)478 HashPartitionedShardIntervalsEqual(ShardInterval *leftShardInterval,
479 								   ShardInterval *rightShardInterval)
480 {
481 	int32 leftShardMinValue = DatumGetInt32(leftShardInterval->minValue);
482 	int32 leftShardMaxValue = DatumGetInt32(leftShardInterval->maxValue);
483 	int32 rightShardMinValue = DatumGetInt32(rightShardInterval->minValue);
484 	int32 rightShardMaxValue = DatumGetInt32(rightShardInterval->maxValue);
485 
486 	bool minValuesEqual = leftShardMinValue == rightShardMinValue;
487 	bool maxValuesEqual = leftShardMaxValue == rightShardMaxValue;
488 
489 	return minValuesEqual && maxValuesEqual;
490 }
491 
492 
493 /*
494  * CompareShardPlacementsByNode compares two shard placements by their nodename
495  * and nodeport.
496  */
497 static int
CompareShardPlacementsByNode(const void * leftElement,const void * rightElement)498 CompareShardPlacementsByNode(const void *leftElement, const void *rightElement)
499 {
500 	const ShardPlacement *leftPlacement = *((const ShardPlacement **) leftElement);
501 	const ShardPlacement *rightPlacement = *((const ShardPlacement **) rightElement);
502 
503 	/* if node names are same, check node ports */
504 	if (leftPlacement->nodeId < rightPlacement->nodeId)
505 	{
506 		return -1;
507 	}
508 	else if (leftPlacement->nodeId > rightPlacement->nodeId)
509 	{
510 		return 1;
511 	}
512 	else
513 	{
514 		return 0;
515 	}
516 }
517 
518 
519 /*
520  * ColocationId searches pg_dist_colocation for shard count, replication factor,
521  * distribution column type, and distribution column collation. If a matching entry
522  * is found, it returns the colocation id, otherwise returns INVALID_COLOCATION_ID.
523  */
524 uint32
ColocationId(int shardCount,int replicationFactor,Oid distributionColumnType,Oid distributionColumnCollation)525 ColocationId(int shardCount, int replicationFactor, Oid distributionColumnType, Oid
526 			 distributionColumnCollation)
527 {
528 	uint32 colocationId = INVALID_COLOCATION_ID;
529 	const int scanKeyCount = 4;
530 	ScanKeyData scanKey[4];
531 	bool indexOK = true;
532 
533 	Relation pgDistColocation = table_open(DistColocationRelationId(), AccessShareLock);
534 
535 	/* set scan arguments */
536 	ScanKeyInit(&scanKey[0], Anum_pg_dist_colocation_distributioncolumntype,
537 				BTEqualStrategyNumber, F_OIDEQ, ObjectIdGetDatum(distributionColumnType));
538 	ScanKeyInit(&scanKey[1], Anum_pg_dist_colocation_shardcount,
539 				BTEqualStrategyNumber, F_INT4EQ, UInt32GetDatum(shardCount));
540 	ScanKeyInit(&scanKey[2], Anum_pg_dist_colocation_replicationfactor,
541 				BTEqualStrategyNumber, F_INT4EQ, Int32GetDatum(replicationFactor));
542 	ScanKeyInit(&scanKey[3], Anum_pg_dist_colocation_distributioncolumncollation,
543 				BTEqualStrategyNumber, F_OIDEQ, ObjectIdGetDatum(
544 					distributionColumnCollation));
545 
546 	SysScanDesc scanDescriptor = systable_beginscan(pgDistColocation,
547 													DistColocationConfigurationIndexId(),
548 													indexOK, NULL, scanKeyCount, scanKey);
549 
550 	HeapTuple colocationTuple = systable_getnext(scanDescriptor);
551 
552 	while (HeapTupleIsValid(colocationTuple))
553 	{
554 		Form_pg_dist_colocation colocationForm =
555 			(Form_pg_dist_colocation) GETSTRUCT(colocationTuple);
556 
557 		if (colocationId == INVALID_COLOCATION_ID || colocationId >
558 			colocationForm->colocationid)
559 		{
560 			/*
561 			 * We assign the smallest colocation id among all the matches so that we
562 			 * assign the same colocation group for similar distributed tables
563 			 */
564 			colocationId = colocationForm->colocationid;
565 		}
566 		colocationTuple = systable_getnext(scanDescriptor);
567 	}
568 
569 	systable_endscan(scanDescriptor);
570 	table_close(pgDistColocation, AccessShareLock);
571 
572 	return colocationId;
573 }
574 
575 
576 /*
577  * CreateColocationGroup creates a new colocation id and writes it into
578  * pg_dist_colocation with the given configuration. It also returns the created
579  * colocation id.
580  */
581 uint32
CreateColocationGroup(int shardCount,int replicationFactor,Oid distributionColumnType,Oid distributionColumnCollation)582 CreateColocationGroup(int shardCount, int replicationFactor, Oid distributionColumnType,
583 					  Oid distributionColumnCollation)
584 {
585 	uint32 colocationId = GetNextColocationId();
586 	Datum values[Natts_pg_dist_colocation];
587 	bool isNulls[Natts_pg_dist_colocation];
588 
589 	/* form new colocation tuple */
590 	memset(values, 0, sizeof(values));
591 	memset(isNulls, false, sizeof(isNulls));
592 
593 	values[Anum_pg_dist_colocation_colocationid - 1] = UInt32GetDatum(colocationId);
594 	values[Anum_pg_dist_colocation_shardcount - 1] = UInt32GetDatum(shardCount);
595 	values[Anum_pg_dist_colocation_replicationfactor - 1] =
596 		UInt32GetDatum(replicationFactor);
597 	values[Anum_pg_dist_colocation_distributioncolumntype - 1] =
598 		ObjectIdGetDatum(distributionColumnType);
599 	values[Anum_pg_dist_colocation_distributioncolumncollation - 1] =
600 		ObjectIdGetDatum(distributionColumnCollation);
601 
602 	/* open colocation relation and insert the new tuple */
603 	Relation pgDistColocation = table_open(DistColocationRelationId(), RowExclusiveLock);
604 
605 	TupleDesc tupleDescriptor = RelationGetDescr(pgDistColocation);
606 	HeapTuple heapTuple = heap_form_tuple(tupleDescriptor, values, isNulls);
607 
608 	CatalogTupleInsert(pgDistColocation, heapTuple);
609 
610 	/* increment the counter so that next command can see the row */
611 	CommandCounterIncrement();
612 	table_close(pgDistColocation, RowExclusiveLock);
613 
614 	return colocationId;
615 }
616 
617 
618 /*
619  * GetNextColocationId allocates and returns a unique colocationId for the
620  * colocation group to be created. This allocation occurs both in shared memory
621  * and in write ahead logs; writing to logs avoids the risk of having
622  * colocationId collisions.
623  *
624  * Please note that the caller is still responsible for finalizing colocationId
625  * with the master node. Further note that this function relies on an internal
626  * sequence created in initdb to generate unique identifiers.
627  */
628 uint32
GetNextColocationId()629 GetNextColocationId()
630 {
631 	text *sequenceName = cstring_to_text(COLOCATIONID_SEQUENCE_NAME);
632 	Oid sequenceId = ResolveRelationId(sequenceName, false);
633 	Datum sequenceIdDatum = ObjectIdGetDatum(sequenceId);
634 	Oid savedUserId = InvalidOid;
635 	int savedSecurityContext = 0;
636 
637 	GetUserIdAndSecContext(&savedUserId, &savedSecurityContext);
638 	SetUserIdAndSecContext(CitusExtensionOwner(), SECURITY_LOCAL_USERID_CHANGE);
639 
640 	/* generate new and unique colocation id from sequence */
641 	Datum colocationIdDatum = DirectFunctionCall1(nextval_oid, sequenceIdDatum);
642 
643 	SetUserIdAndSecContext(savedUserId, savedSecurityContext);
644 
645 	uint32 colocationId = DatumGetUInt32(colocationIdDatum);
646 
647 	return colocationId;
648 }
649 
650 
651 /*
652  * CheckReplicationModel checks if given relations are from the same
653  * replication model. Otherwise, it errors out.
654  */
655 void
CheckReplicationModel(Oid sourceRelationId,Oid targetRelationId)656 CheckReplicationModel(Oid sourceRelationId, Oid targetRelationId)
657 {
658 	CitusTableCacheEntry *sourceTableEntry = GetCitusTableCacheEntry(sourceRelationId);
659 	char sourceReplicationModel = sourceTableEntry->replicationModel;
660 
661 	CitusTableCacheEntry *targetTableEntry = GetCitusTableCacheEntry(targetRelationId);
662 	char targetReplicationModel = targetTableEntry->replicationModel;
663 
664 	if (sourceReplicationModel != targetReplicationModel)
665 	{
666 		char *sourceRelationName = get_rel_name(sourceRelationId);
667 		char *targetRelationName = get_rel_name(targetRelationId);
668 
669 		ereport(ERROR, (errmsg("cannot colocate tables %s and %s",
670 							   sourceRelationName, targetRelationName),
671 						errdetail("Replication models don't match for %s and %s.",
672 								  sourceRelationName, targetRelationName)));
673 	}
674 }
675 
676 
677 /*
678  * CheckDistributionColumnType checks if distribution column types of relations
679  * are same. Otherwise, it errors out.
680  */
681 void
CheckDistributionColumnType(Oid sourceRelationId,Oid targetRelationId)682 CheckDistributionColumnType(Oid sourceRelationId, Oid targetRelationId)
683 {
684 	/* reference tables have NULL distribution column */
685 	Var *sourceDistributionColumn = DistPartitionKey(sourceRelationId);
686 
687 	/* reference tables have NULL distribution column */
688 	Var *targetDistributionColumn = DistPartitionKey(targetRelationId);
689 
690 	EnsureColumnTypeEquality(sourceRelationId, targetRelationId,
691 							 sourceDistributionColumn, targetDistributionColumn);
692 }
693 
694 
695 /*
696  * GetColumnTypeEquality checks if distribution column types and collations
697  * of the given columns are same. The function sets the boolean pointers.
698  */
699 void
EnsureColumnTypeEquality(Oid sourceRelationId,Oid targetRelationId,Var * sourceDistributionColumn,Var * targetDistributionColumn)700 EnsureColumnTypeEquality(Oid sourceRelationId, Oid targetRelationId,
701 						 Var *sourceDistributionColumn, Var *targetDistributionColumn)
702 {
703 	Oid sourceDistributionColumnType = InvalidOid;
704 	Oid targetDistributionColumnType = InvalidOid;
705 	Oid sourceDistributionColumnCollation = InvalidOid;
706 	Oid targetDistributionColumnCollation = InvalidOid;
707 
708 	if (sourceDistributionColumn != NULL)
709 	{
710 		sourceDistributionColumnType = sourceDistributionColumn->vartype;
711 		sourceDistributionColumnCollation = sourceDistributionColumn->varcollid;
712 	}
713 
714 	if (targetDistributionColumn != NULL)
715 	{
716 		targetDistributionColumnType = targetDistributionColumn->vartype;
717 		targetDistributionColumnCollation = targetDistributionColumn->varcollid;
718 	}
719 
720 	bool columnTypesSame = sourceDistributionColumnType == targetDistributionColumnType;
721 	bool columnCollationsSame =
722 		sourceDistributionColumnCollation == targetDistributionColumnCollation;
723 
724 	if (!columnTypesSame)
725 	{
726 		char *sourceRelationName = get_rel_name(sourceRelationId);
727 		char *targetRelationName = get_rel_name(targetRelationId);
728 
729 		ereport(ERROR, (errmsg("cannot colocate tables %s and %s",
730 							   sourceRelationName, targetRelationName),
731 						errdetail("Distribution column types don't match for "
732 								  "%s and %s.", sourceRelationName,
733 								  targetRelationName)));
734 	}
735 
736 	if (!columnCollationsSame)
737 	{
738 		char *sourceRelationName = get_rel_name(sourceRelationId);
739 		char *targetRelationName = get_rel_name(targetRelationId);
740 
741 		ereport(ERROR, (errmsg("cannot colocate tables %s and %s",
742 							   sourceRelationName, targetRelationName),
743 						errdetail(
744 							"Distribution column collations don't match for "
745 							"%s and %s.", sourceRelationName,
746 							targetRelationName)));
747 	}
748 }
749 
750 
751 /*
752  * UpdateRelationColocationGroup updates colocation group in pg_dist_partition
753  * for the given relation.
754  *
755  * When localOnly is true, the function does not propagate changes to the
756  * metadata workers.
757  */
758 void
UpdateRelationColocationGroup(Oid distributedRelationId,uint32 colocationId,bool localOnly)759 UpdateRelationColocationGroup(Oid distributedRelationId, uint32 colocationId,
760 							  bool localOnly)
761 {
762 	bool indexOK = true;
763 	int scanKeyCount = 1;
764 	ScanKeyData scanKey[1];
765 	Datum values[Natts_pg_dist_partition];
766 	bool isNull[Natts_pg_dist_partition];
767 	bool replace[Natts_pg_dist_partition];
768 
769 	Relation pgDistPartition = table_open(DistPartitionRelationId(), RowExclusiveLock);
770 	TupleDesc tupleDescriptor = RelationGetDescr(pgDistPartition);
771 
772 	ScanKeyInit(&scanKey[0], Anum_pg_dist_partition_logicalrelid,
773 				BTEqualStrategyNumber, F_OIDEQ, ObjectIdGetDatum(distributedRelationId));
774 
775 	SysScanDesc scanDescriptor = systable_beginscan(pgDistPartition,
776 													DistPartitionLogicalRelidIndexId(),
777 													indexOK,
778 													NULL, scanKeyCount, scanKey);
779 
780 	HeapTuple heapTuple = systable_getnext(scanDescriptor);
781 	if (!HeapTupleIsValid(heapTuple))
782 	{
783 		char *distributedRelationName = get_rel_name(distributedRelationId);
784 		ereport(ERROR, (errmsg("could not find valid entry for relation %s",
785 							   distributedRelationName)));
786 	}
787 
788 	memset(values, 0, sizeof(values));
789 	memset(isNull, false, sizeof(isNull));
790 	memset(replace, false, sizeof(replace));
791 
792 	values[Anum_pg_dist_partition_colocationid - 1] = UInt32GetDatum(colocationId);
793 	isNull[Anum_pg_dist_partition_colocationid - 1] = false;
794 	replace[Anum_pg_dist_partition_colocationid - 1] = true;
795 
796 	heapTuple = heap_modify_tuple(heapTuple, tupleDescriptor, values, isNull, replace);
797 
798 
799 	CatalogTupleUpdate(pgDistPartition, &heapTuple->t_self, heapTuple);
800 
801 	CitusInvalidateRelcacheByRelid(distributedRelationId);
802 
803 	CommandCounterIncrement();
804 
805 	systable_endscan(scanDescriptor);
806 	table_close(pgDistPartition, NoLock);
807 
808 	bool shouldSyncMetadata = ShouldSyncTableMetadata(distributedRelationId);
809 	if (shouldSyncMetadata && !localOnly)
810 	{
811 		char *updateColocationIdCommand = ColocationIdUpdateCommand(distributedRelationId,
812 																	colocationId);
813 
814 		SendCommandToWorkersWithMetadata(updateColocationIdCommand);
815 	}
816 }
817 
818 
819 /*
820  * TableColocationId function returns co-location id of given table. This function
821  * errors out if given table is not distributed.
822  */
823 uint32
TableColocationId(Oid distributedTableId)824 TableColocationId(Oid distributedTableId)
825 {
826 	CitusTableCacheEntry *cacheEntry = GetCitusTableCacheEntry(distributedTableId);
827 
828 	return cacheEntry->colocationId;
829 }
830 
831 
832 /*
833  * TablesColocated function checks whether given two tables are co-located and
834  * returns true if they are co-located. A table is always co-located with itself.
835  * If given two tables are different and they are not distributed, this function
836  * errors out.
837  */
838 bool
TablesColocated(Oid leftDistributedTableId,Oid rightDistributedTableId)839 TablesColocated(Oid leftDistributedTableId, Oid rightDistributedTableId)
840 {
841 	if (leftDistributedTableId == rightDistributedTableId)
842 	{
843 		return true;
844 	}
845 
846 	uint32 leftColocationId = TableColocationId(leftDistributedTableId);
847 	uint32 rightColocationId = TableColocationId(rightDistributedTableId);
848 	if (leftColocationId == INVALID_COLOCATION_ID ||
849 		rightColocationId == INVALID_COLOCATION_ID)
850 	{
851 		return false;
852 	}
853 
854 	return leftColocationId == rightColocationId;
855 }
856 
857 
858 /*
859  * ShardsColocated function checks whether given two shards are co-located and
860  * returns true if they are co-located. Two shards are co-located either;
861  * - They are same (A shard is always co-located with itself).
862  * OR
863  * - Tables are hash partitioned.
864  * - Tables containing the shards are co-located.
865  * - Min/Max values of the shards are same.
866  */
867 bool
ShardsColocated(ShardInterval * leftShardInterval,ShardInterval * rightShardInterval)868 ShardsColocated(ShardInterval *leftShardInterval, ShardInterval *rightShardInterval)
869 {
870 	bool tablesColocated = TablesColocated(leftShardInterval->relationId,
871 										   rightShardInterval->relationId);
872 
873 	if (tablesColocated)
874 	{
875 		bool shardIntervalEqual = ShardsIntervalsEqual(leftShardInterval,
876 													   rightShardInterval);
877 		return shardIntervalEqual;
878 	}
879 
880 	return false;
881 }
882 
883 
884 /*
885  * ColocatedTableList function returns list of relation ids which are co-located
886  * with given table. If given table is not hash distributed, co-location is not
887  * valid for that table and it is only co-located with itself.
888  */
889 List *
ColocatedTableList(Oid distributedTableId)890 ColocatedTableList(Oid distributedTableId)
891 {
892 	uint32 tableColocationId = TableColocationId(distributedTableId);
893 	List *colocatedTableList = NIL;
894 
895 	/*
896 	 * If distribution type of the table is not hash, the table is only co-located
897 	 * with itself.
898 	 */
899 	if (tableColocationId == INVALID_COLOCATION_ID)
900 	{
901 		colocatedTableList = lappend_oid(colocatedTableList, distributedTableId);
902 		return colocatedTableList;
903 	}
904 
905 	int count = 0;
906 	colocatedTableList = ColocationGroupTableList(tableColocationId, count);
907 
908 	return colocatedTableList;
909 }
910 
911 
912 /*
913  * ColocationGroupTableList returns the list of tables in the given colocation
914  * group. If the colocation group is INVALID_COLOCATION_ID, it returns NIL.
915  *
916  * If count is zero then the command is executed for all rows that it applies to.
917  * If count is greater than zero, then no more than count rows will be retrieved;
918  * execution stops when the count is reached, much like adding a LIMIT clause
919  * to the query.
920  */
921 List *
ColocationGroupTableList(uint32 colocationId,uint32 count)922 ColocationGroupTableList(uint32 colocationId, uint32 count)
923 {
924 	List *colocatedTableList = NIL;
925 	bool indexOK = true;
926 	int scanKeyCount = 1;
927 	ScanKeyData scanKey[1];
928 
929 	/*
930 	 * If distribution type of the table is not hash, the table is only co-located
931 	 * with itself.
932 	 */
933 	if (colocationId == INVALID_COLOCATION_ID)
934 	{
935 		return NIL;
936 	}
937 
938 	ScanKeyInit(&scanKey[0], Anum_pg_dist_partition_colocationid,
939 				BTEqualStrategyNumber, F_INT4EQ, UInt32GetDatum(colocationId));
940 
941 	Relation pgDistPartition = table_open(DistPartitionRelationId(), AccessShareLock);
942 	TupleDesc tupleDescriptor = RelationGetDescr(pgDistPartition);
943 	SysScanDesc scanDescriptor = systable_beginscan(pgDistPartition,
944 													DistPartitionColocationidIndexId(),
945 													indexOK, NULL, scanKeyCount, scanKey);
946 
947 	HeapTuple heapTuple = systable_getnext(scanDescriptor);
948 	while (HeapTupleIsValid(heapTuple))
949 	{
950 		bool isNull = false;
951 		Oid colocatedTableId = heap_getattr(heapTuple,
952 											Anum_pg_dist_partition_logicalrelid,
953 											tupleDescriptor, &isNull);
954 
955 		colocatedTableList = lappend_oid(colocatedTableList, colocatedTableId);
956 		heapTuple = systable_getnext(scanDescriptor);
957 
958 		if (count == 0)
959 		{
960 			/* fetch all rows */
961 			continue;
962 		}
963 		else if (list_length(colocatedTableList) >= count)
964 		{
965 			/* we are done */
966 			break;
967 		}
968 	}
969 
970 	systable_endscan(scanDescriptor);
971 	table_close(pgDistPartition, AccessShareLock);
972 
973 	return colocatedTableList;
974 }
975 
976 
977 /*
978  * ColocatedShardIntervalList function returns list of shard intervals which are
979  * co-located with given shard. If given shard is belong to append or range distributed
980  * table, co-location is not valid for that shard. Therefore such shard is only co-located
981  * with itself.
982  */
983 List *
ColocatedShardIntervalList(ShardInterval * shardInterval)984 ColocatedShardIntervalList(ShardInterval *shardInterval)
985 {
986 	Oid distributedTableId = shardInterval->relationId;
987 	List *colocatedShardList = NIL;
988 
989 	CitusTableCacheEntry *cacheEntry = GetCitusTableCacheEntry(distributedTableId);
990 
991 	/*
992 	 * If distribution type of the table is append or range, each shard of
993 	 * the shard is only co-located with itself.
994 	 */
995 	if (IsCitusTableTypeCacheEntry(cacheEntry, APPEND_DISTRIBUTED) ||
996 		IsCitusTableTypeCacheEntry(cacheEntry, RANGE_DISTRIBUTED))
997 	{
998 		ShardInterval *copyShardInterval = CopyShardInterval(shardInterval);
999 
1000 		colocatedShardList = lappend(colocatedShardList, copyShardInterval);
1001 
1002 		return colocatedShardList;
1003 	}
1004 
1005 	int shardIntervalIndex = ShardIndex(shardInterval);
1006 	List *colocatedTableList = ColocatedTableList(distributedTableId);
1007 
1008 	/* ShardIndex have to find index of given shard */
1009 	Assert(shardIntervalIndex >= 0);
1010 
1011 	Oid colocatedTableId = InvalidOid;
1012 	foreach_oid(colocatedTableId, colocatedTableList)
1013 	{
1014 		CitusTableCacheEntry *colocatedTableCacheEntry =
1015 			GetCitusTableCacheEntry(colocatedTableId);
1016 
1017 		/*
1018 		 * Since we iterate over co-located tables, shard count of each table should be
1019 		 * same and greater than shardIntervalIndex.
1020 		 */
1021 		Assert(cacheEntry->shardIntervalArrayLength ==
1022 			   colocatedTableCacheEntry->shardIntervalArrayLength);
1023 
1024 		ShardInterval *colocatedShardInterval =
1025 			colocatedTableCacheEntry->sortedShardIntervalArray[shardIntervalIndex];
1026 
1027 		ShardInterval *copyShardInterval = CopyShardInterval(colocatedShardInterval);
1028 
1029 		colocatedShardList = lappend(colocatedShardList, copyShardInterval);
1030 	}
1031 
1032 	Assert(list_length(colocatedTableList) == list_length(colocatedShardList));
1033 
1034 	return SortList(colocatedShardList, CompareShardIntervalsById);
1035 }
1036 
1037 
1038 /*
1039  * ColocatedNonPartitionShardIntervalList function returns list of shard intervals
1040  * which are co-located with given shard, except partitions. If given shard is belong
1041  * to append or range distributed table, co-location is not valid for that shard.
1042  * Therefore such shard is only co-located with itself.
1043  */
1044 List *
ColocatedNonPartitionShardIntervalList(ShardInterval * shardInterval)1045 ColocatedNonPartitionShardIntervalList(ShardInterval *shardInterval)
1046 {
1047 	Oid distributedTableId = shardInterval->relationId;
1048 	List *colocatedShardList = NIL;
1049 
1050 	CitusTableCacheEntry *cacheEntry = GetCitusTableCacheEntry(distributedTableId);
1051 
1052 	/*
1053 	 * If distribution type of the table is append or range, each shard of the shard
1054 	 * is only co-located with itself. We don't expect this case to happen, since
1055 	 * distributing partitioned tables in only supported for hash-distributed tables.
1056 	 * Therefore, currently we can't cover here with a test.
1057 	 */
1058 	if (IsCitusTableTypeCacheEntry(cacheEntry, APPEND_DISTRIBUTED) ||
1059 		IsCitusTableTypeCacheEntry(cacheEntry, RANGE_DISTRIBUTED))
1060 	{
1061 		ShardInterval *copyShardInterval = CopyShardInterval(shardInterval);
1062 
1063 		colocatedShardList = lappend(colocatedShardList, copyShardInterval);
1064 
1065 		return colocatedShardList;
1066 	}
1067 
1068 	ereport(DEBUG1, (errmsg("skipping child tables for relation named: %s",
1069 							get_rel_name(distributedTableId))));
1070 
1071 	int shardIntervalIndex = ShardIndex(shardInterval);
1072 	List *colocatedTableList = ColocatedTableList(distributedTableId);
1073 
1074 	/* ShardIndex have to find index of given shard */
1075 	Assert(shardIntervalIndex >= 0);
1076 
1077 	Oid colocatedTableId = InvalidOid;
1078 	foreach_oid(colocatedTableId, colocatedTableList)
1079 	{
1080 		if (PartitionTable(colocatedTableId))
1081 		{
1082 			continue;
1083 		}
1084 
1085 		CitusTableCacheEntry *colocatedTableCacheEntry =
1086 			GetCitusTableCacheEntry(colocatedTableId);
1087 
1088 		/*
1089 		 * Since we iterate over co-located tables, shard count of each table should be
1090 		 * same and greater than shardIntervalIndex.
1091 		 */
1092 		Assert(cacheEntry->shardIntervalArrayLength ==
1093 			   colocatedTableCacheEntry->shardIntervalArrayLength);
1094 
1095 		ShardInterval *colocatedShardInterval =
1096 			colocatedTableCacheEntry->sortedShardIntervalArray[shardIntervalIndex];
1097 
1098 		ShardInterval *copyShardInterval = CopyShardInterval(colocatedShardInterval);
1099 
1100 		colocatedShardList = lappend(colocatedShardList, copyShardInterval);
1101 	}
1102 
1103 	return SortList(colocatedShardList, CompareShardIntervalsById);
1104 }
1105 
1106 
1107 /*
1108  * ColocatedTableId returns an arbitrary table which belongs to given colocation
1109  * group. If there is not such a colocation group, it returns invalid oid.
1110  *
1111  * This function also takes an AccessShareLock on the co-colocated table to
1112  * guarantee that the table isn't dropped for the remainder of the transaction.
1113  */
1114 Oid
ColocatedTableId(Oid colocationId)1115 ColocatedTableId(Oid colocationId)
1116 {
1117 	Oid colocatedTableId = InvalidOid;
1118 	bool indexOK = true;
1119 	bool isNull = false;
1120 	ScanKeyData scanKey[1];
1121 	int scanKeyCount = 1;
1122 
1123 	/*
1124 	 * We may have a distributed table whose colocation id is INVALID_COLOCATION_ID.
1125 	 * In this case, we do not want to send that table's id as colocated table id.
1126 	 */
1127 	if (colocationId == INVALID_COLOCATION_ID)
1128 	{
1129 		return colocatedTableId;
1130 	}
1131 
1132 	ScanKeyInit(&scanKey[0], Anum_pg_dist_partition_colocationid,
1133 				BTEqualStrategyNumber, F_INT4EQ, ObjectIdGetDatum(colocationId));
1134 
1135 	Relation pgDistPartition = table_open(DistPartitionRelationId(), AccessShareLock);
1136 	TupleDesc tupleDescriptor = RelationGetDescr(pgDistPartition);
1137 	SysScanDesc scanDescriptor = systable_beginscan(pgDistPartition,
1138 													DistPartitionColocationidIndexId(),
1139 													indexOK, NULL, scanKeyCount, scanKey);
1140 
1141 	HeapTuple heapTuple = systable_getnext(scanDescriptor);
1142 	while (HeapTupleIsValid(heapTuple))
1143 	{
1144 		colocatedTableId = heap_getattr(heapTuple, Anum_pg_dist_partition_logicalrelid,
1145 										tupleDescriptor, &isNull);
1146 
1147 		/*
1148 		 * Make sure the relation isn't dropped for the remainder of
1149 		 * the transaction.
1150 		 */
1151 		LockRelationOid(colocatedTableId, AccessShareLock);
1152 
1153 		/*
1154 		 * The relation might have been dropped just before we locked it.
1155 		 * Let's look it up.
1156 		 */
1157 		Relation colocatedRelation = RelationIdGetRelation(colocatedTableId);
1158 		if (RelationIsValid(colocatedRelation))
1159 		{
1160 			/* relation still exists, we can use it */
1161 			RelationClose(colocatedRelation);
1162 			break;
1163 		}
1164 
1165 		/* relation was dropped, try the next one */
1166 		colocatedTableId = InvalidOid;
1167 
1168 		heapTuple = systable_getnext(scanDescriptor);
1169 	}
1170 
1171 	systable_endscan(scanDescriptor);
1172 	table_close(pgDistPartition, AccessShareLock);
1173 
1174 	return colocatedTableId;
1175 }
1176 
1177 
1178 /*
1179  * ColocatedShardIdInRelation returns shardId of the shard from given relation, so that
1180  * returned shard is co-located with given shard.
1181  */
1182 uint64
ColocatedShardIdInRelation(Oid relationId,int shardIndex)1183 ColocatedShardIdInRelation(Oid relationId, int shardIndex)
1184 {
1185 	CitusTableCacheEntry *tableCacheEntry = GetCitusTableCacheEntry(relationId);
1186 
1187 	return tableCacheEntry->sortedShardIntervalArray[shardIndex]->shardId;
1188 }
1189 
1190 
1191 /*
1192  * DeleteColocationGroupIfNoTablesBelong function deletes given co-location group if there
1193  * is no relation in that co-location group. A co-location group may become empty after
1194  * mark_tables_colocated or upgrade_reference_table UDF calls. In that case we need to
1195  * remove empty co-location group to prevent orphaned co-location groups.
1196  */
1197 void
DeleteColocationGroupIfNoTablesBelong(uint32 colocationId)1198 DeleteColocationGroupIfNoTablesBelong(uint32 colocationId)
1199 {
1200 	if (colocationId != INVALID_COLOCATION_ID)
1201 	{
1202 		int count = 1;
1203 		List *colocatedTableList = ColocationGroupTableList(colocationId, count);
1204 		int colocatedTableCount = list_length(colocatedTableList);
1205 
1206 		if (colocatedTableCount == 0)
1207 		{
1208 			DeleteColocationGroup(colocationId);
1209 		}
1210 	}
1211 }
1212 
1213 
1214 /*
1215  * DeleteColocationGroup deletes the colocation group from pg_dist_colocation.
1216  */
1217 static void
DeleteColocationGroup(uint32 colocationId)1218 DeleteColocationGroup(uint32 colocationId)
1219 {
1220 	int scanKeyCount = 1;
1221 	ScanKeyData scanKey[1];
1222 	bool indexOK = false;
1223 
1224 	Relation pgDistColocation = table_open(DistColocationRelationId(), RowExclusiveLock);
1225 
1226 	ScanKeyInit(&scanKey[0], Anum_pg_dist_colocation_colocationid,
1227 				BTEqualStrategyNumber, F_INT4EQ, UInt32GetDatum(colocationId));
1228 
1229 	SysScanDesc scanDescriptor = systable_beginscan(pgDistColocation, InvalidOid, indexOK,
1230 													NULL, scanKeyCount, scanKey);
1231 
1232 	/* if a record is found, delete it */
1233 	HeapTuple heapTuple = systable_getnext(scanDescriptor);
1234 	if (HeapTupleIsValid(heapTuple))
1235 	{
1236 		/*
1237 		 * simple_heap_delete() expects that the caller has at least an
1238 		 * AccessShareLock on replica identity index.
1239 		 */
1240 		Relation replicaIndex =
1241 			index_open(RelationGetReplicaIndex(pgDistColocation),
1242 					   AccessShareLock);
1243 		simple_heap_delete(pgDistColocation, &(heapTuple->t_self));
1244 
1245 		CitusInvalidateRelcacheByRelid(DistColocationRelationId());
1246 		CommandCounterIncrement();
1247 		table_close(replicaIndex, AccessShareLock);
1248 	}
1249 
1250 	systable_endscan(scanDescriptor);
1251 	table_close(pgDistColocation, RowExclusiveLock);
1252 }
1253