1 /*-------------------------------------------------------------------------
2  *
3  * shardinterval_utils.c
4  *
5  * This file contains functions to perform useful operations on shard intervals.
6  *
7  * Copyright (c) Citus Data, Inc.
8  *
9  *-------------------------------------------------------------------------
10  */
11 #include "stdint.h"
12 #include "postgres.h"
13 
14 #include "access/nbtree.h"
15 #include "catalog/pg_am.h"
16 #include "catalog/pg_collation.h"
17 #include "catalog/pg_type.h"
18 #include "distributed/listutils.h"
19 #include "distributed/metadata_cache.h"
20 #include "distributed/multi_join_order.h"
21 #include "distributed/distributed_planner.h"
22 #include "distributed/shard_pruning.h"
23 #include "distributed/shardinterval_utils.h"
24 #include "distributed/pg_dist_partition.h"
25 #include "distributed/worker_protocol.h"
26 #include "utils/catcache.h"
27 #include "utils/memutils.h"
28 
29 
30 /*
31  * SortedShardIntervalArray sorts the input shardIntervalArray. Shard intervals with
32  * no min/max values are placed at the end of the array.
33  */
34 ShardInterval **
SortShardIntervalArray(ShardInterval ** shardIntervalArray,int shardCount,Oid collation,FmgrInfo * shardIntervalSortCompareFunction)35 SortShardIntervalArray(ShardInterval **shardIntervalArray, int shardCount,
36 					   Oid collation, FmgrInfo *shardIntervalSortCompareFunction)
37 {
38 	SortShardIntervalContext sortContext = {
39 		.comparisonFunction = shardIntervalSortCompareFunction,
40 		.collation = collation
41 	};
42 
43 	/* short cut if there are no shard intervals in the array */
44 	if (shardCount == 0)
45 	{
46 		return shardIntervalArray;
47 	}
48 
49 	/* if a shard doesn't have min/max values, it's placed in the end of the array */
50 	qsort_arg(shardIntervalArray, shardCount, sizeof(ShardInterval *),
51 			  (qsort_arg_comparator) CompareShardIntervals, (void *) &sortContext);
52 
53 	return shardIntervalArray;
54 }
55 
56 
57 /*
58  * CompareShardIntervals acts as a helper function to compare two shard intervals
59  * by their minimum values, using the value's type comparison function.
60  *
61  * If a shard interval does not have min/max value, it's treated as being greater
62  * than the other.
63  */
64 int
CompareShardIntervals(const void * leftElement,const void * rightElement,SortShardIntervalContext * sortContext)65 CompareShardIntervals(const void *leftElement, const void *rightElement,
66 					  SortShardIntervalContext *sortContext)
67 {
68 	ShardInterval *leftShardInterval = *((ShardInterval **) leftElement);
69 	ShardInterval *rightShardInterval = *((ShardInterval **) rightElement);
70 	int comparisonResult = 0;
71 	bool leftHasNull = (!leftShardInterval->minValueExists ||
72 						!leftShardInterval->maxValueExists);
73 	bool rightHasNull = (!rightShardInterval->minValueExists ||
74 						 !rightShardInterval->maxValueExists);
75 
76 	Assert(sortContext->comparisonFunction != NULL);
77 
78 	if (leftHasNull && rightHasNull)
79 	{
80 		comparisonResult = 0;
81 	}
82 	else if (leftHasNull)
83 	{
84 		comparisonResult = 1;
85 	}
86 	else if (rightHasNull)
87 	{
88 		comparisonResult = -1;
89 	}
90 	else
91 	{
92 		/* if both shard interval have min/max values, calculate comparison result */
93 		Datum leftDatum = leftShardInterval->minValue;
94 		Datum rightDatum = rightShardInterval->minValue;
95 		Datum comparisonDatum = FunctionCall2Coll(sortContext->comparisonFunction,
96 												  sortContext->collation, leftDatum,
97 												  rightDatum);
98 		comparisonResult = DatumGetInt32(comparisonDatum);
99 	}
100 
101 	/* Two different shards should never be equal */
102 	if (comparisonResult == 0)
103 	{
104 		return CompareShardIntervalsById(leftElement, rightElement);
105 	}
106 
107 	return comparisonResult;
108 }
109 
110 
111 /*
112  * CompareShardIntervalsById is a comparison function for sort shard
113  * intervals by their shard ID.
114  */
115 int
CompareShardIntervalsById(const void * leftElement,const void * rightElement)116 CompareShardIntervalsById(const void *leftElement, const void *rightElement)
117 {
118 	ShardInterval *leftInterval = *((ShardInterval **) leftElement);
119 	ShardInterval *rightInterval = *((ShardInterval **) rightElement);
120 	int64 leftShardId = leftInterval->shardId;
121 	int64 rightShardId = rightInterval->shardId;
122 
123 	/* we compare 64-bit integers, instead of casting their difference to int */
124 	if (leftShardId > rightShardId)
125 	{
126 		return 1;
127 	}
128 	else if (leftShardId < rightShardId)
129 	{
130 		return -1;
131 	}
132 	else
133 	{
134 		return 0;
135 	}
136 }
137 
138 
139 /*
140  * CompareShardPlacementsByShardId is a comparison function for sorting shard
141  * placement by their shard ID.
142  */
143 int
CompareShardPlacementsByShardId(const void * leftElement,const void * rightElement)144 CompareShardPlacementsByShardId(const void *leftElement, const void *rightElement)
145 {
146 	GroupShardPlacement *left = *((GroupShardPlacement **) leftElement);
147 	GroupShardPlacement *right = *((GroupShardPlacement **) rightElement);
148 	int64 leftShardId = left->shardId;
149 	int64 rightShardId = right->shardId;
150 
151 	/* we compare 64-bit integers, instead of casting their difference to int */
152 	if (leftShardId > rightShardId)
153 	{
154 		return 1;
155 	}
156 	else if (leftShardId < rightShardId)
157 	{
158 		return -1;
159 	}
160 	else
161 	{
162 		return 0;
163 	}
164 }
165 
166 
167 /*
168  * CompareRelationShards is a comparison function for sorting relation
169  * to shard mappings by their relation ID and then shard ID.
170  */
171 int
CompareRelationShards(const void * leftElement,const void * rightElement)172 CompareRelationShards(const void *leftElement, const void *rightElement)
173 {
174 	RelationShard *leftRelationShard = *((RelationShard **) leftElement);
175 	RelationShard *rightRelationShard = *((RelationShard **) rightElement);
176 	Oid leftRelationId = leftRelationShard->relationId;
177 	Oid rightRelationId = rightRelationShard->relationId;
178 	int64 leftShardId = leftRelationShard->shardId;
179 	int64 rightShardId = rightRelationShard->shardId;
180 
181 	if (leftRelationId > rightRelationId)
182 	{
183 		return 1;
184 	}
185 	else if (leftRelationId < rightRelationId)
186 	{
187 		return -1;
188 	}
189 	else if (leftShardId > rightShardId)
190 	{
191 		return 1;
192 	}
193 	else if (leftShardId < rightShardId)
194 	{
195 		return -1;
196 	}
197 	else
198 	{
199 		return 0;
200 	}
201 }
202 
203 
204 /*
205  * ShardIndex finds the index of given shard in sorted shard interval array.
206  *
207  * For hash partitioned tables, it calculates hash value of a number in its
208  * range (e.g. min value) and finds which shard should contain the hashed
209  * value. For reference tables and citus local tables, it simply returns 0.
210  * For the other table types, the function errors out.
211  */
212 int
ShardIndex(ShardInterval * shardInterval)213 ShardIndex(ShardInterval *shardInterval)
214 {
215 	int shardIndex = INVALID_SHARD_INDEX;
216 	Oid distributedTableId = shardInterval->relationId;
217 	Datum shardMinValue = shardInterval->minValue;
218 
219 	CitusTableCacheEntry *cacheEntry = GetCitusTableCacheEntry(distributedTableId);
220 
221 	/*
222 	 * Note that, we can also support append and range distributed tables, but
223 	 * currently it is not required.
224 	 */
225 	if (!IsCitusTableTypeCacheEntry(cacheEntry, HASH_DISTRIBUTED) &&
226 		!IsCitusTableTypeCacheEntry(
227 			cacheEntry, CITUS_TABLE_WITH_NO_DIST_KEY))
228 	{
229 		ereport(ERROR, (errcode(ERRCODE_FEATURE_NOT_SUPPORTED),
230 						errmsg("finding index of a given shard is only supported for "
231 							   "hash distributed tables, reference tables and local "
232 							   "tables that are added to citus metadata")));
233 	}
234 
235 	/* short-circuit for reference tables */
236 	if (IsCitusTableTypeCacheEntry(cacheEntry, CITUS_TABLE_WITH_NO_DIST_KEY))
237 	{
238 		/*
239 		 * Reference tables and citus local tables have only a single shard,
240 		 * so the index is fixed to 0.
241 		 */
242 		shardIndex = 0;
243 
244 		return shardIndex;
245 	}
246 
247 	shardIndex = FindShardIntervalIndex(shardMinValue, cacheEntry);
248 
249 	return shardIndex;
250 }
251 
252 
253 /*
254  * FindShardInterval finds a single shard interval in the cache for the
255  * given partition column value. Note that reference tables do not have
256  * partition columns, thus, pass partitionColumnValue and compareFunction
257  * as NULL for them.
258  */
259 ShardInterval *
FindShardInterval(Datum partitionColumnValue,CitusTableCacheEntry * cacheEntry)260 FindShardInterval(Datum partitionColumnValue, CitusTableCacheEntry *cacheEntry)
261 {
262 	Datum searchedValue = partitionColumnValue;
263 
264 	if (IsCitusTableTypeCacheEntry(cacheEntry, HASH_DISTRIBUTED))
265 	{
266 		searchedValue = FunctionCall1Coll(cacheEntry->hashFunction,
267 										  cacheEntry->partitionColumn->varcollid,
268 										  partitionColumnValue);
269 	}
270 
271 	int shardIndex = FindShardIntervalIndex(searchedValue, cacheEntry);
272 
273 	if (shardIndex == INVALID_SHARD_INDEX)
274 	{
275 		return NULL;
276 	}
277 
278 	return cacheEntry->sortedShardIntervalArray[shardIndex];
279 }
280 
281 
282 /*
283  * FindShardIntervalIndex finds the index of the shard interval which covers
284  * the searched value. Note that the searched value must be the hashed value
285  * of the original value if the distribution method is hash.
286  *
287  * Note that, if the searched value can not be found for hash partitioned
288  * tables, we error out (unless there are no shards, in which case
289  * INVALID_SHARD_INDEX is returned). This should only happen if something is
290  * terribly wrong, either metadata tables are corrupted or we have a bug
291  * somewhere. Such as a hash function which returns a value not in the range
292  * of [PG_INT32_MIN, PG_INT32_MAX] can fire this.
293  */
294 int
FindShardIntervalIndex(Datum searchedValue,CitusTableCacheEntry * cacheEntry)295 FindShardIntervalIndex(Datum searchedValue, CitusTableCacheEntry *cacheEntry)
296 {
297 	ShardInterval **shardIntervalCache = cacheEntry->sortedShardIntervalArray;
298 	int shardCount = cacheEntry->shardIntervalArrayLength;
299 	FmgrInfo *compareFunction = cacheEntry->shardIntervalCompareFunction;
300 	bool useBinarySearch = (!IsCitusTableTypeCacheEntry(cacheEntry, HASH_DISTRIBUTED) ||
301 							!cacheEntry->hasUniformHashDistribution);
302 	int shardIndex = INVALID_SHARD_INDEX;
303 
304 	if (shardCount == 0)
305 	{
306 		return INVALID_SHARD_INDEX;
307 	}
308 
309 	if (IsCitusTableTypeCacheEntry(cacheEntry, HASH_DISTRIBUTED))
310 	{
311 		if (useBinarySearch)
312 		{
313 			Assert(compareFunction != NULL);
314 
315 			Oid shardIntervalCollation = cacheEntry->partitionColumn->varcollid;
316 			shardIndex = SearchCachedShardInterval(searchedValue, shardIntervalCache,
317 												   shardCount, shardIntervalCollation,
318 												   compareFunction);
319 
320 			/* we should always return a valid shard index for hash partitioned tables */
321 			if (shardIndex == INVALID_SHARD_INDEX)
322 			{
323 				ereport(ERROR, (errcode(ERRCODE_DATA_EXCEPTION),
324 								errmsg("cannot find shard interval"),
325 								errdetail("Hash of the partition column value "
326 										  "does not fall into any shards.")));
327 			}
328 		}
329 		else
330 		{
331 			int hashedValue = DatumGetInt32(searchedValue);
332 
333 			shardIndex = CalculateUniformHashRangeIndex(hashedValue, shardCount);
334 		}
335 	}
336 	else if (IsCitusTableTypeCacheEntry(cacheEntry, CITUS_TABLE_WITH_NO_DIST_KEY))
337 	{
338 		/* non-distributed tables have a single shard, all values mapped to that shard */
339 		Assert(shardCount == 1);
340 
341 		shardIndex = 0;
342 	}
343 	else
344 	{
345 		Assert(compareFunction != NULL);
346 
347 		Oid shardIntervalCollation = cacheEntry->partitionColumn->varcollid;
348 		shardIndex = SearchCachedShardInterval(searchedValue, shardIntervalCache,
349 											   shardCount, shardIntervalCollation,
350 											   compareFunction);
351 	}
352 
353 	return shardIndex;
354 }
355 
356 
357 /*
358  * SearchCachedShardInterval performs a binary search for a shard interval
359  * matching a given partition column value and returns it's index in the cached
360  * array. If it can not find any shard interval with the given value, it returns
361  * INVALID_SHARD_INDEX.
362  *
363  * TODO: Data re-partitioning logic (e.g., worker_hash_partition_table())
364  * on the worker nodes relies on this function in order to be consistent
365  * with shard pruning. Since the worker nodes don't have the metadata, a
366  * synthetically generated ShardInterval ** is passed to the to this
367  * function. The synthetic shard intervals contain only shardmin and shardmax
368  * values. A proper implementation of this approach should be introducing an
369  * intermediate data structure (e.g., ShardRange) on which this function
370  * operates instead of operating shard intervals.
371  */
372 int
SearchCachedShardInterval(Datum partitionColumnValue,ShardInterval ** shardIntervalCache,int shardCount,Oid shardIntervalCollation,FmgrInfo * compareFunction)373 SearchCachedShardInterval(Datum partitionColumnValue, ShardInterval **shardIntervalCache,
374 						  int shardCount, Oid shardIntervalCollation,
375 						  FmgrInfo *compareFunction)
376 {
377 	int lowerBoundIndex = 0;
378 	int upperBoundIndex = shardCount;
379 
380 	while (lowerBoundIndex < upperBoundIndex)
381 	{
382 		int middleIndex = (lowerBoundIndex + upperBoundIndex) / 2;
383 
384 		int minValueComparison = FunctionCall2Coll(compareFunction,
385 												   shardIntervalCollation,
386 												   partitionColumnValue,
387 												   shardIntervalCache[middleIndex]->
388 												   minValue);
389 
390 		if (DatumGetInt32(minValueComparison) < 0)
391 		{
392 			upperBoundIndex = middleIndex;
393 			continue;
394 		}
395 
396 		int maxValueComparison = FunctionCall2Coll(compareFunction,
397 												   shardIntervalCollation,
398 												   partitionColumnValue,
399 												   shardIntervalCache[middleIndex]->
400 												   maxValue);
401 
402 		if (DatumGetInt32(maxValueComparison) <= 0)
403 		{
404 			return middleIndex;
405 		}
406 
407 		lowerBoundIndex = middleIndex + 1;
408 	}
409 
410 	return INVALID_SHARD_INDEX;
411 }
412 
413 
414 /*
415  * CalculateUniformHashRangeIndex returns the index of the hash range in
416  * which hashedValue falls, assuming shardCount uniform hash ranges.
417  *
418  * We use 64-bit integers to avoid overflow issues during arithmetic.
419  *
420  * NOTE: This function is ONLY for hash-distributed tables with uniform
421  * hash ranges.
422  */
423 int
CalculateUniformHashRangeIndex(int hashedValue,int shardCount)424 CalculateUniformHashRangeIndex(int hashedValue, int shardCount)
425 {
426 	int64 hashedValue64 = (int64) hashedValue;
427 
428 	/* normalize to the 0-UINT32_MAX range */
429 	int64 normalizedHashValue = hashedValue64 - PG_INT32_MIN;
430 
431 	/* size of each hash range */
432 	int64 hashRangeSize = HASH_TOKEN_COUNT / shardCount;
433 
434 	/* index of hash range into which the hash value falls */
435 	int shardIndex = (int) (normalizedHashValue / hashRangeSize);
436 
437 	if (shardIndex < 0 || shardIndex > shardCount)
438 	{
439 		ereport(ERROR, (errmsg("bug: shard index %d out of bounds", shardIndex)));
440 	}
441 
442 	/*
443 	 * If the shard count is not power of 2, the range of the last
444 	 * shard becomes larger than others. For that extra piece of range,
445 	 * we still need to use the last shard.
446 	 */
447 	if (shardIndex == shardCount)
448 	{
449 		shardIndex = shardCount - 1;
450 	}
451 
452 	return shardIndex;
453 }
454 
455 
456 /*
457  * SingleReplicatedTable checks whether all shards of a distributed table, do not have
458  * more than one replica. If even one shard has more than one replica, this function
459  * returns false, otherwise it returns true.
460  */
461 bool
SingleReplicatedTable(Oid relationId)462 SingleReplicatedTable(Oid relationId)
463 {
464 	List *shardList = LoadShardList(relationId);
465 	List *shardPlacementList = NIL;
466 
467 	/* we could have append/range distributed tables without shards */
468 	if (list_length(shardList) == 0)
469 	{
470 		return false;
471 	}
472 
473 	/* for hash distributed tables, it is sufficient to only check one shard */
474 	if (IsCitusTableType(relationId, HASH_DISTRIBUTED))
475 	{
476 		/* checking only for the first shard id should suffice */
477 		uint64 shardId = *(uint64 *) linitial(shardList);
478 
479 		shardPlacementList = ShardPlacementListWithoutOrphanedPlacements(shardId);
480 		if (list_length(shardPlacementList) != 1)
481 		{
482 			return false;
483 		}
484 	}
485 	else
486 	{
487 		List *shardIntervalList = LoadShardList(relationId);
488 		uint64 *shardIdPointer = NULL;
489 		foreach_ptr(shardIdPointer, shardIntervalList)
490 		{
491 			uint64 shardId = *shardIdPointer;
492 			shardPlacementList = ShardPlacementListWithoutOrphanedPlacements(shardId);
493 
494 			if (list_length(shardPlacementList) != 1)
495 			{
496 				return false;
497 			}
498 		}
499 	}
500 
501 	return true;
502 }
503