1 /*-------------------------------------------------------------------------
2 *
3 * shardinterval_utils.c
4 *
5 * This file contains functions to perform useful operations on shard intervals.
6 *
7 * Copyright (c) Citus Data, Inc.
8 *
9 *-------------------------------------------------------------------------
10 */
11 #include "stdint.h"
12 #include "postgres.h"
13
14 #include "access/nbtree.h"
15 #include "catalog/pg_am.h"
16 #include "catalog/pg_collation.h"
17 #include "catalog/pg_type.h"
18 #include "distributed/listutils.h"
19 #include "distributed/metadata_cache.h"
20 #include "distributed/multi_join_order.h"
21 #include "distributed/distributed_planner.h"
22 #include "distributed/shard_pruning.h"
23 #include "distributed/shardinterval_utils.h"
24 #include "distributed/pg_dist_partition.h"
25 #include "distributed/worker_protocol.h"
26 #include "utils/catcache.h"
27 #include "utils/memutils.h"
28
29
30 /*
31 * SortedShardIntervalArray sorts the input shardIntervalArray. Shard intervals with
32 * no min/max values are placed at the end of the array.
33 */
34 ShardInterval **
SortShardIntervalArray(ShardInterval ** shardIntervalArray,int shardCount,Oid collation,FmgrInfo * shardIntervalSortCompareFunction)35 SortShardIntervalArray(ShardInterval **shardIntervalArray, int shardCount,
36 Oid collation, FmgrInfo *shardIntervalSortCompareFunction)
37 {
38 SortShardIntervalContext sortContext = {
39 .comparisonFunction = shardIntervalSortCompareFunction,
40 .collation = collation
41 };
42
43 /* short cut if there are no shard intervals in the array */
44 if (shardCount == 0)
45 {
46 return shardIntervalArray;
47 }
48
49 /* if a shard doesn't have min/max values, it's placed in the end of the array */
50 qsort_arg(shardIntervalArray, shardCount, sizeof(ShardInterval *),
51 (qsort_arg_comparator) CompareShardIntervals, (void *) &sortContext);
52
53 return shardIntervalArray;
54 }
55
56
57 /*
58 * CompareShardIntervals acts as a helper function to compare two shard intervals
59 * by their minimum values, using the value's type comparison function.
60 *
61 * If a shard interval does not have min/max value, it's treated as being greater
62 * than the other.
63 */
64 int
CompareShardIntervals(const void * leftElement,const void * rightElement,SortShardIntervalContext * sortContext)65 CompareShardIntervals(const void *leftElement, const void *rightElement,
66 SortShardIntervalContext *sortContext)
67 {
68 ShardInterval *leftShardInterval = *((ShardInterval **) leftElement);
69 ShardInterval *rightShardInterval = *((ShardInterval **) rightElement);
70 int comparisonResult = 0;
71 bool leftHasNull = (!leftShardInterval->minValueExists ||
72 !leftShardInterval->maxValueExists);
73 bool rightHasNull = (!rightShardInterval->minValueExists ||
74 !rightShardInterval->maxValueExists);
75
76 Assert(sortContext->comparisonFunction != NULL);
77
78 if (leftHasNull && rightHasNull)
79 {
80 comparisonResult = 0;
81 }
82 else if (leftHasNull)
83 {
84 comparisonResult = 1;
85 }
86 else if (rightHasNull)
87 {
88 comparisonResult = -1;
89 }
90 else
91 {
92 /* if both shard interval have min/max values, calculate comparison result */
93 Datum leftDatum = leftShardInterval->minValue;
94 Datum rightDatum = rightShardInterval->minValue;
95 Datum comparisonDatum = FunctionCall2Coll(sortContext->comparisonFunction,
96 sortContext->collation, leftDatum,
97 rightDatum);
98 comparisonResult = DatumGetInt32(comparisonDatum);
99 }
100
101 /* Two different shards should never be equal */
102 if (comparisonResult == 0)
103 {
104 return CompareShardIntervalsById(leftElement, rightElement);
105 }
106
107 return comparisonResult;
108 }
109
110
111 /*
112 * CompareShardIntervalsById is a comparison function for sort shard
113 * intervals by their shard ID.
114 */
115 int
CompareShardIntervalsById(const void * leftElement,const void * rightElement)116 CompareShardIntervalsById(const void *leftElement, const void *rightElement)
117 {
118 ShardInterval *leftInterval = *((ShardInterval **) leftElement);
119 ShardInterval *rightInterval = *((ShardInterval **) rightElement);
120 int64 leftShardId = leftInterval->shardId;
121 int64 rightShardId = rightInterval->shardId;
122
123 /* we compare 64-bit integers, instead of casting their difference to int */
124 if (leftShardId > rightShardId)
125 {
126 return 1;
127 }
128 else if (leftShardId < rightShardId)
129 {
130 return -1;
131 }
132 else
133 {
134 return 0;
135 }
136 }
137
138
139 /*
140 * CompareShardPlacementsByShardId is a comparison function for sorting shard
141 * placement by their shard ID.
142 */
143 int
CompareShardPlacementsByShardId(const void * leftElement,const void * rightElement)144 CompareShardPlacementsByShardId(const void *leftElement, const void *rightElement)
145 {
146 GroupShardPlacement *left = *((GroupShardPlacement **) leftElement);
147 GroupShardPlacement *right = *((GroupShardPlacement **) rightElement);
148 int64 leftShardId = left->shardId;
149 int64 rightShardId = right->shardId;
150
151 /* we compare 64-bit integers, instead of casting their difference to int */
152 if (leftShardId > rightShardId)
153 {
154 return 1;
155 }
156 else if (leftShardId < rightShardId)
157 {
158 return -1;
159 }
160 else
161 {
162 return 0;
163 }
164 }
165
166
167 /*
168 * CompareRelationShards is a comparison function for sorting relation
169 * to shard mappings by their relation ID and then shard ID.
170 */
171 int
CompareRelationShards(const void * leftElement,const void * rightElement)172 CompareRelationShards(const void *leftElement, const void *rightElement)
173 {
174 RelationShard *leftRelationShard = *((RelationShard **) leftElement);
175 RelationShard *rightRelationShard = *((RelationShard **) rightElement);
176 Oid leftRelationId = leftRelationShard->relationId;
177 Oid rightRelationId = rightRelationShard->relationId;
178 int64 leftShardId = leftRelationShard->shardId;
179 int64 rightShardId = rightRelationShard->shardId;
180
181 if (leftRelationId > rightRelationId)
182 {
183 return 1;
184 }
185 else if (leftRelationId < rightRelationId)
186 {
187 return -1;
188 }
189 else if (leftShardId > rightShardId)
190 {
191 return 1;
192 }
193 else if (leftShardId < rightShardId)
194 {
195 return -1;
196 }
197 else
198 {
199 return 0;
200 }
201 }
202
203
204 /*
205 * ShardIndex finds the index of given shard in sorted shard interval array.
206 *
207 * For hash partitioned tables, it calculates hash value of a number in its
208 * range (e.g. min value) and finds which shard should contain the hashed
209 * value. For reference tables and citus local tables, it simply returns 0.
210 * For the other table types, the function errors out.
211 */
212 int
ShardIndex(ShardInterval * shardInterval)213 ShardIndex(ShardInterval *shardInterval)
214 {
215 int shardIndex = INVALID_SHARD_INDEX;
216 Oid distributedTableId = shardInterval->relationId;
217 Datum shardMinValue = shardInterval->minValue;
218
219 CitusTableCacheEntry *cacheEntry = GetCitusTableCacheEntry(distributedTableId);
220
221 /*
222 * Note that, we can also support append and range distributed tables, but
223 * currently it is not required.
224 */
225 if (!IsCitusTableTypeCacheEntry(cacheEntry, HASH_DISTRIBUTED) &&
226 !IsCitusTableTypeCacheEntry(
227 cacheEntry, CITUS_TABLE_WITH_NO_DIST_KEY))
228 {
229 ereport(ERROR, (errcode(ERRCODE_FEATURE_NOT_SUPPORTED),
230 errmsg("finding index of a given shard is only supported for "
231 "hash distributed tables, reference tables and local "
232 "tables that are added to citus metadata")));
233 }
234
235 /* short-circuit for reference tables */
236 if (IsCitusTableTypeCacheEntry(cacheEntry, CITUS_TABLE_WITH_NO_DIST_KEY))
237 {
238 /*
239 * Reference tables and citus local tables have only a single shard,
240 * so the index is fixed to 0.
241 */
242 shardIndex = 0;
243
244 return shardIndex;
245 }
246
247 shardIndex = FindShardIntervalIndex(shardMinValue, cacheEntry);
248
249 return shardIndex;
250 }
251
252
253 /*
254 * FindShardInterval finds a single shard interval in the cache for the
255 * given partition column value. Note that reference tables do not have
256 * partition columns, thus, pass partitionColumnValue and compareFunction
257 * as NULL for them.
258 */
259 ShardInterval *
FindShardInterval(Datum partitionColumnValue,CitusTableCacheEntry * cacheEntry)260 FindShardInterval(Datum partitionColumnValue, CitusTableCacheEntry *cacheEntry)
261 {
262 Datum searchedValue = partitionColumnValue;
263
264 if (IsCitusTableTypeCacheEntry(cacheEntry, HASH_DISTRIBUTED))
265 {
266 searchedValue = FunctionCall1Coll(cacheEntry->hashFunction,
267 cacheEntry->partitionColumn->varcollid,
268 partitionColumnValue);
269 }
270
271 int shardIndex = FindShardIntervalIndex(searchedValue, cacheEntry);
272
273 if (shardIndex == INVALID_SHARD_INDEX)
274 {
275 return NULL;
276 }
277
278 return cacheEntry->sortedShardIntervalArray[shardIndex];
279 }
280
281
282 /*
283 * FindShardIntervalIndex finds the index of the shard interval which covers
284 * the searched value. Note that the searched value must be the hashed value
285 * of the original value if the distribution method is hash.
286 *
287 * Note that, if the searched value can not be found for hash partitioned
288 * tables, we error out (unless there are no shards, in which case
289 * INVALID_SHARD_INDEX is returned). This should only happen if something is
290 * terribly wrong, either metadata tables are corrupted or we have a bug
291 * somewhere. Such as a hash function which returns a value not in the range
292 * of [PG_INT32_MIN, PG_INT32_MAX] can fire this.
293 */
294 int
FindShardIntervalIndex(Datum searchedValue,CitusTableCacheEntry * cacheEntry)295 FindShardIntervalIndex(Datum searchedValue, CitusTableCacheEntry *cacheEntry)
296 {
297 ShardInterval **shardIntervalCache = cacheEntry->sortedShardIntervalArray;
298 int shardCount = cacheEntry->shardIntervalArrayLength;
299 FmgrInfo *compareFunction = cacheEntry->shardIntervalCompareFunction;
300 bool useBinarySearch = (!IsCitusTableTypeCacheEntry(cacheEntry, HASH_DISTRIBUTED) ||
301 !cacheEntry->hasUniformHashDistribution);
302 int shardIndex = INVALID_SHARD_INDEX;
303
304 if (shardCount == 0)
305 {
306 return INVALID_SHARD_INDEX;
307 }
308
309 if (IsCitusTableTypeCacheEntry(cacheEntry, HASH_DISTRIBUTED))
310 {
311 if (useBinarySearch)
312 {
313 Assert(compareFunction != NULL);
314
315 Oid shardIntervalCollation = cacheEntry->partitionColumn->varcollid;
316 shardIndex = SearchCachedShardInterval(searchedValue, shardIntervalCache,
317 shardCount, shardIntervalCollation,
318 compareFunction);
319
320 /* we should always return a valid shard index for hash partitioned tables */
321 if (shardIndex == INVALID_SHARD_INDEX)
322 {
323 ereport(ERROR, (errcode(ERRCODE_DATA_EXCEPTION),
324 errmsg("cannot find shard interval"),
325 errdetail("Hash of the partition column value "
326 "does not fall into any shards.")));
327 }
328 }
329 else
330 {
331 int hashedValue = DatumGetInt32(searchedValue);
332
333 shardIndex = CalculateUniformHashRangeIndex(hashedValue, shardCount);
334 }
335 }
336 else if (IsCitusTableTypeCacheEntry(cacheEntry, CITUS_TABLE_WITH_NO_DIST_KEY))
337 {
338 /* non-distributed tables have a single shard, all values mapped to that shard */
339 Assert(shardCount == 1);
340
341 shardIndex = 0;
342 }
343 else
344 {
345 Assert(compareFunction != NULL);
346
347 Oid shardIntervalCollation = cacheEntry->partitionColumn->varcollid;
348 shardIndex = SearchCachedShardInterval(searchedValue, shardIntervalCache,
349 shardCount, shardIntervalCollation,
350 compareFunction);
351 }
352
353 return shardIndex;
354 }
355
356
357 /*
358 * SearchCachedShardInterval performs a binary search for a shard interval
359 * matching a given partition column value and returns it's index in the cached
360 * array. If it can not find any shard interval with the given value, it returns
361 * INVALID_SHARD_INDEX.
362 *
363 * TODO: Data re-partitioning logic (e.g., worker_hash_partition_table())
364 * on the worker nodes relies on this function in order to be consistent
365 * with shard pruning. Since the worker nodes don't have the metadata, a
366 * synthetically generated ShardInterval ** is passed to the to this
367 * function. The synthetic shard intervals contain only shardmin and shardmax
368 * values. A proper implementation of this approach should be introducing an
369 * intermediate data structure (e.g., ShardRange) on which this function
370 * operates instead of operating shard intervals.
371 */
372 int
SearchCachedShardInterval(Datum partitionColumnValue,ShardInterval ** shardIntervalCache,int shardCount,Oid shardIntervalCollation,FmgrInfo * compareFunction)373 SearchCachedShardInterval(Datum partitionColumnValue, ShardInterval **shardIntervalCache,
374 int shardCount, Oid shardIntervalCollation,
375 FmgrInfo *compareFunction)
376 {
377 int lowerBoundIndex = 0;
378 int upperBoundIndex = shardCount;
379
380 while (lowerBoundIndex < upperBoundIndex)
381 {
382 int middleIndex = (lowerBoundIndex + upperBoundIndex) / 2;
383
384 int minValueComparison = FunctionCall2Coll(compareFunction,
385 shardIntervalCollation,
386 partitionColumnValue,
387 shardIntervalCache[middleIndex]->
388 minValue);
389
390 if (DatumGetInt32(minValueComparison) < 0)
391 {
392 upperBoundIndex = middleIndex;
393 continue;
394 }
395
396 int maxValueComparison = FunctionCall2Coll(compareFunction,
397 shardIntervalCollation,
398 partitionColumnValue,
399 shardIntervalCache[middleIndex]->
400 maxValue);
401
402 if (DatumGetInt32(maxValueComparison) <= 0)
403 {
404 return middleIndex;
405 }
406
407 lowerBoundIndex = middleIndex + 1;
408 }
409
410 return INVALID_SHARD_INDEX;
411 }
412
413
414 /*
415 * CalculateUniformHashRangeIndex returns the index of the hash range in
416 * which hashedValue falls, assuming shardCount uniform hash ranges.
417 *
418 * We use 64-bit integers to avoid overflow issues during arithmetic.
419 *
420 * NOTE: This function is ONLY for hash-distributed tables with uniform
421 * hash ranges.
422 */
423 int
CalculateUniformHashRangeIndex(int hashedValue,int shardCount)424 CalculateUniformHashRangeIndex(int hashedValue, int shardCount)
425 {
426 int64 hashedValue64 = (int64) hashedValue;
427
428 /* normalize to the 0-UINT32_MAX range */
429 int64 normalizedHashValue = hashedValue64 - PG_INT32_MIN;
430
431 /* size of each hash range */
432 int64 hashRangeSize = HASH_TOKEN_COUNT / shardCount;
433
434 /* index of hash range into which the hash value falls */
435 int shardIndex = (int) (normalizedHashValue / hashRangeSize);
436
437 if (shardIndex < 0 || shardIndex > shardCount)
438 {
439 ereport(ERROR, (errmsg("bug: shard index %d out of bounds", shardIndex)));
440 }
441
442 /*
443 * If the shard count is not power of 2, the range of the last
444 * shard becomes larger than others. For that extra piece of range,
445 * we still need to use the last shard.
446 */
447 if (shardIndex == shardCount)
448 {
449 shardIndex = shardCount - 1;
450 }
451
452 return shardIndex;
453 }
454
455
456 /*
457 * SingleReplicatedTable checks whether all shards of a distributed table, do not have
458 * more than one replica. If even one shard has more than one replica, this function
459 * returns false, otherwise it returns true.
460 */
461 bool
SingleReplicatedTable(Oid relationId)462 SingleReplicatedTable(Oid relationId)
463 {
464 List *shardList = LoadShardList(relationId);
465 List *shardPlacementList = NIL;
466
467 /* we could have append/range distributed tables without shards */
468 if (list_length(shardList) == 0)
469 {
470 return false;
471 }
472
473 /* for hash distributed tables, it is sufficient to only check one shard */
474 if (IsCitusTableType(relationId, HASH_DISTRIBUTED))
475 {
476 /* checking only for the first shard id should suffice */
477 uint64 shardId = *(uint64 *) linitial(shardList);
478
479 shardPlacementList = ShardPlacementListWithoutOrphanedPlacements(shardId);
480 if (list_length(shardPlacementList) != 1)
481 {
482 return false;
483 }
484 }
485 else
486 {
487 List *shardIntervalList = LoadShardList(relationId);
488 uint64 *shardIdPointer = NULL;
489 foreach_ptr(shardIdPointer, shardIntervalList)
490 {
491 uint64 shardId = *shardIdPointer;
492 shardPlacementList = ShardPlacementListWithoutOrphanedPlacements(shardId);
493
494 if (list_length(shardPlacementList) != 1)
495 {
496 return false;
497 }
498 }
499 }
500
501 return true;
502 }
503