1 /*-------------------------------------------------------------------------
2  *
3  * network_gist.c
4  *	  GiST support for network types.
5  *
6  * The key thing to understand about this code is the definition of the
7  * "union" of a set of INET/CIDR values.  It works like this:
8  * 1. If the values are not all of the same IP address family, the "union"
9  * is a dummy value with family number zero, minbits zero, commonbits zero,
10  * address all zeroes.  Otherwise:
11  * 2. The union has the common IP address family number.
12  * 3. The union's minbits value is the smallest netmask length ("ip_bits")
13  * of all the input values.
14  * 4. Let C be the number of leading address bits that are in common among
15  * all the input values (C ranges from 0 to ip_maxbits for the family).
16  * 5. The union's commonbits value is C.
17  * 6. The union's address value is the same as the common prefix for its
18  * first C bits, and is zeroes to the right of that.  The physical width
19  * of the address value is ip_maxbits for the address family.
20  *
21  * In a leaf index entry (representing a single key), commonbits is equal to
22  * ip_maxbits for the address family, minbits is the same as the represented
23  * value's ip_bits, and the address is equal to the represented address.
24  * Although it may appear that we're wasting a byte by storing the union
25  * format and not just the represented INET/CIDR value in leaf keys, the
26  * extra byte is actually "free" because of alignment considerations.
27  *
28  * Note that this design tracks minbits and commonbits independently; in any
29  * given union value, either might be smaller than the other.  This does not
30  * help us much when descending the tree, because of the way inet comparison
31  * is defined: at non-leaf nodes we can't compare more than minbits bits
32  * even if we know them.  However, it greatly improves the quality of split
33  * decisions.  Preliminary testing suggests that searches are as much as
34  * twice as fast as for a simpler design in which a single field doubles as
35  * the common prefix length and the minimum ip_bits value.
36  *
37  * Portions Copyright (c) 1996-2016, PostgreSQL Global Development Group
38  * Portions Copyright (c) 1994, Regents of the University of California
39  *
40  *
41  * IDENTIFICATION
42  *	  src/backend/utils/adt/network_gist.c
43  *
44  *-------------------------------------------------------------------------
45  */
46 #include "postgres.h"
47 
48 #include <sys/socket.h>
49 
50 #include "access/gist.h"
51 #include "access/stratnum.h"
52 #include "utils/inet.h"
53 
54 /*
55  * Operator strategy numbers used in the GiST inet_ops opclass
56  */
57 #define INETSTRAT_OVERLAPS		RTOverlapStrategyNumber
58 #define INETSTRAT_EQ			RTEqualStrategyNumber
59 #define INETSTRAT_NE			RTNotEqualStrategyNumber
60 #define INETSTRAT_LT			RTLessStrategyNumber
61 #define INETSTRAT_LE			RTLessEqualStrategyNumber
62 #define INETSTRAT_GT			RTGreaterStrategyNumber
63 #define INETSTRAT_GE			RTGreaterEqualStrategyNumber
64 #define INETSTRAT_SUB			RTSubStrategyNumber
65 #define INETSTRAT_SUBEQ			RTSubEqualStrategyNumber
66 #define INETSTRAT_SUP			RTSuperStrategyNumber
67 #define INETSTRAT_SUPEQ			RTSuperEqualStrategyNumber
68 
69 
70 /*
71  * Representation of a GiST INET/CIDR index key.  This is not identical to
72  * INET/CIDR because we need to keep track of the length of the common address
73  * prefix as well as the minimum netmask length.  However, as long as it
74  * follows varlena header rules, the core GiST code won't know the difference.
75  * For simplicity we always use 1-byte-header varlena format.
76  */
77 typedef struct GistInetKey
78 {
79 	uint8		va_header;		/* varlena header --- don't touch directly */
80 	unsigned char family;		/* PGSQL_AF_INET, PGSQL_AF_INET6, or zero */
81 	unsigned char minbits;		/* minimum number of bits in netmask */
82 	unsigned char commonbits;	/* number of common prefix bits in addresses */
83 	unsigned char ipaddr[16];	/* up to 128 bits of common address */
84 } GistInetKey;
85 
86 #define DatumGetInetKeyP(X) ((GistInetKey *) DatumGetPointer(X))
87 #define InetKeyPGetDatum(X) PointerGetDatum(X)
88 
89 /*
90  * Access macros; not really exciting, but we use these for notational
91  * consistency with access to INET/CIDR values.  Note that family-zero values
92  * are stored with 4 bytes of address, not 16.
93  */
94 #define gk_ip_family(gkptr)		((gkptr)->family)
95 #define gk_ip_minbits(gkptr)	((gkptr)->minbits)
96 #define gk_ip_commonbits(gkptr) ((gkptr)->commonbits)
97 #define gk_ip_addr(gkptr)		((gkptr)->ipaddr)
98 #define ip_family_maxbits(fam)	((fam) == PGSQL_AF_INET6 ? 128 : 32)
99 
100 /* These require that the family field has been set: */
101 #define gk_ip_addrsize(gkptr) \
102 	(gk_ip_family(gkptr) == PGSQL_AF_INET6 ? 16 : 4)
103 #define gk_ip_maxbits(gkptr) \
104 	ip_family_maxbits(gk_ip_family(gkptr))
105 #define SET_GK_VARSIZE(dst) \
106 	SET_VARSIZE_SHORT(dst, offsetof(GistInetKey, ipaddr) + gk_ip_addrsize(dst))
107 
108 
109 /*
110  * The GiST query consistency check
111  */
112 Datum
inet_gist_consistent(PG_FUNCTION_ARGS)113 inet_gist_consistent(PG_FUNCTION_ARGS)
114 {
115 	GISTENTRY  *ent = (GISTENTRY *) PG_GETARG_POINTER(0);
116 	inet	   *query = PG_GETARG_INET_PP(1);
117 	StrategyNumber strategy = (StrategyNumber) PG_GETARG_UINT16(2);
118 
119 	/* Oid		subtype = PG_GETARG_OID(3); */
120 	bool	   *recheck = (bool *) PG_GETARG_POINTER(4);
121 	GistInetKey *key = DatumGetInetKeyP(ent->key);
122 	int			minbits,
123 				order;
124 
125 	/* All operators served by this function are exact. */
126 	*recheck = false;
127 
128 	/*
129 	 * Check 0: different families
130 	 *
131 	 * If key represents multiple address families, its children could match
132 	 * anything.  This can only happen on an inner index page.
133 	 */
134 	if (gk_ip_family(key) == 0)
135 	{
136 		Assert(!GIST_LEAF(ent));
137 		PG_RETURN_BOOL(true);
138 	}
139 
140 	/*
141 	 * Check 1: different families
142 	 *
143 	 * Matching families do not help any of the strategies.
144 	 */
145 	if (gk_ip_family(key) != ip_family(query))
146 	{
147 		switch (strategy)
148 		{
149 			case INETSTRAT_LT:
150 			case INETSTRAT_LE:
151 				if (gk_ip_family(key) < ip_family(query))
152 					PG_RETURN_BOOL(true);
153 				break;
154 
155 			case INETSTRAT_GE:
156 			case INETSTRAT_GT:
157 				if (gk_ip_family(key) > ip_family(query))
158 					PG_RETURN_BOOL(true);
159 				break;
160 
161 			case INETSTRAT_NE:
162 				PG_RETURN_BOOL(true);
163 		}
164 		/* For all other cases, we can be sure there is no match */
165 		PG_RETURN_BOOL(false);
166 	}
167 
168 	/*
169 	 * Check 2: network bit count
170 	 *
171 	 * Network bit count (ip_bits) helps to check leaves for sub network and
172 	 * sup network operators.  At non-leaf nodes, we know every child value
173 	 * has ip_bits >= gk_ip_minbits(key), so we can avoid descending in some
174 	 * cases too.
175 	 */
176 	switch (strategy)
177 	{
178 		case INETSTRAT_SUB:
179 			if (GIST_LEAF(ent) && gk_ip_minbits(key) <= ip_bits(query))
180 				PG_RETURN_BOOL(false);
181 			break;
182 
183 		case INETSTRAT_SUBEQ:
184 			if (GIST_LEAF(ent) && gk_ip_minbits(key) < ip_bits(query))
185 				PG_RETURN_BOOL(false);
186 			break;
187 
188 		case INETSTRAT_SUPEQ:
189 		case INETSTRAT_EQ:
190 			if (gk_ip_minbits(key) > ip_bits(query))
191 				PG_RETURN_BOOL(false);
192 			break;
193 
194 		case INETSTRAT_SUP:
195 			if (gk_ip_minbits(key) >= ip_bits(query))
196 				PG_RETURN_BOOL(false);
197 			break;
198 	}
199 
200 	/*
201 	 * Check 3: common network bits
202 	 *
203 	 * Compare available common prefix bits to the query, but not beyond
204 	 * either the query's netmask or the minimum netmask among the represented
205 	 * values.  If these bits don't match the query, we have our answer (and
206 	 * may or may not need to descend, depending on the operator).  If they do
207 	 * match, and we are not at a leaf, we descend in all cases.
208 	 *
209 	 * Note this is the final check for operators that only consider the
210 	 * network part of the address.
211 	 */
212 	minbits = Min(gk_ip_commonbits(key), gk_ip_minbits(key));
213 	minbits = Min(minbits, ip_bits(query));
214 
215 	order = bitncmp(gk_ip_addr(key), ip_addr(query), minbits);
216 
217 	switch (strategy)
218 	{
219 		case INETSTRAT_SUB:
220 		case INETSTRAT_SUBEQ:
221 		case INETSTRAT_OVERLAPS:
222 		case INETSTRAT_SUPEQ:
223 		case INETSTRAT_SUP:
224 			PG_RETURN_BOOL(order == 0);
225 
226 		case INETSTRAT_LT:
227 		case INETSTRAT_LE:
228 			if (order > 0)
229 				PG_RETURN_BOOL(false);
230 			if (order < 0 || !GIST_LEAF(ent))
231 				PG_RETURN_BOOL(true);
232 			break;
233 
234 		case INETSTRAT_EQ:
235 			if (order != 0)
236 				PG_RETURN_BOOL(false);
237 			if (!GIST_LEAF(ent))
238 				PG_RETURN_BOOL(true);
239 			break;
240 
241 		case INETSTRAT_GE:
242 		case INETSTRAT_GT:
243 			if (order < 0)
244 				PG_RETURN_BOOL(false);
245 			if (order > 0 || !GIST_LEAF(ent))
246 				PG_RETURN_BOOL(true);
247 			break;
248 
249 		case INETSTRAT_NE:
250 			if (order != 0 || !GIST_LEAF(ent))
251 				PG_RETURN_BOOL(true);
252 			break;
253 	}
254 
255 	/*
256 	 * Remaining checks are only for leaves and basic comparison strategies.
257 	 * See network_cmp_internal() in network.c for the implementation we need
258 	 * to match.  Note that in a leaf key, commonbits should equal the address
259 	 * length, so we compared the whole network parts above.
260 	 */
261 	Assert(GIST_LEAF(ent));
262 
263 	/*
264 	 * Check 4: network bit count
265 	 *
266 	 * Next step is to compare netmask widths.
267 	 */
268 	switch (strategy)
269 	{
270 		case INETSTRAT_LT:
271 		case INETSTRAT_LE:
272 			if (gk_ip_minbits(key) < ip_bits(query))
273 				PG_RETURN_BOOL(true);
274 			if (gk_ip_minbits(key) > ip_bits(query))
275 				PG_RETURN_BOOL(false);
276 			break;
277 
278 		case INETSTRAT_EQ:
279 			if (gk_ip_minbits(key) != ip_bits(query))
280 				PG_RETURN_BOOL(false);
281 			break;
282 
283 		case INETSTRAT_GE:
284 		case INETSTRAT_GT:
285 			if (gk_ip_minbits(key) > ip_bits(query))
286 				PG_RETURN_BOOL(true);
287 			if (gk_ip_minbits(key) < ip_bits(query))
288 				PG_RETURN_BOOL(false);
289 			break;
290 
291 		case INETSTRAT_NE:
292 			if (gk_ip_minbits(key) != ip_bits(query))
293 				PG_RETURN_BOOL(true);
294 			break;
295 	}
296 
297 	/*
298 	 * Check 5: whole address
299 	 *
300 	 * Netmask bit counts are the same, so check all the address bits.
301 	 */
302 	order = bitncmp(gk_ip_addr(key), ip_addr(query), gk_ip_maxbits(key));
303 
304 	switch (strategy)
305 	{
306 		case INETSTRAT_LT:
307 			PG_RETURN_BOOL(order < 0);
308 
309 		case INETSTRAT_LE:
310 			PG_RETURN_BOOL(order <= 0);
311 
312 		case INETSTRAT_EQ:
313 			PG_RETURN_BOOL(order == 0);
314 
315 		case INETSTRAT_GE:
316 			PG_RETURN_BOOL(order >= 0);
317 
318 		case INETSTRAT_GT:
319 			PG_RETURN_BOOL(order > 0);
320 
321 		case INETSTRAT_NE:
322 			PG_RETURN_BOOL(order != 0);
323 	}
324 
325 	elog(ERROR, "unknown strategy for inet GiST");
326 	PG_RETURN_BOOL(false);		/* keep compiler quiet */
327 }
328 
329 /*
330  * Calculate parameters of the union of some GistInetKeys.
331  *
332  * Examine the keys in elements m..n inclusive of the GISTENTRY array,
333  * and compute these output parameters:
334  * *minfamily_p = minimum IP address family number
335  * *maxfamily_p = maximum IP address family number
336  * *minbits_p = minimum netmask width
337  * *commonbits_p = number of leading bits in common among the addresses
338  *
339  * minbits and commonbits are forced to zero if there's more than one
340  * address family.
341  */
342 static void
calc_inet_union_params(GISTENTRY * ent,int m,int n,int * minfamily_p,int * maxfamily_p,int * minbits_p,int * commonbits_p)343 calc_inet_union_params(GISTENTRY *ent,
344 					   int m, int n,
345 					   int *minfamily_p,
346 					   int *maxfamily_p,
347 					   int *minbits_p,
348 					   int *commonbits_p)
349 {
350 	int			minfamily,
351 				maxfamily,
352 				minbits,
353 				commonbits;
354 	unsigned char *addr;
355 	GistInetKey *tmp;
356 	int			i;
357 
358 	/* Must be at least one key. */
359 	Assert(m <= n);
360 
361 	/* Initialize variables using the first key. */
362 	tmp = DatumGetInetKeyP(ent[m].key);
363 	minfamily = maxfamily = gk_ip_family(tmp);
364 	minbits = gk_ip_minbits(tmp);
365 	commonbits = gk_ip_commonbits(tmp);
366 	addr = gk_ip_addr(tmp);
367 
368 	/* Scan remaining keys. */
369 	for (i = m + 1; i <= n; i++)
370 	{
371 		tmp = DatumGetInetKeyP(ent[i].key);
372 
373 		/* Determine range of family numbers */
374 		if (minfamily > gk_ip_family(tmp))
375 			minfamily = gk_ip_family(tmp);
376 		if (maxfamily < gk_ip_family(tmp))
377 			maxfamily = gk_ip_family(tmp);
378 
379 		/* Find minimum minbits */
380 		if (minbits > gk_ip_minbits(tmp))
381 			minbits = gk_ip_minbits(tmp);
382 
383 		/* Find minimum number of bits in common */
384 		if (commonbits > gk_ip_commonbits(tmp))
385 			commonbits = gk_ip_commonbits(tmp);
386 		if (commonbits > 0)
387 			commonbits = bitncommon(addr, gk_ip_addr(tmp), commonbits);
388 	}
389 
390 	/* Force minbits/commonbits to zero if more than one family. */
391 	if (minfamily != maxfamily)
392 		minbits = commonbits = 0;
393 
394 	*minfamily_p = minfamily;
395 	*maxfamily_p = maxfamily;
396 	*minbits_p = minbits;
397 	*commonbits_p = commonbits;
398 }
399 
400 /*
401  * Same as above, but the GISTENTRY elements to examine are those with
402  * indices listed in the offsets[] array.
403  */
404 static void
calc_inet_union_params_indexed(GISTENTRY * ent,OffsetNumber * offsets,int noffsets,int * minfamily_p,int * maxfamily_p,int * minbits_p,int * commonbits_p)405 calc_inet_union_params_indexed(GISTENTRY *ent,
406 							   OffsetNumber *offsets, int noffsets,
407 							   int *minfamily_p,
408 							   int *maxfamily_p,
409 							   int *minbits_p,
410 							   int *commonbits_p)
411 {
412 	int			minfamily,
413 				maxfamily,
414 				minbits,
415 				commonbits;
416 	unsigned char *addr;
417 	GistInetKey *tmp;
418 	int			i;
419 
420 	/* Must be at least one key. */
421 	Assert(noffsets > 0);
422 
423 	/* Initialize variables using the first key. */
424 	tmp = DatumGetInetKeyP(ent[offsets[0]].key);
425 	minfamily = maxfamily = gk_ip_family(tmp);
426 	minbits = gk_ip_minbits(tmp);
427 	commonbits = gk_ip_commonbits(tmp);
428 	addr = gk_ip_addr(tmp);
429 
430 	/* Scan remaining keys. */
431 	for (i = 1; i < noffsets; i++)
432 	{
433 		tmp = DatumGetInetKeyP(ent[offsets[i]].key);
434 
435 		/* Determine range of family numbers */
436 		if (minfamily > gk_ip_family(tmp))
437 			minfamily = gk_ip_family(tmp);
438 		if (maxfamily < gk_ip_family(tmp))
439 			maxfamily = gk_ip_family(tmp);
440 
441 		/* Find minimum minbits */
442 		if (minbits > gk_ip_minbits(tmp))
443 			minbits = gk_ip_minbits(tmp);
444 
445 		/* Find minimum number of bits in common */
446 		if (commonbits > gk_ip_commonbits(tmp))
447 			commonbits = gk_ip_commonbits(tmp);
448 		if (commonbits > 0)
449 			commonbits = bitncommon(addr, gk_ip_addr(tmp), commonbits);
450 	}
451 
452 	/* Force minbits/commonbits to zero if more than one family. */
453 	if (minfamily != maxfamily)
454 		minbits = commonbits = 0;
455 
456 	*minfamily_p = minfamily;
457 	*maxfamily_p = maxfamily;
458 	*minbits_p = minbits;
459 	*commonbits_p = commonbits;
460 }
461 
462 /*
463  * Construct a GistInetKey representing a union value.
464  *
465  * Inputs are the family/minbits/commonbits values to use, plus a pointer to
466  * the address field of one of the union inputs.  (Since we're going to copy
467  * just the bits-in-common, it doesn't matter which one.)
468  */
469 static GistInetKey *
build_inet_union_key(int family,int minbits,int commonbits,unsigned char * addr)470 build_inet_union_key(int family, int minbits, int commonbits,
471 					 unsigned char *addr)
472 {
473 	GistInetKey *result;
474 
475 	/* Make sure any unused bits are zeroed. */
476 	result = (GistInetKey *) palloc0(sizeof(GistInetKey));
477 
478 	gk_ip_family(result) = family;
479 	gk_ip_minbits(result) = minbits;
480 	gk_ip_commonbits(result) = commonbits;
481 
482 	/* Clone appropriate bytes of the address. */
483 	if (commonbits > 0)
484 		memcpy(gk_ip_addr(result), addr, (commonbits + 7) / 8);
485 
486 	/* Clean any unwanted bits in the last partial byte. */
487 	if (commonbits % 8 != 0)
488 		gk_ip_addr(result)[commonbits / 8] &= ~(0xFF >> (commonbits % 8));
489 
490 	/* Set varlena header correctly. */
491 	SET_GK_VARSIZE(result);
492 
493 	return result;
494 }
495 
496 
497 /*
498  * The GiST union function
499  *
500  * See comments at head of file for the definition of the union.
501  */
502 Datum
inet_gist_union(PG_FUNCTION_ARGS)503 inet_gist_union(PG_FUNCTION_ARGS)
504 {
505 	GistEntryVector *entryvec = (GistEntryVector *) PG_GETARG_POINTER(0);
506 	GISTENTRY  *ent = entryvec->vector;
507 	int			minfamily,
508 				maxfamily,
509 				minbits,
510 				commonbits;
511 	unsigned char *addr;
512 	GistInetKey *tmp,
513 			   *result;
514 
515 	/* Determine parameters of the union. */
516 	calc_inet_union_params(ent, 0, entryvec->n - 1,
517 						   &minfamily, &maxfamily,
518 						   &minbits, &commonbits);
519 
520 	/* If more than one family, emit family number zero. */
521 	if (minfamily != maxfamily)
522 		minfamily = 0;
523 
524 	/* Initialize address using the first key. */
525 	tmp = DatumGetInetKeyP(ent[0].key);
526 	addr = gk_ip_addr(tmp);
527 
528 	/* Construct the union value. */
529 	result = build_inet_union_key(minfamily, minbits, commonbits, addr);
530 
531 	PG_RETURN_POINTER(result);
532 }
533 
534 /*
535  * The GiST compress function
536  *
537  * Convert an inet value to GistInetKey.
538  */
539 Datum
inet_gist_compress(PG_FUNCTION_ARGS)540 inet_gist_compress(PG_FUNCTION_ARGS)
541 {
542 	GISTENTRY  *entry = (GISTENTRY *) PG_GETARG_POINTER(0);
543 	GISTENTRY  *retval;
544 
545 	if (entry->leafkey)
546 	{
547 		retval = palloc(sizeof(GISTENTRY));
548 		if (DatumGetPointer(entry->key) != NULL)
549 		{
550 			inet	   *in = DatumGetInetPP(entry->key);
551 			GistInetKey *r;
552 
553 			r = (GistInetKey *) palloc0(sizeof(GistInetKey));
554 
555 			gk_ip_family(r) = ip_family(in);
556 			gk_ip_minbits(r) = ip_bits(in);
557 			gk_ip_commonbits(r) = gk_ip_maxbits(r);
558 			memcpy(gk_ip_addr(r), ip_addr(in), gk_ip_addrsize(r));
559 			SET_GK_VARSIZE(r);
560 
561 			gistentryinit(*retval, PointerGetDatum(r),
562 						  entry->rel, entry->page,
563 						  entry->offset, FALSE);
564 		}
565 		else
566 		{
567 			gistentryinit(*retval, (Datum) 0,
568 						  entry->rel, entry->page,
569 						  entry->offset, FALSE);
570 		}
571 	}
572 	else
573 		retval = entry;
574 	PG_RETURN_POINTER(retval);
575 }
576 
577 /*
578  * The GiST decompress function
579  *
580  * do not do anything --- we just use the stored GistInetKey as-is.
581  */
582 Datum
inet_gist_decompress(PG_FUNCTION_ARGS)583 inet_gist_decompress(PG_FUNCTION_ARGS)
584 {
585 	GISTENTRY  *entry = (GISTENTRY *) PG_GETARG_POINTER(0);
586 
587 	PG_RETURN_POINTER(entry);
588 }
589 
590 /*
591  * The GiST fetch function
592  *
593  * Reconstruct the original inet datum from a GistInetKey.
594  */
595 Datum
inet_gist_fetch(PG_FUNCTION_ARGS)596 inet_gist_fetch(PG_FUNCTION_ARGS)
597 {
598 	GISTENTRY  *entry = (GISTENTRY *) PG_GETARG_POINTER(0);
599 	GistInetKey *key = DatumGetInetKeyP(entry->key);
600 	GISTENTRY  *retval;
601 	inet	   *dst;
602 
603 	dst = (inet *) palloc0(sizeof(inet));
604 
605 	ip_family(dst) = gk_ip_family(key);
606 	ip_bits(dst) = gk_ip_minbits(key);
607 	memcpy(ip_addr(dst), gk_ip_addr(key), ip_addrsize(dst));
608 	SET_INET_VARSIZE(dst);
609 
610 	retval = palloc(sizeof(GISTENTRY));
611 	gistentryinit(*retval, InetPGetDatum(dst), entry->rel, entry->page,
612 				  entry->offset, FALSE);
613 
614 	PG_RETURN_POINTER(retval);
615 }
616 
617 /*
618  * The GiST page split penalty function
619  *
620  * Charge a large penalty if address family doesn't match, or a somewhat
621  * smaller one if the new value would degrade the union's minbits
622  * (minimum netmask width).  Otherwise, penalty is inverse of the
623  * new number of common address bits.
624  */
625 Datum
inet_gist_penalty(PG_FUNCTION_ARGS)626 inet_gist_penalty(PG_FUNCTION_ARGS)
627 {
628 	GISTENTRY  *origent = (GISTENTRY *) PG_GETARG_POINTER(0);
629 	GISTENTRY  *newent = (GISTENTRY *) PG_GETARG_POINTER(1);
630 	float	   *penalty = (float *) PG_GETARG_POINTER(2);
631 	GistInetKey *orig = DatumGetInetKeyP(origent->key),
632 			   *new = DatumGetInetKeyP(newent->key);
633 	int			commonbits;
634 
635 	if (gk_ip_family(orig) == gk_ip_family(new))
636 	{
637 		if (gk_ip_minbits(orig) <= gk_ip_minbits(new))
638 		{
639 			commonbits = bitncommon(gk_ip_addr(orig), gk_ip_addr(new),
640 									Min(gk_ip_commonbits(orig),
641 										gk_ip_commonbits(new)));
642 			if (commonbits > 0)
643 				*penalty = 1.0f / commonbits;
644 			else
645 				*penalty = 2;
646 		}
647 		else
648 			*penalty = 3;
649 	}
650 	else
651 		*penalty = 4;
652 
653 	PG_RETURN_POINTER(penalty);
654 }
655 
656 /*
657  * The GiST PickSplit method
658  *
659  * There are two ways to split. First one is to split by address families,
660  * if there are multiple families appearing in the input.
661  *
662  * The second and more common way is to split by addresses. To achieve this,
663  * determine the number of leading bits shared by all the keys, then split on
664  * the next bit.  (We don't currently consider the netmask widths while doing
665  * this; should we?)  If we fail to get a nontrivial split that way, split
666  * 50-50.
667  */
668 Datum
inet_gist_picksplit(PG_FUNCTION_ARGS)669 inet_gist_picksplit(PG_FUNCTION_ARGS)
670 {
671 	GistEntryVector *entryvec = (GistEntryVector *) PG_GETARG_POINTER(0);
672 	GIST_SPLITVEC *splitvec = (GIST_SPLITVEC *) PG_GETARG_POINTER(1);
673 	GISTENTRY  *ent = entryvec->vector;
674 	int			minfamily,
675 				maxfamily,
676 				minbits,
677 				commonbits;
678 	unsigned char *addr;
679 	GistInetKey *tmp,
680 			   *left_union,
681 			   *right_union;
682 	int			maxoff,
683 				nbytes;
684 	OffsetNumber i,
685 			   *left,
686 			   *right;
687 
688 	maxoff = entryvec->n - 1;
689 	nbytes = (maxoff + 1) * sizeof(OffsetNumber);
690 
691 	left = (OffsetNumber *) palloc(nbytes);
692 	right = (OffsetNumber *) palloc(nbytes);
693 
694 	splitvec->spl_left = left;
695 	splitvec->spl_right = right;
696 
697 	splitvec->spl_nleft = 0;
698 	splitvec->spl_nright = 0;
699 
700 	/* Determine parameters of the union of all the inputs. */
701 	calc_inet_union_params(ent, FirstOffsetNumber, maxoff,
702 						   &minfamily, &maxfamily,
703 						   &minbits, &commonbits);
704 
705 	if (minfamily != maxfamily)
706 	{
707 		/* Multiple families, so split by family. */
708 		for (i = FirstOffsetNumber; i <= maxoff; i = OffsetNumberNext(i))
709 		{
710 			/*
711 			 * If there's more than 2 families, all but maxfamily go into the
712 			 * left union.  This could only happen if the inputs include some
713 			 * IPv4, some IPv6, and some already-multiple-family unions.
714 			 */
715 			tmp = DatumGetInetKeyP(ent[i].key);
716 			if (gk_ip_family(tmp) != maxfamily)
717 				left[splitvec->spl_nleft++] = i;
718 			else
719 				right[splitvec->spl_nright++] = i;
720 		}
721 	}
722 	else
723 	{
724 		/*
725 		 * Split on the next bit after the common bits.  If that yields a
726 		 * trivial split, try the next bit position to the right.  Repeat till
727 		 * success; or if we run out of bits, do an arbitrary 50-50 split.
728 		 */
729 		int			maxbits = ip_family_maxbits(minfamily);
730 
731 		while (commonbits < maxbits)
732 		{
733 			/* Split using the commonbits'th bit position. */
734 			int			bitbyte = commonbits / 8;
735 			int			bitmask = 0x80 >> (commonbits % 8);
736 
737 			splitvec->spl_nleft = splitvec->spl_nright = 0;
738 
739 			for (i = FirstOffsetNumber; i <= maxoff; i = OffsetNumberNext(i))
740 			{
741 				tmp = DatumGetInetKeyP(ent[i].key);
742 				addr = gk_ip_addr(tmp);
743 				if ((addr[bitbyte] & bitmask) == 0)
744 					left[splitvec->spl_nleft++] = i;
745 				else
746 					right[splitvec->spl_nright++] = i;
747 			}
748 
749 			if (splitvec->spl_nleft > 0 && splitvec->spl_nright > 0)
750 				break;			/* success */
751 			commonbits++;
752 		}
753 
754 		if (commonbits >= maxbits)
755 		{
756 			/* Failed ... do a 50-50 split. */
757 			splitvec->spl_nleft = splitvec->spl_nright = 0;
758 
759 			for (i = FirstOffsetNumber; i <= maxoff / 2; i = OffsetNumberNext(i))
760 			{
761 				left[splitvec->spl_nleft++] = i;
762 			}
763 			for (; i <= maxoff; i = OffsetNumberNext(i))
764 			{
765 				right[splitvec->spl_nright++] = i;
766 			}
767 		}
768 	}
769 
770 	/*
771 	 * Compute the union value for each side from scratch.  In most cases we
772 	 * could approximate the union values with what we already know, but this
773 	 * ensures that each side has minbits and commonbits set as high as
774 	 * possible.
775 	 */
776 	calc_inet_union_params_indexed(ent, left, splitvec->spl_nleft,
777 								   &minfamily, &maxfamily,
778 								   &minbits, &commonbits);
779 	if (minfamily != maxfamily)
780 		minfamily = 0;
781 	tmp = DatumGetInetKeyP(ent[left[0]].key);
782 	addr = gk_ip_addr(tmp);
783 	left_union = build_inet_union_key(minfamily, minbits, commonbits, addr);
784 	splitvec->spl_ldatum = PointerGetDatum(left_union);
785 
786 	calc_inet_union_params_indexed(ent, right, splitvec->spl_nright,
787 								   &minfamily, &maxfamily,
788 								   &minbits, &commonbits);
789 	if (minfamily != maxfamily)
790 		minfamily = 0;
791 	tmp = DatumGetInetKeyP(ent[right[0]].key);
792 	addr = gk_ip_addr(tmp);
793 	right_union = build_inet_union_key(minfamily, minbits, commonbits, addr);
794 	splitvec->spl_rdatum = PointerGetDatum(right_union);
795 
796 	PG_RETURN_POINTER(splitvec);
797 }
798 
799 /*
800  * The GiST equality function
801  */
802 Datum
inet_gist_same(PG_FUNCTION_ARGS)803 inet_gist_same(PG_FUNCTION_ARGS)
804 {
805 	GistInetKey *left = DatumGetInetKeyP(PG_GETARG_DATUM(0));
806 	GistInetKey *right = DatumGetInetKeyP(PG_GETARG_DATUM(1));
807 	bool	   *result = (bool *) PG_GETARG_POINTER(2);
808 
809 	*result = (gk_ip_family(left) == gk_ip_family(right) &&
810 			   gk_ip_minbits(left) == gk_ip_minbits(right) &&
811 			   gk_ip_commonbits(left) == gk_ip_commonbits(right) &&
812 			   memcmp(gk_ip_addr(left), gk_ip_addr(right),
813 					  gk_ip_addrsize(left)) == 0);
814 
815 	PG_RETURN_POINTER(result);
816 }
817