1 /*
2  * contrib/pg_trgm/trgm_gist.c
3  */
4 #include "postgres.h"
5 
6 #include "access/reloptions.h"
7 #include "access/stratnum.h"
8 #include "fmgr.h"
9 #include "port/pg_bitutils.h"
10 #include "trgm.h"
11 
12 /* gist_trgm_ops opclass options */
13 typedef struct
14 {
15 	int32		vl_len_;		/* varlena header (do not touch directly!) */
16 	int			siglen;			/* signature length in bytes */
17 } TrgmGistOptions;
18 
19 #define GET_SIGLEN()			(PG_HAS_OPCLASS_OPTIONS() ? \
20 								 ((TrgmGistOptions *) PG_GET_OPCLASS_OPTIONS())->siglen : \
21 								 SIGLEN_DEFAULT)
22 
23 typedef struct
24 {
25 	/* most recent inputs to gtrgm_consistent */
26 	StrategyNumber strategy;
27 	text	   *query;
28 	/* extracted trigrams for query */
29 	TRGM	   *trigrams;
30 	/* if a regex operator, the extracted graph */
31 	TrgmPackedGraph *graph;
32 
33 	/*
34 	 * The "query" and "trigrams" are stored in the same palloc block as this
35 	 * cache struct, at MAXALIGN'ed offsets.  The graph however isn't.
36 	 */
37 } gtrgm_consistent_cache;
38 
39 #define GETENTRY(vec,pos) ((TRGM *) DatumGetPointer((vec)->vector[(pos)].key))
40 
41 
42 PG_FUNCTION_INFO_V1(gtrgm_in);
43 PG_FUNCTION_INFO_V1(gtrgm_out);
44 PG_FUNCTION_INFO_V1(gtrgm_compress);
45 PG_FUNCTION_INFO_V1(gtrgm_decompress);
46 PG_FUNCTION_INFO_V1(gtrgm_consistent);
47 PG_FUNCTION_INFO_V1(gtrgm_distance);
48 PG_FUNCTION_INFO_V1(gtrgm_union);
49 PG_FUNCTION_INFO_V1(gtrgm_same);
50 PG_FUNCTION_INFO_V1(gtrgm_penalty);
51 PG_FUNCTION_INFO_V1(gtrgm_picksplit);
52 PG_FUNCTION_INFO_V1(gtrgm_options);
53 
54 
55 Datum
gtrgm_in(PG_FUNCTION_ARGS)56 gtrgm_in(PG_FUNCTION_ARGS)
57 {
58 	elog(ERROR, "not implemented");
59 	PG_RETURN_DATUM(0);
60 }
61 
62 Datum
gtrgm_out(PG_FUNCTION_ARGS)63 gtrgm_out(PG_FUNCTION_ARGS)
64 {
65 	elog(ERROR, "not implemented");
66 	PG_RETURN_DATUM(0);
67 }
68 
69 static TRGM *
gtrgm_alloc(bool isalltrue,int siglen,BITVECP sign)70 gtrgm_alloc(bool isalltrue, int siglen, BITVECP sign)
71 {
72 	int			flag = SIGNKEY | (isalltrue ? ALLISTRUE : 0);
73 	int			size = CALCGTSIZE(flag, siglen);
74 	TRGM	   *res = palloc(size);
75 
76 	SET_VARSIZE(res, size);
77 	res->flag = flag;
78 
79 	if (!isalltrue)
80 	{
81 		if (sign)
82 			memcpy(GETSIGN(res), sign, siglen);
83 		else
84 			memset(GETSIGN(res), 0, siglen);
85 	}
86 
87 	return res;
88 }
89 
90 static void
makesign(BITVECP sign,TRGM * a,int siglen)91 makesign(BITVECP sign, TRGM *a, int siglen)
92 {
93 	int32		k,
94 				len = ARRNELEM(a);
95 	trgm	   *ptr = GETARR(a);
96 	int32		tmp = 0;
97 
98 	MemSet((void *) sign, 0, siglen);
99 	SETBIT(sign, SIGLENBIT(siglen));	/* set last unused bit */
100 	for (k = 0; k < len; k++)
101 	{
102 		CPTRGM(((char *) &tmp), ptr + k);
103 		HASH(sign, tmp, siglen);
104 	}
105 }
106 
107 Datum
gtrgm_compress(PG_FUNCTION_ARGS)108 gtrgm_compress(PG_FUNCTION_ARGS)
109 {
110 	GISTENTRY  *entry = (GISTENTRY *) PG_GETARG_POINTER(0);
111 	int			siglen = GET_SIGLEN();
112 	GISTENTRY  *retval = entry;
113 
114 	if (entry->leafkey)
115 	{							/* trgm */
116 		TRGM	   *res;
117 		text	   *val = DatumGetTextPP(entry->key);
118 
119 		res = generate_trgm(VARDATA_ANY(val), VARSIZE_ANY_EXHDR(val));
120 		retval = (GISTENTRY *) palloc(sizeof(GISTENTRY));
121 		gistentryinit(*retval, PointerGetDatum(res),
122 					  entry->rel, entry->page,
123 					  entry->offset, false);
124 	}
125 	else if (ISSIGNKEY(DatumGetPointer(entry->key)) &&
126 			 !ISALLTRUE(DatumGetPointer(entry->key)))
127 	{
128 		int32		i;
129 		TRGM	   *res;
130 		BITVECP		sign = GETSIGN(DatumGetPointer(entry->key));
131 
132 		LOOPBYTE(siglen)
133 		{
134 			if ((sign[i] & 0xff) != 0xff)
135 				PG_RETURN_POINTER(retval);
136 		}
137 
138 		res = gtrgm_alloc(true, siglen, sign);
139 		retval = (GISTENTRY *) palloc(sizeof(GISTENTRY));
140 		gistentryinit(*retval, PointerGetDatum(res),
141 					  entry->rel, entry->page,
142 					  entry->offset, false);
143 	}
144 	PG_RETURN_POINTER(retval);
145 }
146 
147 Datum
gtrgm_decompress(PG_FUNCTION_ARGS)148 gtrgm_decompress(PG_FUNCTION_ARGS)
149 {
150 	GISTENTRY  *entry = (GISTENTRY *) PG_GETARG_POINTER(0);
151 	GISTENTRY  *retval;
152 	text	   *key;
153 
154 	key = DatumGetTextPP(entry->key);
155 
156 	if (key != (text *) DatumGetPointer(entry->key))
157 	{
158 		/* need to pass back the decompressed item */
159 		retval = palloc(sizeof(GISTENTRY));
160 		gistentryinit(*retval, PointerGetDatum(key),
161 					  entry->rel, entry->page, entry->offset, entry->leafkey);
162 		PG_RETURN_POINTER(retval);
163 	}
164 	else
165 	{
166 		/* we can return the entry as-is */
167 		PG_RETURN_POINTER(entry);
168 	}
169 }
170 
171 static int32
cnt_sml_sign_common(TRGM * qtrg,BITVECP sign,int siglen)172 cnt_sml_sign_common(TRGM *qtrg, BITVECP sign, int siglen)
173 {
174 	int32		count = 0;
175 	int32		k,
176 				len = ARRNELEM(qtrg);
177 	trgm	   *ptr = GETARR(qtrg);
178 	int32		tmp = 0;
179 
180 	for (k = 0; k < len; k++)
181 	{
182 		CPTRGM(((char *) &tmp), ptr + k);
183 		count += GETBIT(sign, HASHVAL(tmp, siglen));
184 	}
185 
186 	return count;
187 }
188 
189 Datum
gtrgm_consistent(PG_FUNCTION_ARGS)190 gtrgm_consistent(PG_FUNCTION_ARGS)
191 {
192 	GISTENTRY  *entry = (GISTENTRY *) PG_GETARG_POINTER(0);
193 	text	   *query = PG_GETARG_TEXT_P(1);
194 	StrategyNumber strategy = (StrategyNumber) PG_GETARG_UINT16(2);
195 
196 	/* Oid		subtype = PG_GETARG_OID(3); */
197 	bool	   *recheck = (bool *) PG_GETARG_POINTER(4);
198 	int			siglen = GET_SIGLEN();
199 	TRGM	   *key = (TRGM *) DatumGetPointer(entry->key);
200 	TRGM	   *qtrg;
201 	bool		res;
202 	Size		querysize = VARSIZE(query);
203 	gtrgm_consistent_cache *cache;
204 	double		nlimit;
205 
206 	/*
207 	 * We keep the extracted trigrams in cache, because trigram extraction is
208 	 * relatively CPU-expensive.  When trying to reuse a cached value, check
209 	 * strategy number not just query itself, because trigram extraction
210 	 * depends on strategy.
211 	 *
212 	 * The cached structure is a single palloc chunk containing the
213 	 * gtrgm_consistent_cache header, then the input query (4-byte length
214 	 * word, uncompressed, starting at a MAXALIGN boundary), then the TRGM
215 	 * value (also starting at a MAXALIGN boundary).  However we don't try to
216 	 * include the regex graph (if any) in that struct.  (XXX currently, this
217 	 * approach can leak regex graphs across index rescans.  Not clear if
218 	 * that's worth fixing.)
219 	 */
220 	cache = (gtrgm_consistent_cache *) fcinfo->flinfo->fn_extra;
221 	if (cache == NULL ||
222 		cache->strategy != strategy ||
223 		VARSIZE(cache->query) != querysize ||
224 		memcmp((char *) cache->query, (char *) query, querysize) != 0)
225 	{
226 		gtrgm_consistent_cache *newcache;
227 		TrgmPackedGraph *graph = NULL;
228 		Size		qtrgsize;
229 
230 		switch (strategy)
231 		{
232 			case SimilarityStrategyNumber:
233 			case WordSimilarityStrategyNumber:
234 			case StrictWordSimilarityStrategyNumber:
235 				qtrg = generate_trgm(VARDATA(query),
236 									 querysize - VARHDRSZ);
237 				break;
238 			case ILikeStrategyNumber:
239 #ifndef IGNORECASE
240 				elog(ERROR, "cannot handle ~~* with case-sensitive trigrams");
241 #endif
242 				/* FALL THRU */
243 			case LikeStrategyNumber:
244 				qtrg = generate_wildcard_trgm(VARDATA(query),
245 											  querysize - VARHDRSZ);
246 				break;
247 			case RegExpICaseStrategyNumber:
248 #ifndef IGNORECASE
249 				elog(ERROR, "cannot handle ~* with case-sensitive trigrams");
250 #endif
251 				/* FALL THRU */
252 			case RegExpStrategyNumber:
253 				qtrg = createTrgmNFA(query, PG_GET_COLLATION(),
254 									 &graph, fcinfo->flinfo->fn_mcxt);
255 				/* just in case an empty array is returned ... */
256 				if (qtrg && ARRNELEM(qtrg) <= 0)
257 				{
258 					pfree(qtrg);
259 					qtrg = NULL;
260 				}
261 				break;
262 			default:
263 				elog(ERROR, "unrecognized strategy number: %d", strategy);
264 				qtrg = NULL;	/* keep compiler quiet */
265 				break;
266 		}
267 
268 		qtrgsize = qtrg ? VARSIZE(qtrg) : 0;
269 
270 		newcache = (gtrgm_consistent_cache *)
271 			MemoryContextAlloc(fcinfo->flinfo->fn_mcxt,
272 							   MAXALIGN(sizeof(gtrgm_consistent_cache)) +
273 							   MAXALIGN(querysize) +
274 							   qtrgsize);
275 
276 		newcache->strategy = strategy;
277 		newcache->query = (text *)
278 			((char *) newcache + MAXALIGN(sizeof(gtrgm_consistent_cache)));
279 		memcpy((char *) newcache->query, (char *) query, querysize);
280 		if (qtrg)
281 		{
282 			newcache->trigrams = (TRGM *)
283 				((char *) newcache->query + MAXALIGN(querysize));
284 			memcpy((char *) newcache->trigrams, (char *) qtrg, qtrgsize);
285 			/* release qtrg in case it was made in fn_mcxt */
286 			pfree(qtrg);
287 		}
288 		else
289 			newcache->trigrams = NULL;
290 		newcache->graph = graph;
291 
292 		if (cache)
293 			pfree(cache);
294 		fcinfo->flinfo->fn_extra = (void *) newcache;
295 		cache = newcache;
296 	}
297 
298 	qtrg = cache->trigrams;
299 
300 	switch (strategy)
301 	{
302 		case SimilarityStrategyNumber:
303 		case WordSimilarityStrategyNumber:
304 		case StrictWordSimilarityStrategyNumber:
305 
306 			/*
307 			 * Similarity search is exact. (Strict) word similarity search is
308 			 * inexact
309 			 */
310 			*recheck = (strategy != SimilarityStrategyNumber);
311 
312 			nlimit = index_strategy_get_limit(strategy);
313 
314 			if (GIST_LEAF(entry))
315 			{					/* all leafs contains orig trgm */
316 				double		tmpsml = cnt_sml(qtrg, key, *recheck);
317 
318 				res = (tmpsml >= nlimit);
319 			}
320 			else if (ISALLTRUE(key))
321 			{					/* non-leaf contains signature */
322 				res = true;
323 			}
324 			else
325 			{					/* non-leaf contains signature */
326 				int32		count = cnt_sml_sign_common(qtrg, GETSIGN(key), siglen);
327 				int32		len = ARRNELEM(qtrg);
328 
329 				if (len == 0)
330 					res = false;
331 				else
332 					res = (((((float8) count) / ((float8) len))) >= nlimit);
333 			}
334 			break;
335 		case ILikeStrategyNumber:
336 #ifndef IGNORECASE
337 			elog(ERROR, "cannot handle ~~* with case-sensitive trigrams");
338 #endif
339 			/* FALL THRU */
340 		case LikeStrategyNumber:
341 			/* Wildcard search is inexact */
342 			*recheck = true;
343 
344 			/*
345 			 * Check if all the extracted trigrams can be present in child
346 			 * nodes.
347 			 */
348 			if (GIST_LEAF(entry))
349 			{					/* all leafs contains orig trgm */
350 				res = trgm_contained_by(qtrg, key);
351 			}
352 			else if (ISALLTRUE(key))
353 			{					/* non-leaf contains signature */
354 				res = true;
355 			}
356 			else
357 			{					/* non-leaf contains signature */
358 				int32		k,
359 							tmp = 0,
360 							len = ARRNELEM(qtrg);
361 				trgm	   *ptr = GETARR(qtrg);
362 				BITVECP		sign = GETSIGN(key);
363 
364 				res = true;
365 				for (k = 0; k < len; k++)
366 				{
367 					CPTRGM(((char *) &tmp), ptr + k);
368 					if (!GETBIT(sign, HASHVAL(tmp, siglen)))
369 					{
370 						res = false;
371 						break;
372 					}
373 				}
374 			}
375 			break;
376 		case RegExpICaseStrategyNumber:
377 #ifndef IGNORECASE
378 			elog(ERROR, "cannot handle ~* with case-sensitive trigrams");
379 #endif
380 			/* FALL THRU */
381 		case RegExpStrategyNumber:
382 			/* Regexp search is inexact */
383 			*recheck = true;
384 
385 			/* Check regex match as much as we can with available info */
386 			if (qtrg)
387 			{
388 				if (GIST_LEAF(entry))
389 				{				/* all leafs contains orig trgm */
390 					bool	   *check;
391 
392 					check = trgm_presence_map(qtrg, key);
393 					res = trigramsMatchGraph(cache->graph, check);
394 					pfree(check);
395 				}
396 				else if (ISALLTRUE(key))
397 				{				/* non-leaf contains signature */
398 					res = true;
399 				}
400 				else
401 				{				/* non-leaf contains signature */
402 					int32		k,
403 								tmp = 0,
404 								len = ARRNELEM(qtrg);
405 					trgm	   *ptr = GETARR(qtrg);
406 					BITVECP		sign = GETSIGN(key);
407 					bool	   *check;
408 
409 					/*
410 					 * GETBIT() tests may give false positives, due to limited
411 					 * size of the sign array.  But since trigramsMatchGraph()
412 					 * implements a monotone boolean function, false positives
413 					 * in the check array can't lead to false negative answer.
414 					 * So we can apply trigramsMatchGraph despite uncertainty,
415 					 * and that usefully improves the quality of the search.
416 					 */
417 					check = (bool *) palloc(len * sizeof(bool));
418 					for (k = 0; k < len; k++)
419 					{
420 						CPTRGM(((char *) &tmp), ptr + k);
421 						check[k] = GETBIT(sign, HASHVAL(tmp, siglen));
422 					}
423 					res = trigramsMatchGraph(cache->graph, check);
424 					pfree(check);
425 				}
426 			}
427 			else
428 			{
429 				/* trigram-free query must be rechecked everywhere */
430 				res = true;
431 			}
432 			break;
433 		default:
434 			elog(ERROR, "unrecognized strategy number: %d", strategy);
435 			res = false;		/* keep compiler quiet */
436 			break;
437 	}
438 
439 	PG_RETURN_BOOL(res);
440 }
441 
442 Datum
gtrgm_distance(PG_FUNCTION_ARGS)443 gtrgm_distance(PG_FUNCTION_ARGS)
444 {
445 	GISTENTRY  *entry = (GISTENTRY *) PG_GETARG_POINTER(0);
446 	text	   *query = PG_GETARG_TEXT_P(1);
447 	StrategyNumber strategy = (StrategyNumber) PG_GETARG_UINT16(2);
448 
449 	/* Oid		subtype = PG_GETARG_OID(3); */
450 	bool	   *recheck = (bool *) PG_GETARG_POINTER(4);
451 	int			siglen = GET_SIGLEN();
452 	TRGM	   *key = (TRGM *) DatumGetPointer(entry->key);
453 	TRGM	   *qtrg;
454 	float8		res;
455 	Size		querysize = VARSIZE(query);
456 	char	   *cache = (char *) fcinfo->flinfo->fn_extra;
457 
458 	/*
459 	 * Cache the generated trigrams across multiple calls with the same query.
460 	 */
461 	if (cache == NULL ||
462 		VARSIZE(cache) != querysize ||
463 		memcmp(cache, query, querysize) != 0)
464 	{
465 		char	   *newcache;
466 
467 		qtrg = generate_trgm(VARDATA(query), querysize - VARHDRSZ);
468 
469 		newcache = MemoryContextAlloc(fcinfo->flinfo->fn_mcxt,
470 									  MAXALIGN(querysize) +
471 									  VARSIZE(qtrg));
472 
473 		memcpy(newcache, query, querysize);
474 		memcpy(newcache + MAXALIGN(querysize), qtrg, VARSIZE(qtrg));
475 
476 		if (cache)
477 			pfree(cache);
478 		fcinfo->flinfo->fn_extra = newcache;
479 		cache = newcache;
480 	}
481 
482 	qtrg = (TRGM *) (cache + MAXALIGN(querysize));
483 
484 	switch (strategy)
485 	{
486 		case DistanceStrategyNumber:
487 		case WordDistanceStrategyNumber:
488 		case StrictWordDistanceStrategyNumber:
489 			/* Only plain trigram distance is exact */
490 			*recheck = (strategy != DistanceStrategyNumber);
491 			if (GIST_LEAF(entry))
492 			{					/* all leafs contains orig trgm */
493 
494 				/*
495 				 * Prevent gcc optimizing the sml variable using volatile
496 				 * keyword. Otherwise res can differ from the
497 				 * word_similarity_dist_op() function.
498 				 */
499 				float4 volatile sml = cnt_sml(qtrg, key, *recheck);
500 
501 				res = 1.0 - sml;
502 			}
503 			else if (ISALLTRUE(key))
504 			{					/* all leafs contains orig trgm */
505 				res = 0.0;
506 			}
507 			else
508 			{					/* non-leaf contains signature */
509 				int32		count = cnt_sml_sign_common(qtrg, GETSIGN(key), siglen);
510 				int32		len = ARRNELEM(qtrg);
511 
512 				res = (len == 0) ? -1.0 : 1.0 - ((float8) count) / ((float8) len);
513 			}
514 			break;
515 		default:
516 			elog(ERROR, "unrecognized strategy number: %d", strategy);
517 			res = 0;			/* keep compiler quiet */
518 			break;
519 	}
520 
521 	PG_RETURN_FLOAT8(res);
522 }
523 
524 static int32
unionkey(BITVECP sbase,TRGM * add,int siglen)525 unionkey(BITVECP sbase, TRGM *add, int siglen)
526 {
527 	int32		i;
528 
529 	if (ISSIGNKEY(add))
530 	{
531 		BITVECP		sadd = GETSIGN(add);
532 
533 		if (ISALLTRUE(add))
534 			return 1;
535 
536 		LOOPBYTE(siglen)
537 			sbase[i] |= sadd[i];
538 	}
539 	else
540 	{
541 		trgm	   *ptr = GETARR(add);
542 		int32		tmp = 0;
543 
544 		for (i = 0; i < ARRNELEM(add); i++)
545 		{
546 			CPTRGM(((char *) &tmp), ptr + i);
547 			HASH(sbase, tmp, siglen);
548 		}
549 	}
550 	return 0;
551 }
552 
553 
554 Datum
gtrgm_union(PG_FUNCTION_ARGS)555 gtrgm_union(PG_FUNCTION_ARGS)
556 {
557 	GistEntryVector *entryvec = (GistEntryVector *) PG_GETARG_POINTER(0);
558 	int32		len = entryvec->n;
559 	int		   *size = (int *) PG_GETARG_POINTER(1);
560 	int			siglen = GET_SIGLEN();
561 	int32		i;
562 	TRGM	   *result = gtrgm_alloc(false, siglen, NULL);
563 	BITVECP		base = GETSIGN(result);
564 
565 	for (i = 0; i < len; i++)
566 	{
567 		if (unionkey(base, GETENTRY(entryvec, i), siglen))
568 		{
569 			result->flag = ALLISTRUE;
570 			SET_VARSIZE(result, CALCGTSIZE(ALLISTRUE, siglen));
571 			break;
572 		}
573 	}
574 
575 	*size = VARSIZE(result);
576 
577 	PG_RETURN_POINTER(result);
578 }
579 
580 Datum
gtrgm_same(PG_FUNCTION_ARGS)581 gtrgm_same(PG_FUNCTION_ARGS)
582 {
583 	TRGM	   *a = (TRGM *) PG_GETARG_POINTER(0);
584 	TRGM	   *b = (TRGM *) PG_GETARG_POINTER(1);
585 	bool	   *result = (bool *) PG_GETARG_POINTER(2);
586 	int			siglen = GET_SIGLEN();
587 
588 	if (ISSIGNKEY(a))
589 	{							/* then b also ISSIGNKEY */
590 		if (ISALLTRUE(a) && ISALLTRUE(b))
591 			*result = true;
592 		else if (ISALLTRUE(a))
593 			*result = false;
594 		else if (ISALLTRUE(b))
595 			*result = false;
596 		else
597 		{
598 			int32		i;
599 			BITVECP		sa = GETSIGN(a),
600 						sb = GETSIGN(b);
601 
602 			*result = true;
603 			LOOPBYTE(siglen)
604 			{
605 				if (sa[i] != sb[i])
606 				{
607 					*result = false;
608 					break;
609 				}
610 			}
611 		}
612 	}
613 	else
614 	{							/* a and b ISARRKEY */
615 		int32		lena = ARRNELEM(a),
616 					lenb = ARRNELEM(b);
617 
618 		if (lena != lenb)
619 			*result = false;
620 		else
621 		{
622 			trgm	   *ptra = GETARR(a),
623 					   *ptrb = GETARR(b);
624 			int32		i;
625 
626 			*result = true;
627 			for (i = 0; i < lena; i++)
628 				if (CMPTRGM(ptra + i, ptrb + i))
629 				{
630 					*result = false;
631 					break;
632 				}
633 		}
634 	}
635 
636 	PG_RETURN_POINTER(result);
637 }
638 
639 static int32
sizebitvec(BITVECP sign,int siglen)640 sizebitvec(BITVECP sign, int siglen)
641 {
642 	return pg_popcount(sign, siglen);
643 }
644 
645 static int
hemdistsign(BITVECP a,BITVECP b,int siglen)646 hemdistsign(BITVECP a, BITVECP b, int siglen)
647 {
648 	int			i,
649 				diff,
650 				dist = 0;
651 
652 	LOOPBYTE(siglen)
653 	{
654 		diff = (unsigned char) (a[i] ^ b[i]);
655 		/* Using the popcount functions here isn't likely to win */
656 		dist += pg_number_of_ones[diff];
657 	}
658 	return dist;
659 }
660 
661 static int
hemdist(TRGM * a,TRGM * b,int siglen)662 hemdist(TRGM *a, TRGM *b, int siglen)
663 {
664 	if (ISALLTRUE(a))
665 	{
666 		if (ISALLTRUE(b))
667 			return 0;
668 		else
669 			return SIGLENBIT(siglen) - sizebitvec(GETSIGN(b), siglen);
670 	}
671 	else if (ISALLTRUE(b))
672 		return SIGLENBIT(siglen) - sizebitvec(GETSIGN(a), siglen);
673 
674 	return hemdistsign(GETSIGN(a), GETSIGN(b), siglen);
675 }
676 
677 Datum
gtrgm_penalty(PG_FUNCTION_ARGS)678 gtrgm_penalty(PG_FUNCTION_ARGS)
679 {
680 	GISTENTRY  *origentry = (GISTENTRY *) PG_GETARG_POINTER(0); /* always ISSIGNKEY */
681 	GISTENTRY  *newentry = (GISTENTRY *) PG_GETARG_POINTER(1);
682 	float	   *penalty = (float *) PG_GETARG_POINTER(2);
683 	int			siglen = GET_SIGLEN();
684 	TRGM	   *origval = (TRGM *) DatumGetPointer(origentry->key);
685 	TRGM	   *newval = (TRGM *) DatumGetPointer(newentry->key);
686 	BITVECP		orig = GETSIGN(origval);
687 
688 	*penalty = 0.0;
689 
690 	if (ISARRKEY(newval))
691 	{
692 		char	   *cache = (char *) fcinfo->flinfo->fn_extra;
693 		TRGM	   *cachedVal = (TRGM *) (cache + MAXALIGN(siglen));
694 		Size		newvalsize = VARSIZE(newval);
695 		BITVECP		sign;
696 
697 		/*
698 		 * Cache the sign data across multiple calls with the same newval.
699 		 */
700 		if (cache == NULL ||
701 			VARSIZE(cachedVal) != newvalsize ||
702 			memcmp(cachedVal, newval, newvalsize) != 0)
703 		{
704 			char	   *newcache;
705 
706 			newcache = MemoryContextAlloc(fcinfo->flinfo->fn_mcxt,
707 										  MAXALIGN(siglen) +
708 										  newvalsize);
709 
710 			makesign((BITVECP) newcache, newval, siglen);
711 
712 			cachedVal = (TRGM *) (newcache + MAXALIGN(siglen));
713 			memcpy(cachedVal, newval, newvalsize);
714 
715 			if (cache)
716 				pfree(cache);
717 			fcinfo->flinfo->fn_extra = newcache;
718 			cache = newcache;
719 		}
720 
721 		sign = (BITVECP) cache;
722 
723 		if (ISALLTRUE(origval))
724 			*penalty = ((float) (SIGLENBIT(siglen) - sizebitvec(sign, siglen))) / (float) (SIGLENBIT(siglen) + 1);
725 		else
726 			*penalty = hemdistsign(sign, orig, siglen);
727 	}
728 	else
729 		*penalty = hemdist(origval, newval, siglen);
730 	PG_RETURN_POINTER(penalty);
731 }
732 
733 typedef struct
734 {
735 	bool		allistrue;
736 	BITVECP		sign;
737 } CACHESIGN;
738 
739 static void
fillcache(CACHESIGN * item,TRGM * key,BITVECP sign,int siglen)740 fillcache(CACHESIGN *item, TRGM *key, BITVECP sign, int siglen)
741 {
742 	item->allistrue = false;
743 	item->sign = sign;
744 	if (ISARRKEY(key))
745 		makesign(item->sign, key, siglen);
746 	else if (ISALLTRUE(key))
747 		item->allistrue = true;
748 	else
749 		memcpy((void *) item->sign, (void *) GETSIGN(key), siglen);
750 }
751 
752 #define WISH_F(a,b,c) (double)( -(double)(((a)-(b))*((a)-(b))*((a)-(b)))*(c) )
753 typedef struct
754 {
755 	OffsetNumber pos;
756 	int32		cost;
757 } SPLITCOST;
758 
759 static int
comparecost(const void * a,const void * b)760 comparecost(const void *a, const void *b)
761 {
762 	if (((const SPLITCOST *) a)->cost == ((const SPLITCOST *) b)->cost)
763 		return 0;
764 	else
765 		return (((const SPLITCOST *) a)->cost > ((const SPLITCOST *) b)->cost) ? 1 : -1;
766 }
767 
768 
769 static int
hemdistcache(CACHESIGN * a,CACHESIGN * b,int siglen)770 hemdistcache(CACHESIGN *a, CACHESIGN *b, int siglen)
771 {
772 	if (a->allistrue)
773 	{
774 		if (b->allistrue)
775 			return 0;
776 		else
777 			return SIGLENBIT(siglen) - sizebitvec(b->sign, siglen);
778 	}
779 	else if (b->allistrue)
780 		return SIGLENBIT(siglen) - sizebitvec(a->sign, siglen);
781 
782 	return hemdistsign(a->sign, b->sign, siglen);
783 }
784 
785 Datum
gtrgm_picksplit(PG_FUNCTION_ARGS)786 gtrgm_picksplit(PG_FUNCTION_ARGS)
787 {
788 	GistEntryVector *entryvec = (GistEntryVector *) PG_GETARG_POINTER(0);
789 	OffsetNumber maxoff = entryvec->n - 1;
790 	GIST_SPLITVEC *v = (GIST_SPLITVEC *) PG_GETARG_POINTER(1);
791 	int			siglen = GET_SIGLEN();
792 	OffsetNumber k,
793 				j;
794 	TRGM	   *datum_l,
795 			   *datum_r;
796 	BITVECP		union_l,
797 				union_r;
798 	int32		size_alpha,
799 				size_beta;
800 	int32		size_waste,
801 				waste = -1;
802 	int32		nbytes;
803 	OffsetNumber seed_1 = 0,
804 				seed_2 = 0;
805 	OffsetNumber *left,
806 			   *right;
807 	BITVECP		ptr;
808 	int			i;
809 	CACHESIGN  *cache;
810 	char	   *cache_sign;
811 	SPLITCOST  *costvector;
812 
813 	/* cache the sign data for each existing item */
814 	cache = (CACHESIGN *) palloc(sizeof(CACHESIGN) * (maxoff + 1));
815 	cache_sign = palloc(siglen * (maxoff + 1));
816 
817 	for (k = FirstOffsetNumber; k <= maxoff; k = OffsetNumberNext(k))
818 		fillcache(&cache[k], GETENTRY(entryvec, k), &cache_sign[siglen * k],
819 				  siglen);
820 
821 	/* now find the two furthest-apart items */
822 	for (k = FirstOffsetNumber; k < maxoff; k = OffsetNumberNext(k))
823 	{
824 		for (j = OffsetNumberNext(k); j <= maxoff; j = OffsetNumberNext(j))
825 		{
826 			size_waste = hemdistcache(&(cache[j]), &(cache[k]), siglen);
827 			if (size_waste > waste)
828 			{
829 				waste = size_waste;
830 				seed_1 = k;
831 				seed_2 = j;
832 			}
833 		}
834 	}
835 
836 	/* just in case we didn't make a selection ... */
837 	if (seed_1 == 0 || seed_2 == 0)
838 	{
839 		seed_1 = 1;
840 		seed_2 = 2;
841 	}
842 
843 	/* initialize the result vectors */
844 	nbytes = maxoff * sizeof(OffsetNumber);
845 	v->spl_left = left = (OffsetNumber *) palloc(nbytes);
846 	v->spl_right = right = (OffsetNumber *) palloc(nbytes);
847 	v->spl_nleft = 0;
848 	v->spl_nright = 0;
849 
850 	/* form initial .. */
851 	datum_l = gtrgm_alloc(cache[seed_1].allistrue, siglen, cache[seed_1].sign);
852 	datum_r = gtrgm_alloc(cache[seed_2].allistrue, siglen, cache[seed_2].sign);
853 
854 	union_l = GETSIGN(datum_l);
855 	union_r = GETSIGN(datum_r);
856 
857 	/* sort before ... */
858 	costvector = (SPLITCOST *) palloc(sizeof(SPLITCOST) * maxoff);
859 	for (j = FirstOffsetNumber; j <= maxoff; j = OffsetNumberNext(j))
860 	{
861 		costvector[j - 1].pos = j;
862 		size_alpha = hemdistcache(&(cache[seed_1]), &(cache[j]), siglen);
863 		size_beta = hemdistcache(&(cache[seed_2]), &(cache[j]), siglen);
864 		costvector[j - 1].cost = abs(size_alpha - size_beta);
865 	}
866 	qsort((void *) costvector, maxoff, sizeof(SPLITCOST), comparecost);
867 
868 	for (k = 0; k < maxoff; k++)
869 	{
870 		j = costvector[k].pos;
871 		if (j == seed_1)
872 		{
873 			*left++ = j;
874 			v->spl_nleft++;
875 			continue;
876 		}
877 		else if (j == seed_2)
878 		{
879 			*right++ = j;
880 			v->spl_nright++;
881 			continue;
882 		}
883 
884 		if (ISALLTRUE(datum_l) || cache[j].allistrue)
885 		{
886 			if (ISALLTRUE(datum_l) && cache[j].allistrue)
887 				size_alpha = 0;
888 			else
889 				size_alpha = SIGLENBIT(siglen) -
890 					sizebitvec((cache[j].allistrue) ? GETSIGN(datum_l) :
891 							   GETSIGN(cache[j].sign),
892 							   siglen);
893 		}
894 		else
895 			size_alpha = hemdistsign(cache[j].sign, GETSIGN(datum_l), siglen);
896 
897 		if (ISALLTRUE(datum_r) || cache[j].allistrue)
898 		{
899 			if (ISALLTRUE(datum_r) && cache[j].allistrue)
900 				size_beta = 0;
901 			else
902 				size_beta = SIGLENBIT(siglen) -
903 					sizebitvec((cache[j].allistrue) ? GETSIGN(datum_r) :
904 							   GETSIGN(cache[j].sign),
905 							   siglen);
906 		}
907 		else
908 			size_beta = hemdistsign(cache[j].sign, GETSIGN(datum_r), siglen);
909 
910 		if (size_alpha < size_beta + WISH_F(v->spl_nleft, v->spl_nright, 0.1))
911 		{
912 			if (ISALLTRUE(datum_l) || cache[j].allistrue)
913 			{
914 				if (!ISALLTRUE(datum_l))
915 					MemSet((void *) GETSIGN(datum_l), 0xff, siglen);
916 			}
917 			else
918 			{
919 				ptr = cache[j].sign;
920 				LOOPBYTE(siglen)
921 					union_l[i] |= ptr[i];
922 			}
923 			*left++ = j;
924 			v->spl_nleft++;
925 		}
926 		else
927 		{
928 			if (ISALLTRUE(datum_r) || cache[j].allistrue)
929 			{
930 				if (!ISALLTRUE(datum_r))
931 					MemSet((void *) GETSIGN(datum_r), 0xff, siglen);
932 			}
933 			else
934 			{
935 				ptr = cache[j].sign;
936 				LOOPBYTE(siglen)
937 					union_r[i] |= ptr[i];
938 			}
939 			*right++ = j;
940 			v->spl_nright++;
941 		}
942 	}
943 
944 	v->spl_ldatum = PointerGetDatum(datum_l);
945 	v->spl_rdatum = PointerGetDatum(datum_r);
946 
947 	PG_RETURN_POINTER(v);
948 }
949 
950 Datum
gtrgm_options(PG_FUNCTION_ARGS)951 gtrgm_options(PG_FUNCTION_ARGS)
952 {
953 	local_relopts *relopts = (local_relopts *) PG_GETARG_POINTER(0);
954 
955 	init_local_reloptions(relopts, sizeof(TrgmGistOptions));
956 	add_local_int_reloption(relopts, "siglen",
957 							"signature length in bytes",
958 							SIGLEN_DEFAULT, 1, SIGLEN_MAX,
959 							offsetof(TrgmGistOptions, siglen));
960 
961 	PG_RETURN_VOID();
962 }
963