1 /*
2  * contrib/pg_trgm/trgm_gist.c
3  */
4 #include "postgres.h"
5 
6 #include "trgm.h"
7 
8 #include "access/stratnum.h"
9 #include "fmgr.h"
10 
11 
12 typedef struct
13 {
14 	/* most recent inputs to gtrgm_consistent */
15 	StrategyNumber strategy;
16 	text	   *query;
17 	/* extracted trigrams for query */
18 	TRGM	   *trigrams;
19 	/* if a regex operator, the extracted graph */
20 	TrgmPackedGraph *graph;
21 
22 	/*
23 	 * The "query" and "trigrams" are stored in the same palloc block as this
24 	 * cache struct, at MAXALIGN'ed offsets.  The graph however isn't.
25 	 */
26 } gtrgm_consistent_cache;
27 
28 #define GETENTRY(vec,pos) ((TRGM *) DatumGetPointer((vec)->vector[(pos)].key))
29 
30 
31 PG_FUNCTION_INFO_V1(gtrgm_in);
32 PG_FUNCTION_INFO_V1(gtrgm_out);
33 PG_FUNCTION_INFO_V1(gtrgm_compress);
34 PG_FUNCTION_INFO_V1(gtrgm_decompress);
35 PG_FUNCTION_INFO_V1(gtrgm_consistent);
36 PG_FUNCTION_INFO_V1(gtrgm_distance);
37 PG_FUNCTION_INFO_V1(gtrgm_union);
38 PG_FUNCTION_INFO_V1(gtrgm_same);
39 PG_FUNCTION_INFO_V1(gtrgm_penalty);
40 PG_FUNCTION_INFO_V1(gtrgm_picksplit);
41 
42 /* Number of one-bits in an unsigned byte */
43 static const uint8 number_of_ones[256] = {
44 	0, 1, 1, 2, 1, 2, 2, 3, 1, 2, 2, 3, 2, 3, 3, 4,
45 	1, 2, 2, 3, 2, 3, 3, 4, 2, 3, 3, 4, 3, 4, 4, 5,
46 	1, 2, 2, 3, 2, 3, 3, 4, 2, 3, 3, 4, 3, 4, 4, 5,
47 	2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6,
48 	1, 2, 2, 3, 2, 3, 3, 4, 2, 3, 3, 4, 3, 4, 4, 5,
49 	2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6,
50 	2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6,
51 	3, 4, 4, 5, 4, 5, 5, 6, 4, 5, 5, 6, 5, 6, 6, 7,
52 	1, 2, 2, 3, 2, 3, 3, 4, 2, 3, 3, 4, 3, 4, 4, 5,
53 	2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6,
54 	2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6,
55 	3, 4, 4, 5, 4, 5, 5, 6, 4, 5, 5, 6, 5, 6, 6, 7,
56 	2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6,
57 	3, 4, 4, 5, 4, 5, 5, 6, 4, 5, 5, 6, 5, 6, 6, 7,
58 	3, 4, 4, 5, 4, 5, 5, 6, 4, 5, 5, 6, 5, 6, 6, 7,
59 	4, 5, 5, 6, 5, 6, 6, 7, 5, 6, 6, 7, 6, 7, 7, 8
60 };
61 
62 
63 Datum
gtrgm_in(PG_FUNCTION_ARGS)64 gtrgm_in(PG_FUNCTION_ARGS)
65 {
66 	elog(ERROR, "not implemented");
67 	PG_RETURN_DATUM(0);
68 }
69 
70 Datum
gtrgm_out(PG_FUNCTION_ARGS)71 gtrgm_out(PG_FUNCTION_ARGS)
72 {
73 	elog(ERROR, "not implemented");
74 	PG_RETURN_DATUM(0);
75 }
76 
77 static void
makesign(BITVECP sign,TRGM * a)78 makesign(BITVECP sign, TRGM *a)
79 {
80 	int32		k,
81 				len = ARRNELEM(a);
82 	trgm	   *ptr = GETARR(a);
83 	int32		tmp = 0;
84 
85 	MemSet((void *) sign, 0, sizeof(BITVEC));
86 	SETBIT(sign, SIGLENBIT);	/* set last unused bit */
87 	for (k = 0; k < len; k++)
88 	{
89 		CPTRGM(((char *) &tmp), ptr + k);
90 		HASH(sign, tmp);
91 	}
92 }
93 
94 Datum
gtrgm_compress(PG_FUNCTION_ARGS)95 gtrgm_compress(PG_FUNCTION_ARGS)
96 {
97 	GISTENTRY  *entry = (GISTENTRY *) PG_GETARG_POINTER(0);
98 	GISTENTRY  *retval = entry;
99 
100 	if (entry->leafkey)
101 	{							/* trgm */
102 		TRGM	   *res;
103 		text	   *val = DatumGetTextPP(entry->key);
104 
105 		res = generate_trgm(VARDATA_ANY(val), VARSIZE_ANY_EXHDR(val));
106 		retval = (GISTENTRY *) palloc(sizeof(GISTENTRY));
107 		gistentryinit(*retval, PointerGetDatum(res),
108 					  entry->rel, entry->page,
109 					  entry->offset, FALSE);
110 	}
111 	else if (ISSIGNKEY(DatumGetPointer(entry->key)) &&
112 			 !ISALLTRUE(DatumGetPointer(entry->key)))
113 	{
114 		int32		i,
115 					len;
116 		TRGM	   *res;
117 		BITVECP		sign = GETSIGN(DatumGetPointer(entry->key));
118 
119 		LOOPBYTE
120 		{
121 			if ((sign[i] & 0xff) != 0xff)
122 				PG_RETURN_POINTER(retval);
123 		}
124 
125 		len = CALCGTSIZE(SIGNKEY | ALLISTRUE, 0);
126 		res = (TRGM *) palloc(len);
127 		SET_VARSIZE(res, len);
128 		res->flag = SIGNKEY | ALLISTRUE;
129 
130 		retval = (GISTENTRY *) palloc(sizeof(GISTENTRY));
131 		gistentryinit(*retval, PointerGetDatum(res),
132 					  entry->rel, entry->page,
133 					  entry->offset, FALSE);
134 	}
135 	PG_RETURN_POINTER(retval);
136 }
137 
138 Datum
gtrgm_decompress(PG_FUNCTION_ARGS)139 gtrgm_decompress(PG_FUNCTION_ARGS)
140 {
141 	GISTENTRY  *entry = (GISTENTRY *) PG_GETARG_POINTER(0);
142 	GISTENTRY  *retval;
143 	text	   *key;
144 
145 	key = DatumGetTextPP(entry->key);
146 
147 	if (key != (text *) DatumGetPointer(entry->key))
148 	{
149 		/* need to pass back the decompressed item */
150 		retval = palloc(sizeof(GISTENTRY));
151 		gistentryinit(*retval, PointerGetDatum(key),
152 					  entry->rel, entry->page, entry->offset, entry->leafkey);
153 		PG_RETURN_POINTER(retval);
154 	}
155 	else
156 	{
157 		/* we can return the entry as-is */
158 		PG_RETURN_POINTER(entry);
159 	}
160 }
161 
162 static int32
cnt_sml_sign_common(TRGM * qtrg,BITVECP sign)163 cnt_sml_sign_common(TRGM *qtrg, BITVECP sign)
164 {
165 	int32		count = 0;
166 	int32		k,
167 				len = ARRNELEM(qtrg);
168 	trgm	   *ptr = GETARR(qtrg);
169 	int32		tmp = 0;
170 
171 	for (k = 0; k < len; k++)
172 	{
173 		CPTRGM(((char *) &tmp), ptr + k);
174 		count += GETBIT(sign, HASHVAL(tmp));
175 	}
176 
177 	return count;
178 }
179 
180 Datum
gtrgm_consistent(PG_FUNCTION_ARGS)181 gtrgm_consistent(PG_FUNCTION_ARGS)
182 {
183 	GISTENTRY  *entry = (GISTENTRY *) PG_GETARG_POINTER(0);
184 	text	   *query = PG_GETARG_TEXT_P(1);
185 	StrategyNumber strategy = (StrategyNumber) PG_GETARG_UINT16(2);
186 
187 	/* Oid		subtype = PG_GETARG_OID(3); */
188 	bool	   *recheck = (bool *) PG_GETARG_POINTER(4);
189 	TRGM	   *key = (TRGM *) DatumGetPointer(entry->key);
190 	TRGM	   *qtrg;
191 	bool		res;
192 	Size		querysize = VARSIZE(query);
193 	gtrgm_consistent_cache *cache;
194 	double		nlimit;
195 
196 	/*
197 	 * We keep the extracted trigrams in cache, because trigram extraction is
198 	 * relatively CPU-expensive.  When trying to reuse a cached value, check
199 	 * strategy number not just query itself, because trigram extraction
200 	 * depends on strategy.
201 	 *
202 	 * The cached structure is a single palloc chunk containing the
203 	 * gtrgm_consistent_cache header, then the input query (4-byte length
204 	 * word, uncompressed, starting at a MAXALIGN boundary), then the TRGM
205 	 * value (also starting at a MAXALIGN boundary).  However we don't try to
206 	 * include the regex graph (if any) in that struct.  (XXX currently, this
207 	 * approach can leak regex graphs across index rescans.  Not clear if
208 	 * that's worth fixing.)
209 	 */
210 	cache = (gtrgm_consistent_cache *) fcinfo->flinfo->fn_extra;
211 	if (cache == NULL ||
212 		cache->strategy != strategy ||
213 		VARSIZE(cache->query) != querysize ||
214 		memcmp((char *) cache->query, (char *) query, querysize) != 0)
215 	{
216 		gtrgm_consistent_cache *newcache;
217 		TrgmPackedGraph *graph = NULL;
218 		Size		qtrgsize;
219 
220 		switch (strategy)
221 		{
222 			case SimilarityStrategyNumber:
223 			case WordSimilarityStrategyNumber:
224 				qtrg = generate_trgm(VARDATA(query),
225 									 querysize - VARHDRSZ);
226 				break;
227 			case ILikeStrategyNumber:
228 #ifndef IGNORECASE
229 				elog(ERROR, "cannot handle ~~* with case-sensitive trigrams");
230 #endif
231 				/* FALL THRU */
232 			case LikeStrategyNumber:
233 				qtrg = generate_wildcard_trgm(VARDATA(query),
234 											  querysize - VARHDRSZ);
235 				break;
236 			case RegExpICaseStrategyNumber:
237 #ifndef IGNORECASE
238 				elog(ERROR, "cannot handle ~* with case-sensitive trigrams");
239 #endif
240 				/* FALL THRU */
241 			case RegExpStrategyNumber:
242 				qtrg = createTrgmNFA(query, PG_GET_COLLATION(),
243 									 &graph, fcinfo->flinfo->fn_mcxt);
244 				/* just in case an empty array is returned ... */
245 				if (qtrg && ARRNELEM(qtrg) <= 0)
246 				{
247 					pfree(qtrg);
248 					qtrg = NULL;
249 				}
250 				break;
251 			default:
252 				elog(ERROR, "unrecognized strategy number: %d", strategy);
253 				qtrg = NULL;	/* keep compiler quiet */
254 				break;
255 		}
256 
257 		qtrgsize = qtrg ? VARSIZE(qtrg) : 0;
258 
259 		newcache = (gtrgm_consistent_cache *)
260 			MemoryContextAlloc(fcinfo->flinfo->fn_mcxt,
261 							   MAXALIGN(sizeof(gtrgm_consistent_cache)) +
262 							   MAXALIGN(querysize) +
263 							   qtrgsize);
264 
265 		newcache->strategy = strategy;
266 		newcache->query = (text *)
267 			((char *) newcache + MAXALIGN(sizeof(gtrgm_consistent_cache)));
268 		memcpy((char *) newcache->query, (char *) query, querysize);
269 		if (qtrg)
270 		{
271 			newcache->trigrams = (TRGM *)
272 				((char *) newcache->query + MAXALIGN(querysize));
273 			memcpy((char *) newcache->trigrams, (char *) qtrg, qtrgsize);
274 			/* release qtrg in case it was made in fn_mcxt */
275 			pfree(qtrg);
276 		}
277 		else
278 			newcache->trigrams = NULL;
279 		newcache->graph = graph;
280 
281 		if (cache)
282 			pfree(cache);
283 		fcinfo->flinfo->fn_extra = (void *) newcache;
284 		cache = newcache;
285 	}
286 
287 	qtrg = cache->trigrams;
288 
289 	switch (strategy)
290 	{
291 		case SimilarityStrategyNumber:
292 		case WordSimilarityStrategyNumber:
293 			/* Similarity search is exact. Word similarity search is inexact */
294 			*recheck = (strategy == WordSimilarityStrategyNumber);
295 			nlimit = (strategy == SimilarityStrategyNumber) ?
296 				similarity_threshold : word_similarity_threshold;
297 
298 			if (GIST_LEAF(entry))
299 			{					/* all leafs contains orig trgm */
300 				double		tmpsml = cnt_sml(qtrg, key, *recheck);
301 
302 				res = (tmpsml >= nlimit);
303 			}
304 			else if (ISALLTRUE(key))
305 			{					/* non-leaf contains signature */
306 				res = true;
307 			}
308 			else
309 			{					/* non-leaf contains signature */
310 				int32		count = cnt_sml_sign_common(qtrg, GETSIGN(key));
311 				int32		len = ARRNELEM(qtrg);
312 
313 				if (len == 0)
314 					res = false;
315 				else
316 					res = (((((float8) count) / ((float8) len))) >= nlimit);
317 			}
318 			break;
319 		case ILikeStrategyNumber:
320 #ifndef IGNORECASE
321 			elog(ERROR, "cannot handle ~~* with case-sensitive trigrams");
322 #endif
323 			/* FALL THRU */
324 		case LikeStrategyNumber:
325 			/* Wildcard search is inexact */
326 			*recheck = true;
327 
328 			/*
329 			 * Check if all the extracted trigrams can be present in child
330 			 * nodes.
331 			 */
332 			if (GIST_LEAF(entry))
333 			{					/* all leafs contains orig trgm */
334 				res = trgm_contained_by(qtrg, key);
335 			}
336 			else if (ISALLTRUE(key))
337 			{					/* non-leaf contains signature */
338 				res = true;
339 			}
340 			else
341 			{					/* non-leaf contains signature */
342 				int32		k,
343 							tmp = 0,
344 							len = ARRNELEM(qtrg);
345 				trgm	   *ptr = GETARR(qtrg);
346 				BITVECP		sign = GETSIGN(key);
347 
348 				res = true;
349 				for (k = 0; k < len; k++)
350 				{
351 					CPTRGM(((char *) &tmp), ptr + k);
352 					if (!GETBIT(sign, HASHVAL(tmp)))
353 					{
354 						res = false;
355 						break;
356 					}
357 				}
358 			}
359 			break;
360 		case RegExpICaseStrategyNumber:
361 #ifndef IGNORECASE
362 			elog(ERROR, "cannot handle ~* with case-sensitive trigrams");
363 #endif
364 			/* FALL THRU */
365 		case RegExpStrategyNumber:
366 			/* Regexp search is inexact */
367 			*recheck = true;
368 
369 			/* Check regex match as much as we can with available info */
370 			if (qtrg)
371 			{
372 				if (GIST_LEAF(entry))
373 				{				/* all leafs contains orig trgm */
374 					bool	   *check;
375 
376 					check = trgm_presence_map(qtrg, key);
377 					res = trigramsMatchGraph(cache->graph, check);
378 					pfree(check);
379 				}
380 				else if (ISALLTRUE(key))
381 				{				/* non-leaf contains signature */
382 					res = true;
383 				}
384 				else
385 				{				/* non-leaf contains signature */
386 					int32		k,
387 								tmp = 0,
388 								len = ARRNELEM(qtrg);
389 					trgm	   *ptr = GETARR(qtrg);
390 					BITVECP		sign = GETSIGN(key);
391 					bool	   *check;
392 
393 					/*
394 					 * GETBIT() tests may give false positives, due to limited
395 					 * size of the sign array.  But since trigramsMatchGraph()
396 					 * implements a monotone boolean function, false positives
397 					 * in the check array can't lead to false negative answer.
398 					 * So we can apply trigramsMatchGraph despite uncertainty,
399 					 * and that usefully improves the quality of the search.
400 					 */
401 					check = (bool *) palloc(len * sizeof(bool));
402 					for (k = 0; k < len; k++)
403 					{
404 						CPTRGM(((char *) &tmp), ptr + k);
405 						check[k] = GETBIT(sign, HASHVAL(tmp));
406 					}
407 					res = trigramsMatchGraph(cache->graph, check);
408 					pfree(check);
409 				}
410 			}
411 			else
412 			{
413 				/* trigram-free query must be rechecked everywhere */
414 				res = true;
415 			}
416 			break;
417 		default:
418 			elog(ERROR, "unrecognized strategy number: %d", strategy);
419 			res = false;		/* keep compiler quiet */
420 			break;
421 	}
422 
423 	PG_RETURN_BOOL(res);
424 }
425 
426 Datum
gtrgm_distance(PG_FUNCTION_ARGS)427 gtrgm_distance(PG_FUNCTION_ARGS)
428 {
429 	GISTENTRY  *entry = (GISTENTRY *) PG_GETARG_POINTER(0);
430 	text	   *query = PG_GETARG_TEXT_P(1);
431 	StrategyNumber strategy = (StrategyNumber) PG_GETARG_UINT16(2);
432 
433 	/* Oid		subtype = PG_GETARG_OID(3); */
434 	bool	   *recheck = (bool *) PG_GETARG_POINTER(4);
435 	TRGM	   *key = (TRGM *) DatumGetPointer(entry->key);
436 	TRGM	   *qtrg;
437 	float8		res;
438 	Size		querysize = VARSIZE(query);
439 	char	   *cache = (char *) fcinfo->flinfo->fn_extra;
440 
441 	/*
442 	 * Cache the generated trigrams across multiple calls with the same query.
443 	 */
444 	if (cache == NULL ||
445 		VARSIZE(cache) != querysize ||
446 		memcmp(cache, query, querysize) != 0)
447 	{
448 		char	   *newcache;
449 
450 		qtrg = generate_trgm(VARDATA(query), querysize - VARHDRSZ);
451 
452 		newcache = MemoryContextAlloc(fcinfo->flinfo->fn_mcxt,
453 									  MAXALIGN(querysize) +
454 									  VARSIZE(qtrg));
455 
456 		memcpy(newcache, query, querysize);
457 		memcpy(newcache + MAXALIGN(querysize), qtrg, VARSIZE(qtrg));
458 
459 		if (cache)
460 			pfree(cache);
461 		fcinfo->flinfo->fn_extra = newcache;
462 		cache = newcache;
463 	}
464 
465 	qtrg = (TRGM *) (cache + MAXALIGN(querysize));
466 
467 	switch (strategy)
468 	{
469 		case DistanceStrategyNumber:
470 		case WordDistanceStrategyNumber:
471 			*recheck = strategy == WordDistanceStrategyNumber;
472 			if (GIST_LEAF(entry))
473 			{					/* all leafs contains orig trgm */
474 
475 				/*
476 				 * Prevent gcc optimizing the sml variable using volatile
477 				 * keyword. Otherwise res can differ from the
478 				 * word_similarity_dist_op() function.
479 				 */
480 				float4 volatile sml = cnt_sml(qtrg, key, *recheck);
481 
482 				res = 1.0 - sml;
483 			}
484 			else if (ISALLTRUE(key))
485 			{					/* all leafs contains orig trgm */
486 				res = 0.0;
487 			}
488 			else
489 			{					/* non-leaf contains signature */
490 				int32		count = cnt_sml_sign_common(qtrg, GETSIGN(key));
491 				int32		len = ARRNELEM(qtrg);
492 
493 				res = (len == 0) ? -1.0 : 1.0 - ((float8) count) / ((float8) len);
494 			}
495 			break;
496 		default:
497 			elog(ERROR, "unrecognized strategy number: %d", strategy);
498 			res = 0;			/* keep compiler quiet */
499 			break;
500 	}
501 
502 	PG_RETURN_FLOAT8(res);
503 }
504 
505 static int32
unionkey(BITVECP sbase,TRGM * add)506 unionkey(BITVECP sbase, TRGM *add)
507 {
508 	int32		i;
509 
510 	if (ISSIGNKEY(add))
511 	{
512 		BITVECP		sadd = GETSIGN(add);
513 
514 		if (ISALLTRUE(add))
515 			return 1;
516 
517 		LOOPBYTE
518 			sbase[i] |= sadd[i];
519 	}
520 	else
521 	{
522 		trgm	   *ptr = GETARR(add);
523 		int32		tmp = 0;
524 
525 		for (i = 0; i < ARRNELEM(add); i++)
526 		{
527 			CPTRGM(((char *) &tmp), ptr + i);
528 			HASH(sbase, tmp);
529 		}
530 	}
531 	return 0;
532 }
533 
534 
535 Datum
gtrgm_union(PG_FUNCTION_ARGS)536 gtrgm_union(PG_FUNCTION_ARGS)
537 {
538 	GistEntryVector *entryvec = (GistEntryVector *) PG_GETARG_POINTER(0);
539 	int32		len = entryvec->n;
540 	int		   *size = (int *) PG_GETARG_POINTER(1);
541 	BITVEC		base;
542 	int32		i;
543 	int32		flag = 0;
544 	TRGM	   *result;
545 
546 	MemSet((void *) base, 0, sizeof(BITVEC));
547 	for (i = 0; i < len; i++)
548 	{
549 		if (unionkey(base, GETENTRY(entryvec, i)))
550 		{
551 			flag = ALLISTRUE;
552 			break;
553 		}
554 	}
555 
556 	flag |= SIGNKEY;
557 	len = CALCGTSIZE(flag, 0);
558 	result = (TRGM *) palloc(len);
559 	SET_VARSIZE(result, len);
560 	result->flag = flag;
561 	if (!ISALLTRUE(result))
562 		memcpy((void *) GETSIGN(result), (void *) base, sizeof(BITVEC));
563 	*size = len;
564 
565 	PG_RETURN_POINTER(result);
566 }
567 
568 Datum
gtrgm_same(PG_FUNCTION_ARGS)569 gtrgm_same(PG_FUNCTION_ARGS)
570 {
571 	TRGM	   *a = (TRGM *) PG_GETARG_POINTER(0);
572 	TRGM	   *b = (TRGM *) PG_GETARG_POINTER(1);
573 	bool	   *result = (bool *) PG_GETARG_POINTER(2);
574 
575 	if (ISSIGNKEY(a))
576 	{							/* then b also ISSIGNKEY */
577 		if (ISALLTRUE(a) && ISALLTRUE(b))
578 			*result = true;
579 		else if (ISALLTRUE(a))
580 			*result = false;
581 		else if (ISALLTRUE(b))
582 			*result = false;
583 		else
584 		{
585 			int32		i;
586 			BITVECP		sa = GETSIGN(a),
587 						sb = GETSIGN(b);
588 
589 			*result = true;
590 			LOOPBYTE
591 			{
592 				if (sa[i] != sb[i])
593 				{
594 					*result = false;
595 					break;
596 				}
597 			}
598 		}
599 	}
600 	else
601 	{							/* a and b ISARRKEY */
602 		int32		lena = ARRNELEM(a),
603 					lenb = ARRNELEM(b);
604 
605 		if (lena != lenb)
606 			*result = false;
607 		else
608 		{
609 			trgm	   *ptra = GETARR(a),
610 					   *ptrb = GETARR(b);
611 			int32		i;
612 
613 			*result = true;
614 			for (i = 0; i < lena; i++)
615 				if (CMPTRGM(ptra + i, ptrb + i))
616 				{
617 					*result = false;
618 					break;
619 				}
620 		}
621 	}
622 
623 	PG_RETURN_POINTER(result);
624 }
625 
626 static int32
sizebitvec(BITVECP sign)627 sizebitvec(BITVECP sign)
628 {
629 	int32		size = 0,
630 				i;
631 
632 	LOOPBYTE
633 		size += number_of_ones[(unsigned char) sign[i]];
634 	return size;
635 }
636 
637 static int
hemdistsign(BITVECP a,BITVECP b)638 hemdistsign(BITVECP a, BITVECP b)
639 {
640 	int			i,
641 				diff,
642 				dist = 0;
643 
644 	LOOPBYTE
645 	{
646 		diff = (unsigned char) (a[i] ^ b[i]);
647 		dist += number_of_ones[diff];
648 	}
649 	return dist;
650 }
651 
652 static int
hemdist(TRGM * a,TRGM * b)653 hemdist(TRGM *a, TRGM *b)
654 {
655 	if (ISALLTRUE(a))
656 	{
657 		if (ISALLTRUE(b))
658 			return 0;
659 		else
660 			return SIGLENBIT - sizebitvec(GETSIGN(b));
661 	}
662 	else if (ISALLTRUE(b))
663 		return SIGLENBIT - sizebitvec(GETSIGN(a));
664 
665 	return hemdistsign(GETSIGN(a), GETSIGN(b));
666 }
667 
668 Datum
gtrgm_penalty(PG_FUNCTION_ARGS)669 gtrgm_penalty(PG_FUNCTION_ARGS)
670 {
671 	GISTENTRY  *origentry = (GISTENTRY *) PG_GETARG_POINTER(0); /* always ISSIGNKEY */
672 	GISTENTRY  *newentry = (GISTENTRY *) PG_GETARG_POINTER(1);
673 	float	   *penalty = (float *) PG_GETARG_POINTER(2);
674 	TRGM	   *origval = (TRGM *) DatumGetPointer(origentry->key);
675 	TRGM	   *newval = (TRGM *) DatumGetPointer(newentry->key);
676 	BITVECP		orig = GETSIGN(origval);
677 
678 	*penalty = 0.0;
679 
680 	if (ISARRKEY(newval))
681 	{
682 		char	   *cache = (char *) fcinfo->flinfo->fn_extra;
683 		TRGM	   *cachedVal = (TRGM *) (cache + MAXALIGN(sizeof(BITVEC)));
684 		Size		newvalsize = VARSIZE(newval);
685 		BITVECP		sign;
686 
687 		/*
688 		 * Cache the sign data across multiple calls with the same newval.
689 		 */
690 		if (cache == NULL ||
691 			VARSIZE(cachedVal) != newvalsize ||
692 			memcmp(cachedVal, newval, newvalsize) != 0)
693 		{
694 			char	   *newcache;
695 
696 			newcache = MemoryContextAlloc(fcinfo->flinfo->fn_mcxt,
697 										  MAXALIGN(sizeof(BITVEC)) +
698 										  newvalsize);
699 
700 			makesign((BITVECP) newcache, newval);
701 
702 			cachedVal = (TRGM *) (newcache + MAXALIGN(sizeof(BITVEC)));
703 			memcpy(cachedVal, newval, newvalsize);
704 
705 			if (cache)
706 				pfree(cache);
707 			fcinfo->flinfo->fn_extra = newcache;
708 			cache = newcache;
709 		}
710 
711 		sign = (BITVECP) cache;
712 
713 		if (ISALLTRUE(origval))
714 			*penalty = ((float) (SIGLENBIT - sizebitvec(sign))) / (float) (SIGLENBIT + 1);
715 		else
716 			*penalty = hemdistsign(sign, orig);
717 	}
718 	else
719 		*penalty = hemdist(origval, newval);
720 	PG_RETURN_POINTER(penalty);
721 }
722 
723 typedef struct
724 {
725 	bool		allistrue;
726 	BITVEC		sign;
727 } CACHESIGN;
728 
729 static void
fillcache(CACHESIGN * item,TRGM * key)730 fillcache(CACHESIGN *item, TRGM *key)
731 {
732 	item->allistrue = false;
733 	if (ISARRKEY(key))
734 		makesign(item->sign, key);
735 	else if (ISALLTRUE(key))
736 		item->allistrue = true;
737 	else
738 		memcpy((void *) item->sign, (void *) GETSIGN(key), sizeof(BITVEC));
739 }
740 
741 #define WISH_F(a,b,c) (double)( -(double)(((a)-(b))*((a)-(b))*((a)-(b)))*(c) )
742 typedef struct
743 {
744 	OffsetNumber pos;
745 	int32		cost;
746 } SPLITCOST;
747 
748 static int
comparecost(const void * a,const void * b)749 comparecost(const void *a, const void *b)
750 {
751 	if (((const SPLITCOST *) a)->cost == ((const SPLITCOST *) b)->cost)
752 		return 0;
753 	else
754 		return (((const SPLITCOST *) a)->cost > ((const SPLITCOST *) b)->cost) ? 1 : -1;
755 }
756 
757 
758 static int
hemdistcache(CACHESIGN * a,CACHESIGN * b)759 hemdistcache(CACHESIGN *a, CACHESIGN *b)
760 {
761 	if (a->allistrue)
762 	{
763 		if (b->allistrue)
764 			return 0;
765 		else
766 			return SIGLENBIT - sizebitvec(b->sign);
767 	}
768 	else if (b->allistrue)
769 		return SIGLENBIT - sizebitvec(a->sign);
770 
771 	return hemdistsign(a->sign, b->sign);
772 }
773 
774 Datum
gtrgm_picksplit(PG_FUNCTION_ARGS)775 gtrgm_picksplit(PG_FUNCTION_ARGS)
776 {
777 	GistEntryVector *entryvec = (GistEntryVector *) PG_GETARG_POINTER(0);
778 	OffsetNumber maxoff = entryvec->n - 1;
779 	GIST_SPLITVEC *v = (GIST_SPLITVEC *) PG_GETARG_POINTER(1);
780 	OffsetNumber k,
781 				j;
782 	TRGM	   *datum_l,
783 			   *datum_r;
784 	BITVECP		union_l,
785 				union_r;
786 	int32		size_alpha,
787 				size_beta;
788 	int32		size_waste,
789 				waste = -1;
790 	int32		nbytes;
791 	OffsetNumber seed_1 = 0,
792 				seed_2 = 0;
793 	OffsetNumber *left,
794 			   *right;
795 	BITVECP		ptr;
796 	int			i;
797 	CACHESIGN  *cache;
798 	SPLITCOST  *costvector;
799 
800 	/* cache the sign data for each existing item */
801 	cache = (CACHESIGN *) palloc(sizeof(CACHESIGN) * (maxoff + 1));
802 	for (k = FirstOffsetNumber; k <= maxoff; k = OffsetNumberNext(k))
803 		fillcache(&cache[k], GETENTRY(entryvec, k));
804 
805 	/* now find the two furthest-apart items */
806 	for (k = FirstOffsetNumber; k < maxoff; k = OffsetNumberNext(k))
807 	{
808 		for (j = OffsetNumberNext(k); j <= maxoff; j = OffsetNumberNext(j))
809 		{
810 			size_waste = hemdistcache(&(cache[j]), &(cache[k]));
811 			if (size_waste > waste)
812 			{
813 				waste = size_waste;
814 				seed_1 = k;
815 				seed_2 = j;
816 			}
817 		}
818 	}
819 
820 	/* just in case we didn't make a selection ... */
821 	if (seed_1 == 0 || seed_2 == 0)
822 	{
823 		seed_1 = 1;
824 		seed_2 = 2;
825 	}
826 
827 	/* initialize the result vectors */
828 	nbytes = maxoff * sizeof(OffsetNumber);
829 	v->spl_left = left = (OffsetNumber *) palloc(nbytes);
830 	v->spl_right = right = (OffsetNumber *) palloc(nbytes);
831 	v->spl_nleft = 0;
832 	v->spl_nright = 0;
833 
834 	/* form initial .. */
835 	if (cache[seed_1].allistrue)
836 	{
837 		datum_l = (TRGM *) palloc(CALCGTSIZE(SIGNKEY | ALLISTRUE, 0));
838 		SET_VARSIZE(datum_l, CALCGTSIZE(SIGNKEY | ALLISTRUE, 0));
839 		datum_l->flag = SIGNKEY | ALLISTRUE;
840 	}
841 	else
842 	{
843 		datum_l = (TRGM *) palloc(CALCGTSIZE(SIGNKEY, 0));
844 		SET_VARSIZE(datum_l, CALCGTSIZE(SIGNKEY, 0));
845 		datum_l->flag = SIGNKEY;
846 		memcpy((void *) GETSIGN(datum_l), (void *) cache[seed_1].sign, sizeof(BITVEC));
847 	}
848 	if (cache[seed_2].allistrue)
849 	{
850 		datum_r = (TRGM *) palloc(CALCGTSIZE(SIGNKEY | ALLISTRUE, 0));
851 		SET_VARSIZE(datum_r, CALCGTSIZE(SIGNKEY | ALLISTRUE, 0));
852 		datum_r->flag = SIGNKEY | ALLISTRUE;
853 	}
854 	else
855 	{
856 		datum_r = (TRGM *) palloc(CALCGTSIZE(SIGNKEY, 0));
857 		SET_VARSIZE(datum_r, CALCGTSIZE(SIGNKEY, 0));
858 		datum_r->flag = SIGNKEY;
859 		memcpy((void *) GETSIGN(datum_r), (void *) cache[seed_2].sign, sizeof(BITVEC));
860 	}
861 
862 	union_l = GETSIGN(datum_l);
863 	union_r = GETSIGN(datum_r);
864 
865 	/* sort before ... */
866 	costvector = (SPLITCOST *) palloc(sizeof(SPLITCOST) * maxoff);
867 	for (j = FirstOffsetNumber; j <= maxoff; j = OffsetNumberNext(j))
868 	{
869 		costvector[j - 1].pos = j;
870 		size_alpha = hemdistcache(&(cache[seed_1]), &(cache[j]));
871 		size_beta = hemdistcache(&(cache[seed_2]), &(cache[j]));
872 		costvector[j - 1].cost = abs(size_alpha - size_beta);
873 	}
874 	qsort((void *) costvector, maxoff, sizeof(SPLITCOST), comparecost);
875 
876 	for (k = 0; k < maxoff; k++)
877 	{
878 		j = costvector[k].pos;
879 		if (j == seed_1)
880 		{
881 			*left++ = j;
882 			v->spl_nleft++;
883 			continue;
884 		}
885 		else if (j == seed_2)
886 		{
887 			*right++ = j;
888 			v->spl_nright++;
889 			continue;
890 		}
891 
892 		if (ISALLTRUE(datum_l) || cache[j].allistrue)
893 		{
894 			if (ISALLTRUE(datum_l) && cache[j].allistrue)
895 				size_alpha = 0;
896 			else
897 				size_alpha = SIGLENBIT - sizebitvec(
898 													(cache[j].allistrue) ? GETSIGN(datum_l) : GETSIGN(cache[j].sign)
899 					);
900 		}
901 		else
902 			size_alpha = hemdistsign(cache[j].sign, GETSIGN(datum_l));
903 
904 		if (ISALLTRUE(datum_r) || cache[j].allistrue)
905 		{
906 			if (ISALLTRUE(datum_r) && cache[j].allistrue)
907 				size_beta = 0;
908 			else
909 				size_beta = SIGLENBIT - sizebitvec(
910 												   (cache[j].allistrue) ? GETSIGN(datum_r) : GETSIGN(cache[j].sign)
911 					);
912 		}
913 		else
914 			size_beta = hemdistsign(cache[j].sign, GETSIGN(datum_r));
915 
916 		if (size_alpha < size_beta + WISH_F(v->spl_nleft, v->spl_nright, 0.1))
917 		{
918 			if (ISALLTRUE(datum_l) || cache[j].allistrue)
919 			{
920 				if (!ISALLTRUE(datum_l))
921 					MemSet((void *) GETSIGN(datum_l), 0xff, sizeof(BITVEC));
922 			}
923 			else
924 			{
925 				ptr = cache[j].sign;
926 				LOOPBYTE
927 					union_l[i] |= ptr[i];
928 			}
929 			*left++ = j;
930 			v->spl_nleft++;
931 		}
932 		else
933 		{
934 			if (ISALLTRUE(datum_r) || cache[j].allistrue)
935 			{
936 				if (!ISALLTRUE(datum_r))
937 					MemSet((void *) GETSIGN(datum_r), 0xff, sizeof(BITVEC));
938 			}
939 			else
940 			{
941 				ptr = cache[j].sign;
942 				LOOPBYTE
943 					union_r[i] |= ptr[i];
944 			}
945 			*right++ = j;
946 			v->spl_nright++;
947 		}
948 	}
949 
950 	v->spl_ldatum = PointerGetDatum(datum_l);
951 	v->spl_rdatum = PointerGetDatum(datum_r);
952 
953 	PG_RETURN_POINTER(v);
954 }
955