1 /*
2  * contrib/pg_trgm/trgm_gist.c
3  */
4 #include "postgres.h"
5 
6 #include "trgm.h"
7 
8 #include "access/stratnum.h"
9 #include "fmgr.h"
10 
11 
12 typedef struct
13 {
14 	/* most recent inputs to gtrgm_consistent */
15 	StrategyNumber strategy;
16 	text	   *query;
17 	/* extracted trigrams for query */
18 	TRGM	   *trigrams;
19 	/* if a regex operator, the extracted graph */
20 	TrgmPackedGraph *graph;
21 
22 	/*
23 	 * The "query" and "trigrams" are stored in the same palloc block as this
24 	 * cache struct, at MAXALIGN'ed offsets.  The graph however isn't.
25 	 */
26 } gtrgm_consistent_cache;
27 
28 #define GETENTRY(vec,pos) ((TRGM *) DatumGetPointer((vec)->vector[(pos)].key))
29 
30 
31 PG_FUNCTION_INFO_V1(gtrgm_in);
32 PG_FUNCTION_INFO_V1(gtrgm_out);
33 PG_FUNCTION_INFO_V1(gtrgm_compress);
34 PG_FUNCTION_INFO_V1(gtrgm_decompress);
35 PG_FUNCTION_INFO_V1(gtrgm_consistent);
36 PG_FUNCTION_INFO_V1(gtrgm_distance);
37 PG_FUNCTION_INFO_V1(gtrgm_union);
38 PG_FUNCTION_INFO_V1(gtrgm_same);
39 PG_FUNCTION_INFO_V1(gtrgm_penalty);
40 PG_FUNCTION_INFO_V1(gtrgm_picksplit);
41 
42 /* Number of one-bits in an unsigned byte */
43 static const uint8 number_of_ones[256] = {
44 	0, 1, 1, 2, 1, 2, 2, 3, 1, 2, 2, 3, 2, 3, 3, 4,
45 	1, 2, 2, 3, 2, 3, 3, 4, 2, 3, 3, 4, 3, 4, 4, 5,
46 	1, 2, 2, 3, 2, 3, 3, 4, 2, 3, 3, 4, 3, 4, 4, 5,
47 	2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6,
48 	1, 2, 2, 3, 2, 3, 3, 4, 2, 3, 3, 4, 3, 4, 4, 5,
49 	2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6,
50 	2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6,
51 	3, 4, 4, 5, 4, 5, 5, 6, 4, 5, 5, 6, 5, 6, 6, 7,
52 	1, 2, 2, 3, 2, 3, 3, 4, 2, 3, 3, 4, 3, 4, 4, 5,
53 	2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6,
54 	2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6,
55 	3, 4, 4, 5, 4, 5, 5, 6, 4, 5, 5, 6, 5, 6, 6, 7,
56 	2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6,
57 	3, 4, 4, 5, 4, 5, 5, 6, 4, 5, 5, 6, 5, 6, 6, 7,
58 	3, 4, 4, 5, 4, 5, 5, 6, 4, 5, 5, 6, 5, 6, 6, 7,
59 	4, 5, 5, 6, 5, 6, 6, 7, 5, 6, 6, 7, 6, 7, 7, 8
60 };
61 
62 
63 Datum
gtrgm_in(PG_FUNCTION_ARGS)64 gtrgm_in(PG_FUNCTION_ARGS)
65 {
66 	elog(ERROR, "not implemented");
67 	PG_RETURN_DATUM(0);
68 }
69 
70 Datum
gtrgm_out(PG_FUNCTION_ARGS)71 gtrgm_out(PG_FUNCTION_ARGS)
72 {
73 	elog(ERROR, "not implemented");
74 	PG_RETURN_DATUM(0);
75 }
76 
77 static void
makesign(BITVECP sign,TRGM * a)78 makesign(BITVECP sign, TRGM *a)
79 {
80 	int32		k,
81 				len = ARRNELEM(a);
82 	trgm	   *ptr = GETARR(a);
83 	int32		tmp = 0;
84 
85 	MemSet((void *) sign, 0, sizeof(BITVEC));
86 	SETBIT(sign, SIGLENBIT);	/* set last unused bit */
87 	for (k = 0; k < len; k++)
88 	{
89 		CPTRGM(((char *) &tmp), ptr + k);
90 		HASH(sign, tmp);
91 	}
92 }
93 
94 Datum
gtrgm_compress(PG_FUNCTION_ARGS)95 gtrgm_compress(PG_FUNCTION_ARGS)
96 {
97 	GISTENTRY  *entry = (GISTENTRY *) PG_GETARG_POINTER(0);
98 	GISTENTRY  *retval = entry;
99 
100 	if (entry->leafkey)
101 	{							/* trgm */
102 		TRGM	   *res;
103 		text	   *val = DatumGetTextPP(entry->key);
104 
105 		res = generate_trgm(VARDATA_ANY(val), VARSIZE_ANY_EXHDR(val));
106 		retval = (GISTENTRY *) palloc(sizeof(GISTENTRY));
107 		gistentryinit(*retval, PointerGetDatum(res),
108 					  entry->rel, entry->page,
109 					  entry->offset, false);
110 	}
111 	else if (ISSIGNKEY(DatumGetPointer(entry->key)) &&
112 			 !ISALLTRUE(DatumGetPointer(entry->key)))
113 	{
114 		int32		i,
115 					len;
116 		TRGM	   *res;
117 		BITVECP		sign = GETSIGN(DatumGetPointer(entry->key));
118 
119 		LOOPBYTE
120 		{
121 			if ((sign[i] & 0xff) != 0xff)
122 				PG_RETURN_POINTER(retval);
123 		}
124 
125 		len = CALCGTSIZE(SIGNKEY | ALLISTRUE, 0);
126 		res = (TRGM *) palloc(len);
127 		SET_VARSIZE(res, len);
128 		res->flag = SIGNKEY | ALLISTRUE;
129 
130 		retval = (GISTENTRY *) palloc(sizeof(GISTENTRY));
131 		gistentryinit(*retval, PointerGetDatum(res),
132 					  entry->rel, entry->page,
133 					  entry->offset, false);
134 	}
135 	PG_RETURN_POINTER(retval);
136 }
137 
138 Datum
gtrgm_decompress(PG_FUNCTION_ARGS)139 gtrgm_decompress(PG_FUNCTION_ARGS)
140 {
141 	GISTENTRY  *entry = (GISTENTRY *) PG_GETARG_POINTER(0);
142 	GISTENTRY  *retval;
143 	text	   *key;
144 
145 	key = DatumGetTextPP(entry->key);
146 
147 	if (key != (text *) DatumGetPointer(entry->key))
148 	{
149 		/* need to pass back the decompressed item */
150 		retval = palloc(sizeof(GISTENTRY));
151 		gistentryinit(*retval, PointerGetDatum(key),
152 					  entry->rel, entry->page, entry->offset, entry->leafkey);
153 		PG_RETURN_POINTER(retval);
154 	}
155 	else
156 	{
157 		/* we can return the entry as-is */
158 		PG_RETURN_POINTER(entry);
159 	}
160 }
161 
162 static int32
cnt_sml_sign_common(TRGM * qtrg,BITVECP sign)163 cnt_sml_sign_common(TRGM *qtrg, BITVECP sign)
164 {
165 	int32		count = 0;
166 	int32		k,
167 				len = ARRNELEM(qtrg);
168 	trgm	   *ptr = GETARR(qtrg);
169 	int32		tmp = 0;
170 
171 	for (k = 0; k < len; k++)
172 	{
173 		CPTRGM(((char *) &tmp), ptr + k);
174 		count += GETBIT(sign, HASHVAL(tmp));
175 	}
176 
177 	return count;
178 }
179 
180 Datum
gtrgm_consistent(PG_FUNCTION_ARGS)181 gtrgm_consistent(PG_FUNCTION_ARGS)
182 {
183 	GISTENTRY  *entry = (GISTENTRY *) PG_GETARG_POINTER(0);
184 	text	   *query = PG_GETARG_TEXT_P(1);
185 	StrategyNumber strategy = (StrategyNumber) PG_GETARG_UINT16(2);
186 
187 	/* Oid		subtype = PG_GETARG_OID(3); */
188 	bool	   *recheck = (bool *) PG_GETARG_POINTER(4);
189 	TRGM	   *key = (TRGM *) DatumGetPointer(entry->key);
190 	TRGM	   *qtrg;
191 	bool		res;
192 	Size		querysize = VARSIZE(query);
193 	gtrgm_consistent_cache *cache;
194 	double		nlimit;
195 
196 	/*
197 	 * We keep the extracted trigrams in cache, because trigram extraction is
198 	 * relatively CPU-expensive.  When trying to reuse a cached value, check
199 	 * strategy number not just query itself, because trigram extraction
200 	 * depends on strategy.
201 	 *
202 	 * The cached structure is a single palloc chunk containing the
203 	 * gtrgm_consistent_cache header, then the input query (4-byte length
204 	 * word, uncompressed, starting at a MAXALIGN boundary), then the TRGM
205 	 * value (also starting at a MAXALIGN boundary).  However we don't try to
206 	 * include the regex graph (if any) in that struct.  (XXX currently, this
207 	 * approach can leak regex graphs across index rescans.  Not clear if
208 	 * that's worth fixing.)
209 	 */
210 	cache = (gtrgm_consistent_cache *) fcinfo->flinfo->fn_extra;
211 	if (cache == NULL ||
212 		cache->strategy != strategy ||
213 		VARSIZE(cache->query) != querysize ||
214 		memcmp((char *) cache->query, (char *) query, querysize) != 0)
215 	{
216 		gtrgm_consistent_cache *newcache;
217 		TrgmPackedGraph *graph = NULL;
218 		Size		qtrgsize;
219 
220 		switch (strategy)
221 		{
222 			case SimilarityStrategyNumber:
223 			case WordSimilarityStrategyNumber:
224 			case StrictWordSimilarityStrategyNumber:
225 				qtrg = generate_trgm(VARDATA(query),
226 									 querysize - VARHDRSZ);
227 				break;
228 			case ILikeStrategyNumber:
229 #ifndef IGNORECASE
230 				elog(ERROR, "cannot handle ~~* with case-sensitive trigrams");
231 #endif
232 				/* FALL THRU */
233 			case LikeStrategyNumber:
234 				qtrg = generate_wildcard_trgm(VARDATA(query),
235 											  querysize - VARHDRSZ);
236 				break;
237 			case RegExpICaseStrategyNumber:
238 #ifndef IGNORECASE
239 				elog(ERROR, "cannot handle ~* with case-sensitive trigrams");
240 #endif
241 				/* FALL THRU */
242 			case RegExpStrategyNumber:
243 				qtrg = createTrgmNFA(query, PG_GET_COLLATION(),
244 									 &graph, fcinfo->flinfo->fn_mcxt);
245 				/* just in case an empty array is returned ... */
246 				if (qtrg && ARRNELEM(qtrg) <= 0)
247 				{
248 					pfree(qtrg);
249 					qtrg = NULL;
250 				}
251 				break;
252 			default:
253 				elog(ERROR, "unrecognized strategy number: %d", strategy);
254 				qtrg = NULL;	/* keep compiler quiet */
255 				break;
256 		}
257 
258 		qtrgsize = qtrg ? VARSIZE(qtrg) : 0;
259 
260 		newcache = (gtrgm_consistent_cache *)
261 			MemoryContextAlloc(fcinfo->flinfo->fn_mcxt,
262 							   MAXALIGN(sizeof(gtrgm_consistent_cache)) +
263 							   MAXALIGN(querysize) +
264 							   qtrgsize);
265 
266 		newcache->strategy = strategy;
267 		newcache->query = (text *)
268 			((char *) newcache + MAXALIGN(sizeof(gtrgm_consistent_cache)));
269 		memcpy((char *) newcache->query, (char *) query, querysize);
270 		if (qtrg)
271 		{
272 			newcache->trigrams = (TRGM *)
273 				((char *) newcache->query + MAXALIGN(querysize));
274 			memcpy((char *) newcache->trigrams, (char *) qtrg, qtrgsize);
275 			/* release qtrg in case it was made in fn_mcxt */
276 			pfree(qtrg);
277 		}
278 		else
279 			newcache->trigrams = NULL;
280 		newcache->graph = graph;
281 
282 		if (cache)
283 			pfree(cache);
284 		fcinfo->flinfo->fn_extra = (void *) newcache;
285 		cache = newcache;
286 	}
287 
288 	qtrg = cache->trigrams;
289 
290 	switch (strategy)
291 	{
292 		case SimilarityStrategyNumber:
293 		case WordSimilarityStrategyNumber:
294 		case StrictWordSimilarityStrategyNumber:
295 
296 			/*
297 			 * Similarity search is exact. (Strict) word similarity search is
298 			 * inexact
299 			 */
300 			*recheck = (strategy != SimilarityStrategyNumber);
301 
302 			nlimit = index_strategy_get_limit(strategy);
303 
304 			if (GIST_LEAF(entry))
305 			{					/* all leafs contains orig trgm */
306 				double		tmpsml = cnt_sml(qtrg, key, *recheck);
307 
308 				res = (tmpsml >= nlimit);
309 			}
310 			else if (ISALLTRUE(key))
311 			{					/* non-leaf contains signature */
312 				res = true;
313 			}
314 			else
315 			{					/* non-leaf contains signature */
316 				int32		count = cnt_sml_sign_common(qtrg, GETSIGN(key));
317 				int32		len = ARRNELEM(qtrg);
318 
319 				if (len == 0)
320 					res = false;
321 				else
322 					res = (((((float8) count) / ((float8) len))) >= nlimit);
323 			}
324 			break;
325 		case ILikeStrategyNumber:
326 #ifndef IGNORECASE
327 			elog(ERROR, "cannot handle ~~* with case-sensitive trigrams");
328 #endif
329 			/* FALL THRU */
330 		case LikeStrategyNumber:
331 			/* Wildcard search is inexact */
332 			*recheck = true;
333 
334 			/*
335 			 * Check if all the extracted trigrams can be present in child
336 			 * nodes.
337 			 */
338 			if (GIST_LEAF(entry))
339 			{					/* all leafs contains orig trgm */
340 				res = trgm_contained_by(qtrg, key);
341 			}
342 			else if (ISALLTRUE(key))
343 			{					/* non-leaf contains signature */
344 				res = true;
345 			}
346 			else
347 			{					/* non-leaf contains signature */
348 				int32		k,
349 							tmp = 0,
350 							len = ARRNELEM(qtrg);
351 				trgm	   *ptr = GETARR(qtrg);
352 				BITVECP		sign = GETSIGN(key);
353 
354 				res = true;
355 				for (k = 0; k < len; k++)
356 				{
357 					CPTRGM(((char *) &tmp), ptr + k);
358 					if (!GETBIT(sign, HASHVAL(tmp)))
359 					{
360 						res = false;
361 						break;
362 					}
363 				}
364 			}
365 			break;
366 		case RegExpICaseStrategyNumber:
367 #ifndef IGNORECASE
368 			elog(ERROR, "cannot handle ~* with case-sensitive trigrams");
369 #endif
370 			/* FALL THRU */
371 		case RegExpStrategyNumber:
372 			/* Regexp search is inexact */
373 			*recheck = true;
374 
375 			/* Check regex match as much as we can with available info */
376 			if (qtrg)
377 			{
378 				if (GIST_LEAF(entry))
379 				{				/* all leafs contains orig trgm */
380 					bool	   *check;
381 
382 					check = trgm_presence_map(qtrg, key);
383 					res = trigramsMatchGraph(cache->graph, check);
384 					pfree(check);
385 				}
386 				else if (ISALLTRUE(key))
387 				{				/* non-leaf contains signature */
388 					res = true;
389 				}
390 				else
391 				{				/* non-leaf contains signature */
392 					int32		k,
393 								tmp = 0,
394 								len = ARRNELEM(qtrg);
395 					trgm	   *ptr = GETARR(qtrg);
396 					BITVECP		sign = GETSIGN(key);
397 					bool	   *check;
398 
399 					/*
400 					 * GETBIT() tests may give false positives, due to limited
401 					 * size of the sign array.  But since trigramsMatchGraph()
402 					 * implements a monotone boolean function, false positives
403 					 * in the check array can't lead to false negative answer.
404 					 * So we can apply trigramsMatchGraph despite uncertainty,
405 					 * and that usefully improves the quality of the search.
406 					 */
407 					check = (bool *) palloc(len * sizeof(bool));
408 					for (k = 0; k < len; k++)
409 					{
410 						CPTRGM(((char *) &tmp), ptr + k);
411 						check[k] = GETBIT(sign, HASHVAL(tmp));
412 					}
413 					res = trigramsMatchGraph(cache->graph, check);
414 					pfree(check);
415 				}
416 			}
417 			else
418 			{
419 				/* trigram-free query must be rechecked everywhere */
420 				res = true;
421 			}
422 			break;
423 		default:
424 			elog(ERROR, "unrecognized strategy number: %d", strategy);
425 			res = false;		/* keep compiler quiet */
426 			break;
427 	}
428 
429 	PG_RETURN_BOOL(res);
430 }
431 
432 Datum
gtrgm_distance(PG_FUNCTION_ARGS)433 gtrgm_distance(PG_FUNCTION_ARGS)
434 {
435 	GISTENTRY  *entry = (GISTENTRY *) PG_GETARG_POINTER(0);
436 	text	   *query = PG_GETARG_TEXT_P(1);
437 	StrategyNumber strategy = (StrategyNumber) PG_GETARG_UINT16(2);
438 
439 	/* Oid		subtype = PG_GETARG_OID(3); */
440 	bool	   *recheck = (bool *) PG_GETARG_POINTER(4);
441 	TRGM	   *key = (TRGM *) DatumGetPointer(entry->key);
442 	TRGM	   *qtrg;
443 	float8		res;
444 	Size		querysize = VARSIZE(query);
445 	char	   *cache = (char *) fcinfo->flinfo->fn_extra;
446 
447 	/*
448 	 * Cache the generated trigrams across multiple calls with the same query.
449 	 */
450 	if (cache == NULL ||
451 		VARSIZE(cache) != querysize ||
452 		memcmp(cache, query, querysize) != 0)
453 	{
454 		char	   *newcache;
455 
456 		qtrg = generate_trgm(VARDATA(query), querysize - VARHDRSZ);
457 
458 		newcache = MemoryContextAlloc(fcinfo->flinfo->fn_mcxt,
459 									  MAXALIGN(querysize) +
460 									  VARSIZE(qtrg));
461 
462 		memcpy(newcache, query, querysize);
463 		memcpy(newcache + MAXALIGN(querysize), qtrg, VARSIZE(qtrg));
464 
465 		if (cache)
466 			pfree(cache);
467 		fcinfo->flinfo->fn_extra = newcache;
468 		cache = newcache;
469 	}
470 
471 	qtrg = (TRGM *) (cache + MAXALIGN(querysize));
472 
473 	switch (strategy)
474 	{
475 		case DistanceStrategyNumber:
476 		case WordDistanceStrategyNumber:
477 		case StrictWordDistanceStrategyNumber:
478 			/* Only plain trigram distance is exact */
479 			*recheck = (strategy != DistanceStrategyNumber);
480 			if (GIST_LEAF(entry))
481 			{					/* all leafs contains orig trgm */
482 
483 				/*
484 				 * Prevent gcc optimizing the sml variable using volatile
485 				 * keyword. Otherwise res can differ from the
486 				 * word_similarity_dist_op() function.
487 				 */
488 				float4 volatile sml = cnt_sml(qtrg, key, *recheck);
489 
490 				res = 1.0 - sml;
491 			}
492 			else if (ISALLTRUE(key))
493 			{					/* all leafs contains orig trgm */
494 				res = 0.0;
495 			}
496 			else
497 			{					/* non-leaf contains signature */
498 				int32		count = cnt_sml_sign_common(qtrg, GETSIGN(key));
499 				int32		len = ARRNELEM(qtrg);
500 
501 				res = (len == 0) ? -1.0 : 1.0 - ((float8) count) / ((float8) len);
502 			}
503 			break;
504 		default:
505 			elog(ERROR, "unrecognized strategy number: %d", strategy);
506 			res = 0;			/* keep compiler quiet */
507 			break;
508 	}
509 
510 	PG_RETURN_FLOAT8(res);
511 }
512 
513 static int32
unionkey(BITVECP sbase,TRGM * add)514 unionkey(BITVECP sbase, TRGM *add)
515 {
516 	int32		i;
517 
518 	if (ISSIGNKEY(add))
519 	{
520 		BITVECP		sadd = GETSIGN(add);
521 
522 		if (ISALLTRUE(add))
523 			return 1;
524 
525 		LOOPBYTE
526 			sbase[i] |= sadd[i];
527 	}
528 	else
529 	{
530 		trgm	   *ptr = GETARR(add);
531 		int32		tmp = 0;
532 
533 		for (i = 0; i < ARRNELEM(add); i++)
534 		{
535 			CPTRGM(((char *) &tmp), ptr + i);
536 			HASH(sbase, tmp);
537 		}
538 	}
539 	return 0;
540 }
541 
542 
543 Datum
gtrgm_union(PG_FUNCTION_ARGS)544 gtrgm_union(PG_FUNCTION_ARGS)
545 {
546 	GistEntryVector *entryvec = (GistEntryVector *) PG_GETARG_POINTER(0);
547 	int32		len = entryvec->n;
548 	int		   *size = (int *) PG_GETARG_POINTER(1);
549 	BITVEC		base;
550 	int32		i;
551 	int32		flag = 0;
552 	TRGM	   *result;
553 
554 	MemSet((void *) base, 0, sizeof(BITVEC));
555 	for (i = 0; i < len; i++)
556 	{
557 		if (unionkey(base, GETENTRY(entryvec, i)))
558 		{
559 			flag = ALLISTRUE;
560 			break;
561 		}
562 	}
563 
564 	flag |= SIGNKEY;
565 	len = CALCGTSIZE(flag, 0);
566 	result = (TRGM *) palloc(len);
567 	SET_VARSIZE(result, len);
568 	result->flag = flag;
569 	if (!ISALLTRUE(result))
570 		memcpy((void *) GETSIGN(result), (void *) base, sizeof(BITVEC));
571 	*size = len;
572 
573 	PG_RETURN_POINTER(result);
574 }
575 
576 Datum
gtrgm_same(PG_FUNCTION_ARGS)577 gtrgm_same(PG_FUNCTION_ARGS)
578 {
579 	TRGM	   *a = (TRGM *) PG_GETARG_POINTER(0);
580 	TRGM	   *b = (TRGM *) PG_GETARG_POINTER(1);
581 	bool	   *result = (bool *) PG_GETARG_POINTER(2);
582 
583 	if (ISSIGNKEY(a))
584 	{							/* then b also ISSIGNKEY */
585 		if (ISALLTRUE(a) && ISALLTRUE(b))
586 			*result = true;
587 		else if (ISALLTRUE(a))
588 			*result = false;
589 		else if (ISALLTRUE(b))
590 			*result = false;
591 		else
592 		{
593 			int32		i;
594 			BITVECP		sa = GETSIGN(a),
595 						sb = GETSIGN(b);
596 
597 			*result = true;
598 			LOOPBYTE
599 			{
600 				if (sa[i] != sb[i])
601 				{
602 					*result = false;
603 					break;
604 				}
605 			}
606 		}
607 	}
608 	else
609 	{							/* a and b ISARRKEY */
610 		int32		lena = ARRNELEM(a),
611 					lenb = ARRNELEM(b);
612 
613 		if (lena != lenb)
614 			*result = false;
615 		else
616 		{
617 			trgm	   *ptra = GETARR(a),
618 					   *ptrb = GETARR(b);
619 			int32		i;
620 
621 			*result = true;
622 			for (i = 0; i < lena; i++)
623 				if (CMPTRGM(ptra + i, ptrb + i))
624 				{
625 					*result = false;
626 					break;
627 				}
628 		}
629 	}
630 
631 	PG_RETURN_POINTER(result);
632 }
633 
634 static int32
sizebitvec(BITVECP sign)635 sizebitvec(BITVECP sign)
636 {
637 	int32		size = 0,
638 				i;
639 
640 	LOOPBYTE
641 		size += number_of_ones[(unsigned char) sign[i]];
642 	return size;
643 }
644 
645 static int
hemdistsign(BITVECP a,BITVECP b)646 hemdistsign(BITVECP a, BITVECP b)
647 {
648 	int			i,
649 				diff,
650 				dist = 0;
651 
652 	LOOPBYTE
653 	{
654 		diff = (unsigned char) (a[i] ^ b[i]);
655 		dist += number_of_ones[diff];
656 	}
657 	return dist;
658 }
659 
660 static int
hemdist(TRGM * a,TRGM * b)661 hemdist(TRGM *a, TRGM *b)
662 {
663 	if (ISALLTRUE(a))
664 	{
665 		if (ISALLTRUE(b))
666 			return 0;
667 		else
668 			return SIGLENBIT - sizebitvec(GETSIGN(b));
669 	}
670 	else if (ISALLTRUE(b))
671 		return SIGLENBIT - sizebitvec(GETSIGN(a));
672 
673 	return hemdistsign(GETSIGN(a), GETSIGN(b));
674 }
675 
676 Datum
gtrgm_penalty(PG_FUNCTION_ARGS)677 gtrgm_penalty(PG_FUNCTION_ARGS)
678 {
679 	GISTENTRY  *origentry = (GISTENTRY *) PG_GETARG_POINTER(0); /* always ISSIGNKEY */
680 	GISTENTRY  *newentry = (GISTENTRY *) PG_GETARG_POINTER(1);
681 	float	   *penalty = (float *) PG_GETARG_POINTER(2);
682 	TRGM	   *origval = (TRGM *) DatumGetPointer(origentry->key);
683 	TRGM	   *newval = (TRGM *) DatumGetPointer(newentry->key);
684 	BITVECP		orig = GETSIGN(origval);
685 
686 	*penalty = 0.0;
687 
688 	if (ISARRKEY(newval))
689 	{
690 		char	   *cache = (char *) fcinfo->flinfo->fn_extra;
691 		TRGM	   *cachedVal = (TRGM *) (cache + MAXALIGN(sizeof(BITVEC)));
692 		Size		newvalsize = VARSIZE(newval);
693 		BITVECP		sign;
694 
695 		/*
696 		 * Cache the sign data across multiple calls with the same newval.
697 		 */
698 		if (cache == NULL ||
699 			VARSIZE(cachedVal) != newvalsize ||
700 			memcmp(cachedVal, newval, newvalsize) != 0)
701 		{
702 			char	   *newcache;
703 
704 			newcache = MemoryContextAlloc(fcinfo->flinfo->fn_mcxt,
705 										  MAXALIGN(sizeof(BITVEC)) +
706 										  newvalsize);
707 
708 			makesign((BITVECP) newcache, newval);
709 
710 			cachedVal = (TRGM *) (newcache + MAXALIGN(sizeof(BITVEC)));
711 			memcpy(cachedVal, newval, newvalsize);
712 
713 			if (cache)
714 				pfree(cache);
715 			fcinfo->flinfo->fn_extra = newcache;
716 			cache = newcache;
717 		}
718 
719 		sign = (BITVECP) cache;
720 
721 		if (ISALLTRUE(origval))
722 			*penalty = ((float) (SIGLENBIT - sizebitvec(sign))) / (float) (SIGLENBIT + 1);
723 		else
724 			*penalty = hemdistsign(sign, orig);
725 	}
726 	else
727 		*penalty = hemdist(origval, newval);
728 	PG_RETURN_POINTER(penalty);
729 }
730 
731 typedef struct
732 {
733 	bool		allistrue;
734 	BITVEC		sign;
735 } CACHESIGN;
736 
737 static void
fillcache(CACHESIGN * item,TRGM * key)738 fillcache(CACHESIGN *item, TRGM *key)
739 {
740 	item->allistrue = false;
741 	if (ISARRKEY(key))
742 		makesign(item->sign, key);
743 	else if (ISALLTRUE(key))
744 		item->allistrue = true;
745 	else
746 		memcpy((void *) item->sign, (void *) GETSIGN(key), sizeof(BITVEC));
747 }
748 
749 #define WISH_F(a,b,c) (double)( -(double)(((a)-(b))*((a)-(b))*((a)-(b)))*(c) )
750 typedef struct
751 {
752 	OffsetNumber pos;
753 	int32		cost;
754 } SPLITCOST;
755 
756 static int
comparecost(const void * a,const void * b)757 comparecost(const void *a, const void *b)
758 {
759 	if (((const SPLITCOST *) a)->cost == ((const SPLITCOST *) b)->cost)
760 		return 0;
761 	else
762 		return (((const SPLITCOST *) a)->cost > ((const SPLITCOST *) b)->cost) ? 1 : -1;
763 }
764 
765 
766 static int
hemdistcache(CACHESIGN * a,CACHESIGN * b)767 hemdistcache(CACHESIGN *a, CACHESIGN *b)
768 {
769 	if (a->allistrue)
770 	{
771 		if (b->allistrue)
772 			return 0;
773 		else
774 			return SIGLENBIT - sizebitvec(b->sign);
775 	}
776 	else if (b->allistrue)
777 		return SIGLENBIT - sizebitvec(a->sign);
778 
779 	return hemdistsign(a->sign, b->sign);
780 }
781 
782 Datum
gtrgm_picksplit(PG_FUNCTION_ARGS)783 gtrgm_picksplit(PG_FUNCTION_ARGS)
784 {
785 	GistEntryVector *entryvec = (GistEntryVector *) PG_GETARG_POINTER(0);
786 	OffsetNumber maxoff = entryvec->n - 1;
787 	GIST_SPLITVEC *v = (GIST_SPLITVEC *) PG_GETARG_POINTER(1);
788 	OffsetNumber k,
789 				j;
790 	TRGM	   *datum_l,
791 			   *datum_r;
792 	BITVECP		union_l,
793 				union_r;
794 	int32		size_alpha,
795 				size_beta;
796 	int32		size_waste,
797 				waste = -1;
798 	int32		nbytes;
799 	OffsetNumber seed_1 = 0,
800 				seed_2 = 0;
801 	OffsetNumber *left,
802 			   *right;
803 	BITVECP		ptr;
804 	int			i;
805 	CACHESIGN  *cache;
806 	SPLITCOST  *costvector;
807 
808 	/* cache the sign data for each existing item */
809 	cache = (CACHESIGN *) palloc(sizeof(CACHESIGN) * (maxoff + 1));
810 	for (k = FirstOffsetNumber; k <= maxoff; k = OffsetNumberNext(k))
811 		fillcache(&cache[k], GETENTRY(entryvec, k));
812 
813 	/* now find the two furthest-apart items */
814 	for (k = FirstOffsetNumber; k < maxoff; k = OffsetNumberNext(k))
815 	{
816 		for (j = OffsetNumberNext(k); j <= maxoff; j = OffsetNumberNext(j))
817 		{
818 			size_waste = hemdistcache(&(cache[j]), &(cache[k]));
819 			if (size_waste > waste)
820 			{
821 				waste = size_waste;
822 				seed_1 = k;
823 				seed_2 = j;
824 			}
825 		}
826 	}
827 
828 	/* just in case we didn't make a selection ... */
829 	if (seed_1 == 0 || seed_2 == 0)
830 	{
831 		seed_1 = 1;
832 		seed_2 = 2;
833 	}
834 
835 	/* initialize the result vectors */
836 	nbytes = maxoff * sizeof(OffsetNumber);
837 	v->spl_left = left = (OffsetNumber *) palloc(nbytes);
838 	v->spl_right = right = (OffsetNumber *) palloc(nbytes);
839 	v->spl_nleft = 0;
840 	v->spl_nright = 0;
841 
842 	/* form initial .. */
843 	if (cache[seed_1].allistrue)
844 	{
845 		datum_l = (TRGM *) palloc(CALCGTSIZE(SIGNKEY | ALLISTRUE, 0));
846 		SET_VARSIZE(datum_l, CALCGTSIZE(SIGNKEY | ALLISTRUE, 0));
847 		datum_l->flag = SIGNKEY | ALLISTRUE;
848 	}
849 	else
850 	{
851 		datum_l = (TRGM *) palloc(CALCGTSIZE(SIGNKEY, 0));
852 		SET_VARSIZE(datum_l, CALCGTSIZE(SIGNKEY, 0));
853 		datum_l->flag = SIGNKEY;
854 		memcpy((void *) GETSIGN(datum_l), (void *) cache[seed_1].sign, sizeof(BITVEC));
855 	}
856 	if (cache[seed_2].allistrue)
857 	{
858 		datum_r = (TRGM *) palloc(CALCGTSIZE(SIGNKEY | ALLISTRUE, 0));
859 		SET_VARSIZE(datum_r, CALCGTSIZE(SIGNKEY | ALLISTRUE, 0));
860 		datum_r->flag = SIGNKEY | ALLISTRUE;
861 	}
862 	else
863 	{
864 		datum_r = (TRGM *) palloc(CALCGTSIZE(SIGNKEY, 0));
865 		SET_VARSIZE(datum_r, CALCGTSIZE(SIGNKEY, 0));
866 		datum_r->flag = SIGNKEY;
867 		memcpy((void *) GETSIGN(datum_r), (void *) cache[seed_2].sign, sizeof(BITVEC));
868 	}
869 
870 	union_l = GETSIGN(datum_l);
871 	union_r = GETSIGN(datum_r);
872 
873 	/* sort before ... */
874 	costvector = (SPLITCOST *) palloc(sizeof(SPLITCOST) * maxoff);
875 	for (j = FirstOffsetNumber; j <= maxoff; j = OffsetNumberNext(j))
876 	{
877 		costvector[j - 1].pos = j;
878 		size_alpha = hemdistcache(&(cache[seed_1]), &(cache[j]));
879 		size_beta = hemdistcache(&(cache[seed_2]), &(cache[j]));
880 		costvector[j - 1].cost = abs(size_alpha - size_beta);
881 	}
882 	qsort((void *) costvector, maxoff, sizeof(SPLITCOST), comparecost);
883 
884 	for (k = 0; k < maxoff; k++)
885 	{
886 		j = costvector[k].pos;
887 		if (j == seed_1)
888 		{
889 			*left++ = j;
890 			v->spl_nleft++;
891 			continue;
892 		}
893 		else if (j == seed_2)
894 		{
895 			*right++ = j;
896 			v->spl_nright++;
897 			continue;
898 		}
899 
900 		if (ISALLTRUE(datum_l) || cache[j].allistrue)
901 		{
902 			if (ISALLTRUE(datum_l) && cache[j].allistrue)
903 				size_alpha = 0;
904 			else
905 				size_alpha = SIGLENBIT - sizebitvec(
906 													(cache[j].allistrue) ? GETSIGN(datum_l) : GETSIGN(cache[j].sign)
907 					);
908 		}
909 		else
910 			size_alpha = hemdistsign(cache[j].sign, GETSIGN(datum_l));
911 
912 		if (ISALLTRUE(datum_r) || cache[j].allistrue)
913 		{
914 			if (ISALLTRUE(datum_r) && cache[j].allistrue)
915 				size_beta = 0;
916 			else
917 				size_beta = SIGLENBIT - sizebitvec(
918 												   (cache[j].allistrue) ? GETSIGN(datum_r) : GETSIGN(cache[j].sign)
919 					);
920 		}
921 		else
922 			size_beta = hemdistsign(cache[j].sign, GETSIGN(datum_r));
923 
924 		if (size_alpha < size_beta + WISH_F(v->spl_nleft, v->spl_nright, 0.1))
925 		{
926 			if (ISALLTRUE(datum_l) || cache[j].allistrue)
927 			{
928 				if (!ISALLTRUE(datum_l))
929 					MemSet((void *) GETSIGN(datum_l), 0xff, sizeof(BITVEC));
930 			}
931 			else
932 			{
933 				ptr = cache[j].sign;
934 				LOOPBYTE
935 					union_l[i] |= ptr[i];
936 			}
937 			*left++ = j;
938 			v->spl_nleft++;
939 		}
940 		else
941 		{
942 			if (ISALLTRUE(datum_r) || cache[j].allistrue)
943 			{
944 				if (!ISALLTRUE(datum_r))
945 					MemSet((void *) GETSIGN(datum_r), 0xff, sizeof(BITVEC));
946 			}
947 			else
948 			{
949 				ptr = cache[j].sign;
950 				LOOPBYTE
951 					union_r[i] |= ptr[i];
952 			}
953 			*right++ = j;
954 			v->spl_nright++;
955 		}
956 	}
957 
958 	v->spl_ldatum = PointerGetDatum(datum_l);
959 	v->spl_rdatum = PointerGetDatum(datum_r);
960 
961 	PG_RETURN_POINTER(v);
962 }
963