xref: /netbsd/sys/net/npf/lpm.c (revision 4c02f958)
1 /*-
2  * Copyright (c) 2016 Mindaugas Rasiukevicius <rmind at noxt eu>
3  * All rights reserved.
4  *
5  * Redistribution and use in source and binary forms, with or without
6  * modification, are permitted provided that the following conditions
7  * are met:
8  * 1. Redistributions of source code must retain the above copyright
9  *    notice, this list of conditions and the following disclaimer.
10  * 2. Redistributions in binary form must reproduce the above copyright
11  *    notice, this list of conditions and the following disclaimer in the
12  *    documentation and/or other materials provided with the distribution.
13  *
14  * THIS SOFTWARE IS PROVIDED BY THE AUTHOR AND CONTRIBUTORS ``AS IS'' AND
15  * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
16  * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
17  * ARE DISCLAIMED.  IN NO EVENT SHALL THE AUTHOR OR CONTRIBUTORS BE LIABLE
18  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
19  * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS
20  * OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
21  * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
22  * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY
23  * OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF
24  * SUCH DAMAGE.
25  */
26 
27 /*
28  * Longest Prefix Match (LPM) library supporting IPv4 and IPv6.
29  *
30  * Algorithm:
31  *
32  * Each prefix gets its own hash map and all added prefixes are saved
33  * in a bitmap.  On a lookup, we perform a linear scan of hash maps,
34  * iterating through the added prefixes only.  Usually, there are only
35  * a few unique prefixes used and such simple algorithm is very efficient.
36  * With many IPv6 prefixes, the linear scan might become a bottleneck.
37  */
38 
39 #if defined(_KERNEL)
40 #include <sys/cdefs.h>
41 __KERNEL_RCSID(0, "$NetBSD: lpm.c,v 1.6 2019/06/12 14:36:32 christos Exp $");
42 
43 #include <sys/param.h>
44 #include <sys/types.h>
45 #include <sys/malloc.h>
46 #include <sys/kmem.h>
47 #else
48 #include <sys/socket.h>
49 #include <arpa/inet.h>
50 
51 #include <stdio.h>
52 #include <stdlib.h>
53 #include <stdbool.h>
54 #include <stddef.h>
55 #include <string.h>
56 #include <strings.h>
57 #include <errno.h>
58 #include <assert.h>
59 #define kmem_alloc(a, b) malloc(a)
60 #define kmem_free(a, b) free(a)
61 #define kmem_zalloc(a, b) calloc(a, 1)
62 #endif
63 
64 #include "lpm.h"
65 
66 #define	LPM_MAX_PREFIX		(128)
67 #define	LPM_MAX_WORDS		(LPM_MAX_PREFIX >> 5)
68 #define	LPM_TO_WORDS(x)		((x) >> 2)
69 #define	LPM_HASH_STEP		(8)
70 #define	LPM_LEN_IDX(len)	((len) >> 4)
71 
72 #ifdef DEBUG
73 #define	ASSERT			assert
74 #else
75 #define	ASSERT(x)
76 #endif
77 
78 typedef struct lpm_ent {
79 	struct lpm_ent *next;
80 	void *		val;
81 	unsigned	len;
82 	uint8_t		key[];
83 } lpm_ent_t;
84 
85 typedef struct {
86 	unsigned	hashsize;
87 	unsigned	nitems;
88 	lpm_ent_t **	bucket;
89 } lpm_hmap_t;
90 
91 struct lpm {
92 	uint32_t	bitmask[LPM_MAX_WORDS];
93 	int		flags;
94 	void *		defvals[2];
95 	lpm_hmap_t	prefix[LPM_MAX_PREFIX + 1];
96 };
97 
98 static const uint32_t zero_address[LPM_MAX_WORDS];
99 
100 lpm_t *
lpm_create(int flags)101 lpm_create(int flags)
102 {
103 	lpm_t *lpm = kmem_zalloc(sizeof(*lpm), KM_SLEEP);
104 	lpm->flags = flags;
105 	return lpm;
106 }
107 
108 void
lpm_clear(lpm_t * lpm,lpm_dtor_t dtor,void * arg)109 lpm_clear(lpm_t *lpm, lpm_dtor_t dtor, void *arg)
110 {
111 	for (unsigned n = 0; n <= LPM_MAX_PREFIX; n++) {
112 		lpm_hmap_t *hmap = &lpm->prefix[n];
113 
114 		if (!hmap->hashsize) {
115 			KASSERT(!hmap->bucket);
116 			continue;
117 		}
118 		for (unsigned i = 0; i < hmap->hashsize; i++) {
119 			lpm_ent_t *entry = hmap->bucket[i];
120 
121 			while (entry) {
122 				lpm_ent_t *next = entry->next;
123 
124 				if (dtor) {
125 					dtor(arg, entry->key,
126 					    entry->len, entry->val);
127 				}
128 				kmem_free(entry,
129 				    offsetof(lpm_ent_t, key[entry->len]));
130 				entry = next;
131 			}
132 		}
133 		kmem_free(hmap->bucket, hmap->hashsize * sizeof(lpm_ent_t *));
134 		hmap->bucket = NULL;
135 		hmap->hashsize = 0;
136 		hmap->nitems = 0;
137 	}
138 	if (dtor) {
139 		dtor(arg, zero_address, 4, lpm->defvals[0]);
140 		dtor(arg, zero_address, 16, lpm->defvals[1]);
141 	}
142 	memset(lpm->bitmask, 0, sizeof(lpm->bitmask));
143 	memset(lpm->defvals, 0, sizeof(lpm->defvals));
144 }
145 
146 void
lpm_destroy(lpm_t * lpm)147 lpm_destroy(lpm_t *lpm)
148 {
149 	lpm_clear(lpm, NULL, NULL);
150 	kmem_free(lpm, sizeof(*lpm));
151 }
152 
153 /*
154  * fnv1a_hash: Fowler-Noll-Vo hash function (FNV-1a variant).
155  */
156 static uint32_t
fnv1a_hash(const void * buf,size_t len)157 fnv1a_hash(const void *buf, size_t len)
158 {
159 	uint32_t hash = 2166136261UL;
160 	const uint8_t *p = buf;
161 
162 	while (len--) {
163 		hash ^= *p++;
164 		hash *= 16777619U;
165 	}
166 	return hash;
167 }
168 
169 static bool
hashmap_rehash(lpm_hmap_t * hmap,unsigned size,int flags)170 hashmap_rehash(lpm_hmap_t *hmap, unsigned size, int flags)
171 {
172 	lpm_ent_t **bucket;
173 	unsigned hashsize;
174 
175 	for (hashsize = 1; hashsize < size; hashsize <<= 1) {
176 		continue;
177 	}
178 	bucket = kmem_zalloc(hashsize * sizeof(lpm_ent_t *), flags);
179 	if (bucket == NULL)
180 		return false;
181 	for (unsigned n = 0; n < hmap->hashsize; n++) {
182 		lpm_ent_t *list = hmap->bucket[n];
183 
184 		while (list) {
185 			lpm_ent_t *entry = list;
186 			uint32_t hash = fnv1a_hash(entry->key, entry->len);
187 			const unsigned i = hash & (hashsize - 1);
188 
189 			list = entry->next;
190 			entry->next = bucket[i];
191 			bucket[i] = entry;
192 		}
193 	}
194 	if (hmap->bucket)
195 		kmem_free(hmap->bucket, hmap->hashsize * sizeof(lpm_ent_t *));
196 	hmap->bucket = bucket;
197 	hmap->hashsize = hashsize;
198 	return true;
199 }
200 
201 static lpm_ent_t *
hashmap_insert(lpm_hmap_t * hmap,const void * key,size_t len,int flags)202 hashmap_insert(lpm_hmap_t *hmap, const void *key, size_t len, int flags)
203 {
204 	const unsigned target = hmap->nitems + LPM_HASH_STEP;
205 	const size_t entlen = offsetof(lpm_ent_t, key[len]);
206 	uint32_t hash, i;
207 	lpm_ent_t *entry;
208 
209 	if (hmap->hashsize < target && !hashmap_rehash(hmap, target, flags)) {
210 		return NULL;
211 	}
212 
213 	hash = fnv1a_hash(key, len);
214 	i = hash & (hmap->hashsize - 1);
215 	entry = hmap->bucket[i];
216 	while (entry) {
217 		if (entry->len == len && memcmp(entry->key, key, len) == 0) {
218 			return entry;
219 		}
220 		entry = entry->next;
221 	}
222 
223 	if ((entry = kmem_alloc(entlen, flags)) != NULL) {
224 		memcpy(entry->key, key, len);
225 		entry->next = hmap->bucket[i];
226 		entry->len = len;
227 
228 		hmap->bucket[i] = entry;
229 		hmap->nitems++;
230 	}
231 	return entry;
232 }
233 
234 static lpm_ent_t *
hashmap_lookup(lpm_hmap_t * hmap,const void * key,size_t len)235 hashmap_lookup(lpm_hmap_t *hmap, const void *key, size_t len)
236 {
237 	const uint32_t hash = fnv1a_hash(key, len);
238 	const unsigned i = hash & (hmap->hashsize - 1);
239 	lpm_ent_t *entry;
240 
241 	if (hmap->hashsize == 0) {
242 		return NULL;
243 	}
244 	entry = hmap->bucket[i];
245 
246 	while (entry) {
247 		if (entry->len == len && memcmp(entry->key, key, len) == 0) {
248 			return entry;
249 		}
250 		entry = entry->next;
251 	}
252 	return NULL;
253 }
254 
255 static int
hashmap_remove(lpm_hmap_t * hmap,const void * key,size_t len)256 hashmap_remove(lpm_hmap_t *hmap, const void *key, size_t len)
257 {
258 	const uint32_t hash = fnv1a_hash(key, len);
259 	const unsigned i = hash & (hmap->hashsize - 1);
260 	lpm_ent_t *prev = NULL, *entry;
261 
262 	if (hmap->hashsize == 0) {
263 		return -1;
264 	}
265 	entry = hmap->bucket[i];
266 
267 	while (entry) {
268 		if (entry->len == len && memcmp(entry->key, key, len) == 0) {
269 			if (prev) {
270 				prev->next = entry->next;
271 			} else {
272 				hmap->bucket[i] = entry->next;
273 			}
274 			kmem_free(entry, offsetof(lpm_ent_t, key[len]));
275 			return 0;
276 		}
277 		prev = entry;
278 		entry = entry->next;
279 	}
280 	return -1;
281 }
282 
283 /*
284  * compute_prefix: given the address and prefix length, compute and
285  * return the address prefix.
286  */
287 static inline void
compute_prefix(const unsigned nwords,const uint32_t * addr,unsigned preflen,uint32_t * prefix)288 compute_prefix(const unsigned nwords, const uint32_t *addr,
289     unsigned preflen, uint32_t *prefix)
290 {
291 	uint32_t addr2[4];
292 
293 	if ((uintptr_t)addr & 3) {
294 		/* Unaligned address: just copy for now. */
295 		memcpy(addr2, addr, nwords * 4);
296 		addr = addr2;
297 	}
298 	for (unsigned i = 0; i < nwords; i++) {
299 		if (preflen == 0) {
300 			prefix[i] = 0;
301 			continue;
302 		}
303 		if (preflen < 32) {
304 			uint32_t mask = htonl(0xffffffff << (32 - preflen));
305 			prefix[i] = addr[i] & mask;
306 			preflen = 0;
307 		} else {
308 			prefix[i] = addr[i];
309 			preflen -= 32;
310 		}
311 	}
312 }
313 
314 /*
315  * lpm_insert: insert the CIDR into the LPM table.
316  *
317  * => Returns zero on success and -1 on failure.
318  */
319 int
lpm_insert(lpm_t * lpm,const void * addr,size_t len,unsigned preflen,void * val)320 lpm_insert(lpm_t *lpm, const void *addr,
321     size_t len, unsigned preflen, void *val)
322 {
323 	const unsigned nwords = LPM_TO_WORDS(len);
324 	uint32_t prefix[LPM_MAX_WORDS];
325 	lpm_ent_t *entry;
326 	KASSERT(len == 4 || len == 16);
327 
328 	if (preflen == 0) {
329 		/* 0-length prefix is a special case. */
330 		lpm->defvals[LPM_LEN_IDX(len)] = val;
331 		return 0;
332 	}
333 	compute_prefix(nwords, addr, preflen, prefix);
334 	entry = hashmap_insert(&lpm->prefix[preflen], prefix, len, lpm->flags);
335 	if (entry) {
336 		const unsigned n = --preflen >> 5;
337 		lpm->bitmask[n] |= 0x80000000U >> (preflen & 31);
338 		entry->val = val;
339 		return 0;
340 	}
341 	return -1;
342 }
343 
344 /*
345  * lpm_remove: remove the specified prefix.
346  */
347 int
lpm_remove(lpm_t * lpm,const void * addr,size_t len,unsigned preflen)348 lpm_remove(lpm_t *lpm, const void *addr, size_t len, unsigned preflen)
349 {
350 	const unsigned nwords = LPM_TO_WORDS(len);
351 	uint32_t prefix[LPM_MAX_WORDS];
352 	KASSERT(len == 4 || len == 16);
353 
354 	if (preflen == 0) {
355 		lpm->defvals[LPM_LEN_IDX(len)] = NULL;
356 		return 0;
357 	}
358 	compute_prefix(nwords, addr, preflen, prefix);
359 	return hashmap_remove(&lpm->prefix[preflen], prefix, len);
360 }
361 
362 /*
363  * lpm_lookup: find the longest matching prefix given the IP address.
364  *
365  * => Returns the associated value on success or NULL on failure.
366  */
367 void *
lpm_lookup(lpm_t * lpm,const void * addr,size_t len)368 lpm_lookup(lpm_t *lpm, const void *addr, size_t len)
369 {
370 	const unsigned nwords = LPM_TO_WORDS(len);
371 	unsigned i, n = nwords;
372 	uint32_t prefix[LPM_MAX_WORDS];
373 
374 	while (n--) {
375 		uint32_t bitmask = lpm->bitmask[n];
376 
377 		while ((i = ffs(bitmask)) != 0) {
378 			const unsigned preflen = (32 * n) + (32 - --i);
379 			lpm_hmap_t *hmap = &lpm->prefix[preflen];
380 			lpm_ent_t *entry;
381 
382 			compute_prefix(nwords, addr, preflen, prefix);
383 			entry = hashmap_lookup(hmap, prefix, len);
384 			if (entry) {
385 				return entry->val;
386 			}
387 			bitmask &= ~(1U << i);
388 		}
389 	}
390 	return lpm->defvals[LPM_LEN_IDX(len)];
391 }
392 
393 /*
394  * lpm_lookup_prefix: return the value associated with a prefix
395  *
396  * => Returns the associated value on success or NULL on failure.
397  */
398 void *
lpm_lookup_prefix(lpm_t * lpm,const void * addr,size_t len,unsigned preflen)399 lpm_lookup_prefix(lpm_t *lpm, const void *addr, size_t len, unsigned preflen)
400 {
401 	const unsigned nwords = LPM_TO_WORDS(len);
402 	uint32_t prefix[LPM_MAX_WORDS];
403 	lpm_ent_t *entry;
404 	KASSERT(len == 4 || len == 16);
405 
406 	if (preflen == 0) {
407 		return lpm->defvals[LPM_LEN_IDX(len)];
408 	}
409 	compute_prefix(nwords, addr, preflen, prefix);
410 	entry = hashmap_lookup(&lpm->prefix[preflen], prefix, len);
411 	if (entry) {
412 		return entry->val;
413 	}
414 	return NULL;
415 }
416 
417 #if !defined(_KERNEL)
418 /*
419  * lpm_strtobin: convert CIDR string to the binary IP address and mask.
420  *
421  * => The address will be in the network byte order.
422  * => Returns 0 on success or -1 on failure.
423  */
424 int
lpm_strtobin(const char * cidr,void * addr,size_t * len,unsigned * preflen)425 lpm_strtobin(const char *cidr, void *addr, size_t *len, unsigned *preflen)
426 {
427 	char *p, buf[INET6_ADDRSTRLEN];
428 
429 	strncpy(buf, cidr, sizeof(buf));
430 	buf[sizeof(buf) - 1] = '\0';
431 
432 	if ((p = strchr(buf, '/')) != NULL) {
433 		const ptrdiff_t off = p - buf;
434 		*preflen = atoi(&buf[off + 1]);
435 		buf[off] = '\0';
436 	} else {
437 		*preflen = LPM_MAX_PREFIX;
438 	}
439 
440 	if (inet_pton(AF_INET6, buf, addr) == 1) {
441 		*len = 16;
442 		return 0;
443 	}
444 	if (inet_pton(AF_INET, buf, addr) == 1) {
445 		if (*preflen == LPM_MAX_PREFIX) {
446 			*preflen = 32;
447 		}
448 		*len = 4;
449 		return 0;
450 	}
451 	return -1;
452 }
453 #endif
454