xref: /netbsd/common/lib/libc/gen/rpst.c (revision 87a9663c)
1 /*	$NetBSD: rpst.c,v 1.6 2009/05/26 22:37:50 yamt Exp $	*/
2 
3 /*-
4  * Copyright (c)2009 YAMAMOTO Takashi,
5  * All rights reserved.
6  *
7  * Redistribution and use in source and binary forms, with or without
8  * modification, are permitted provided that the following conditions
9  * are met:
10  * 1. Redistributions of source code must retain the above copyright
11  *    notice, this list of conditions and the following disclaimer.
12  * 2. Redistributions in binary form must reproduce the above copyright
13  *    notice, this list of conditions and the following disclaimer in the
14  *    documentation and/or other materials provided with the distribution.
15  *
16  * THIS SOFTWARE IS PROVIDED BY THE AUTHOR AND CONTRIBUTORS ``AS IS'' AND
17  * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18  * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
19  * ARE DISCLAIMED.  IN NO EVENT SHALL THE AUTHOR OR CONTRIBUTORS BE LIABLE
20  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
21  * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS
22  * OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
23  * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
24  * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY
25  * OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF
26  * SUCH DAMAGE.
27  */
28 
29 /*
30  * radix priority search tree
31  *
32  * described in:
33  *	SIAM J. COMPUT.
34  *	Vol. 14, No. 2, May 1985
35  *	PRIORITY SEARCH TREES
36  *	EDWARD M. McCREIGHT
37  *
38  * ideas from linux:
39  *	- grow tree height on-demand.
40  *	- allow duplicated X values.  in that case, we act as a heap.
41  */
42 
43 #include <sys/cdefs.h>
44 
45 #if defined(_KERNEL)
46 __KERNEL_RCSID(0, "$NetBSD: rpst.c,v 1.6 2009/05/26 22:37:50 yamt Exp $");
47 #include <sys/param.h>
48 #else /* defined(_KERNEL) */
49 __RCSID("$NetBSD: rpst.c,v 1.6 2009/05/26 22:37:50 yamt Exp $");
50 #include <assert.h>
51 #include <stdbool.h>
52 #include <string.h>
53 #if 1
54 #define	KASSERT	assert
55 #else
56 #define	KASSERT(a)
57 #endif
58 #endif /* defined(_KERNEL) */
59 
60 #include <sys/rpst.h>
61 
62 /*
63  * rpst_init_tree: initialize a tree.
64  */
65 
66 void
67 rpst_init_tree(struct rpst_tree *t)
68 {
69 
70 	t->t_root = NULL;
71 	t->t_height = 0;
72 }
73 
74 /*
75  * rpst_height2max: calculate the maximum index which can be handled by
76  * a tree with the given height.
77  *
78  * 0  ... 0x0000000000000001
79  * 1  ... 0x0000000000000003
80  * 2  ... 0x0000000000000007
81  * 3  ... 0x000000000000000f
82  *
83  * 31 ... 0x00000000ffffffff
84  *
85  * 63 ... 0xffffffffffffffff
86  */
87 
88 static uint64_t
89 rpst_height2max(unsigned int height)
90 {
91 
92 	KASSERT(height < 64);
93 	if (height == 63) {
94 		return UINT64_MAX;
95 	}
96 	return (UINT64_C(1) << (height + 1)) - 1;
97 }
98 
99 /*
100  * rpst_level2mask: calculate the mask for the given level in the tree.
101  *
102  * the mask used to index root's children is level 0.
103  */
104 
105 static uint64_t
106 rpst_level2mask(const struct rpst_tree *t, unsigned int level)
107 {
108 	uint64_t mask;
109 
110 	if (t->t_height < level) {
111 		mask = 0;
112 	} else {
113 		mask = UINT64_C(1) << (t->t_height - level);
114 	}
115 	return mask;
116 }
117 
118 /*
119  * rpst_startmask: calculate the mask for the start of a search.
120  * (ie. the mask for the top-most bit)
121  */
122 
123 static uint64_t
124 rpst_startmask(const struct rpst_tree *t)
125 {
126 	const uint64_t mask = rpst_level2mask(t, 0);
127 
128 	KASSERT((mask | (mask - 1)) == rpst_height2max(t->t_height));
129 	return mask;
130 }
131 
132 /*
133  * rpst_update_parents: update n_parent of children
134  */
135 
136 static inline void
137 rpst_update_parents(struct rpst_node *n)
138 {
139 	int i;
140 
141 	for (i = 0; i < 2; i++) {
142 		if (n->n_children[i] != NULL) {
143 			n->n_children[i]->n_parent = n;
144 		}
145 	}
146 }
147 
148 /*
149  * rpst_enlarge_tree: enlarge tree so that 'index' can be stored
150  */
151 
152 static void
153 rpst_enlarge_tree(struct rpst_tree *t, uint64_t idx)
154 {
155 
156 	while (idx > rpst_height2max(t->t_height)) {
157 		struct rpst_node *n = t->t_root;
158 
159 		if (n != NULL) {
160 			rpst_remove_node(t, n);
161 			memset(&n->n_children, 0, sizeof(n->n_children));
162 			n->n_children[0] = t->t_root;
163 			t->t_root->n_parent = n;
164 			t->t_root = n;
165 			n->n_parent = NULL;
166 		}
167 		t->t_height++;
168 	}
169 }
170 
171 /*
172  * rpst_insert_node1: a helper for rpst_insert_node.
173  */
174 
175 static struct rpst_node *
176 rpst_insert_node1(struct rpst_node **where, struct rpst_node *n, uint64_t mask)
177 {
178 	struct rpst_node *parent;
179 	struct rpst_node *cur;
180 	unsigned int idx;
181 
182 	KASSERT((n->n_x & ((-mask) << 1)) == 0);
183 	parent = NULL;
184 next:
185 	cur = *where;
186 	if (cur == NULL) {
187 		n->n_parent = parent;
188 		memset(&n->n_children, 0, sizeof(n->n_children));
189 		*where = n;
190 		return NULL;
191 	}
192 	KASSERT(cur->n_parent == parent);
193 	if (n->n_y == cur->n_y && n->n_x == cur->n_x) {
194 		return cur;
195 	}
196 	if (n->n_y < cur->n_y) {
197 		/*
198 		 * swap cur and n.
199 		 * note that n is not in tree.
200 		 */
201 		memcpy(n->n_children, cur->n_children, sizeof(n->n_children));
202 		n->n_parent = cur->n_parent;
203 		rpst_update_parents(n);
204 		*where = n;
205 		n = cur;
206 		cur = *where;
207 	}
208 	KASSERT(*where == cur);
209 	idx = (n->n_x & mask) != 0;
210 	where = &cur->n_children[idx];
211 	parent = cur;
212 	KASSERT((*where) == NULL || ((((*where)->n_x & mask) != 0) == idx));
213 	KASSERT((*where) == NULL || (*where)->n_y >= cur->n_y);
214 	mask >>= 1;
215 	goto next;
216 }
217 
218 /*
219  * rpst_insert_node: insert a node into the tree.
220  *
221  * => return NULL on success.
222  * => if a duplicated node (a node with the same X,Y pair as ours) is found,
223  *    return the node.  in that case, the tree is intact.
224  */
225 
226 struct rpst_node *
227 rpst_insert_node(struct rpst_tree *t, struct rpst_node *n)
228 {
229 
230 	rpst_enlarge_tree(t, n->n_x);
231 	return rpst_insert_node1(&t->t_root, n, rpst_startmask(t));
232 }
233 
234 /*
235  * rpst_find_pptr: find a pointer to the given node.
236  *
237  * also, return the parent node via parentp.  (NULL for the root node.)
238  */
239 
240 static inline struct rpst_node **
241 rpst_find_pptr(struct rpst_tree *t, struct rpst_node *n,
242     struct rpst_node **parentp)
243 {
244 	struct rpst_node * const parent = n->n_parent;
245 	unsigned int i;
246 
247 	*parentp = parent;
248 	if (parent == NULL) {
249 		return &t->t_root;
250 	}
251 	for (i = 0; i < 2 - 1; i++) {
252 		if (parent->n_children[i] == n) {
253 			break;
254 		}
255 	}
256 	KASSERT(parent->n_children[i] == n);
257 	return &parent->n_children[i];
258 }
259 
260 /*
261  * rpst_remove_node_at: remove a node at *where.
262  */
263 
264 static void
265 rpst_remove_node_at(struct rpst_node *parent, struct rpst_node **where,
266     struct rpst_node *cur)
267 {
268 	struct rpst_node *tmp[2];
269 	struct rpst_node *selected;
270 	unsigned int selected_idx = 0; /* XXX gcc */
271 	unsigned int i;
272 
273 	KASSERT(cur != NULL);
274 	KASSERT(parent == cur->n_parent);
275 next:
276 	selected = NULL;
277 	for (i = 0; i < 2; i++) {
278 		struct rpst_node *c;
279 
280 		c = cur->n_children[i];
281 		KASSERT(c == NULL || c->n_parent == cur);
282 		if (selected == NULL || (c != NULL && c->n_y < selected->n_y)) {
283 			selected = c;
284 			selected_idx = i;
285 		}
286 	}
287 	/*
288 	 * now we have:
289 	 *
290 	 *      parent
291 	 *          \ <- where
292 	 *           cur
293 	 *           / \
294 	 *          A  selected
295 	 *              / \
296 	 *             B   C
297 	 */
298 	*where = selected;
299 	if (selected == NULL) {
300 		return;
301 	}
302 	/*
303 	 * swap selected->n_children and cur->n_children.
304 	 */
305 	memcpy(tmp, selected->n_children, sizeof(tmp));
306 	memcpy(selected->n_children, cur->n_children, sizeof(tmp));
307 	memcpy(cur->n_children, tmp, sizeof(tmp));
308 	rpst_update_parents(cur);
309 	rpst_update_parents(selected);
310 	selected->n_parent = parent;
311 	/*
312 	 *      parent
313 	 *          \ <- where
314 	 *          selected
315 	 *           / \
316 	 *          A  selected
317 	 *
318 	 *              cur
319 	 *              / \
320 	 *             B   C
321 	 */
322 	where = &selected->n_children[selected_idx];
323 	/*
324 	 *      parent
325 	 *          \
326 	 *          selected
327 	 *           / \ <- where
328 	 *          A  selected (*)
329 	 *
330 	 *              cur (**)
331 	 *              / \
332 	 *             B   C
333 	 *
334 	 * (*) this 'selected' will be overwritten in the next iteration.
335 	 * (**) cur->n_parent is bogus.
336 	 */
337 	parent = selected;
338 	goto next;
339 }
340 
341 /*
342  * rpst_remove_node: remove a node from the tree.
343  */
344 
345 void
346 rpst_remove_node(struct rpst_tree *t, struct rpst_node *n)
347 {
348 	struct rpst_node *parent;
349 	struct rpst_node **where;
350 
351 	where = rpst_find_pptr(t, n, &parent);
352 	rpst_remove_node_at(parent, where, n);
353 }
354 
355 static bool __unused
356 rpst_iterator_match_p(const struct rpst_node *n, const struct rpst_iterator *it)
357 {
358 
359 	if (n->n_y > it->it_max_y) {
360 		return false;
361 	}
362 	if (n->n_x < it->it_min_x) {
363 		return false;
364 	}
365 	if (n->n_x > it->it_max_x) {
366 		return false;
367 	}
368 	return true;
369 }
370 
371 struct rpst_node *
372 rpst_iterate_first(struct rpst_tree *t, uint64_t max_y, uint64_t min_x,
373     uint64_t max_x, struct rpst_iterator *it)
374 {
375 	struct rpst_node *n;
376 
377 	KASSERT(min_x <= max_x);
378 	n = t->t_root;
379 	if (n == NULL || n->n_y > max_y) {
380 		return NULL;
381 	}
382 	it->it_tree = t;
383 	it->it_cur = n;
384 	it->it_idx = (min_x & rpst_startmask(t)) != 0;
385 	it->it_level = 0;
386 	it->it_max_y = max_y;
387 	it->it_min_x = min_x;
388 	it->it_max_x = max_x;
389 	return rpst_iterate_next(it);
390 }
391 
392 static inline unsigned int
393 rpst_node_on_edge_p(const struct rpst_node *n, uint64_t val, uint64_t mask)
394 {
395 
396 	return ((n->n_x ^ val) & ((-mask) << 1)) == 0;
397 }
398 
399 static inline uint64_t
400 rpst_maxidx(const struct rpst_node *n, uint64_t max_x, uint64_t mask)
401 {
402 
403 	if (rpst_node_on_edge_p(n, max_x, mask)) {
404 		return (max_x & mask) != 0;
405 	} else {
406 		return 1;
407 	}
408 }
409 
410 static inline uint64_t
411 rpst_minidx(const struct rpst_node *n, uint64_t min_x, uint64_t mask)
412 {
413 
414 	if (rpst_node_on_edge_p(n, min_x, mask)) {
415 		return (min_x & mask) != 0;
416 	} else {
417 		return 0;
418 	}
419 }
420 
421 struct rpst_node *
422 rpst_iterate_next(struct rpst_iterator *it)
423 {
424 	struct rpst_tree *t;
425 	struct rpst_node *n;
426 	struct rpst_node *next;
427 	const uint64_t max_y = it->it_max_y;
428 	const uint64_t min_x = it->it_min_x;
429 	const uint64_t max_x = it->it_max_x;
430 	unsigned int idx;
431 	unsigned int maxidx;
432 	unsigned int level;
433 	uint64_t mask;
434 
435 	t = it->it_tree;
436 	n = it->it_cur;
437 	idx = it->it_idx;
438 	level = it->it_level;
439 	mask = rpst_level2mask(t, level);
440 	maxidx = rpst_maxidx(n, max_x, mask);
441 	KASSERT(n == t->t_root || rpst_iterator_match_p(n, it));
442 next:
443 	KASSERT(mask == rpst_level2mask(t, level));
444 	KASSERT(idx >= rpst_minidx(n, min_x, mask));
445 	KASSERT(maxidx == rpst_maxidx(n, max_x, mask));
446 	KASSERT(idx <= maxidx + 2);
447 	KASSERT(n != NULL);
448 #if 0
449 	printf("%s: cur=%p, idx=%u maxidx=%u level=%u mask=%" PRIx64 "\n",
450 	    __func__, (void *)n, idx, maxidx, level, mask);
451 #endif
452 	if (idx == maxidx + 1) { /* visit the current node */
453 		idx++;
454 		if (min_x <= n->n_x && n->n_x <= max_x) {
455 			it->it_tree = t;
456 			it->it_cur = n;
457 			it->it_idx = idx;
458 			it->it_level = level;
459 			KASSERT(rpst_iterator_match_p(n, it));
460 			return n; /* report */
461 		}
462 		goto next;
463 	} else if (idx == maxidx + 2) { /* back to the parent */
464 		struct rpst_node **where;
465 
466 		where = rpst_find_pptr(t, n, &next);
467 		if (next == NULL) {
468 			KASSERT(level == 0);
469 			KASSERT(t->t_root == n);
470 			KASSERT(&t->t_root == where);
471 			return NULL; /* done */
472 		}
473 		KASSERT(level > 0);
474 		level--;
475 		n = next;
476 		mask = rpst_level2mask(t, level);
477 		maxidx = rpst_maxidx(n, max_x, mask);
478 		idx = where - n->n_children + 1;
479 		KASSERT(idx < 2 + 1);
480 		goto next;
481 	}
482 	/* go to a child */
483 	KASSERT(idx < 2);
484 	next = n->n_children[idx];
485 	if (next == NULL || next->n_y > max_y) {
486 		idx++;
487 		goto next;
488 	}
489 	KASSERT(next->n_parent == n);
490 	KASSERT(next->n_y >= n->n_y);
491 	level++;
492 	mask >>= 1;
493 	n = next;
494 	idx = rpst_minidx(n, min_x, mask);
495 	maxidx = rpst_maxidx(n, max_x, mask);
496 #if 0
497 	printf("%s: visit %p idx=%u level=%u mask=%llx\n",
498 	    __func__, n, idx, level, mask);
499 #endif
500 	goto next;
501 }
502 
503 #if defined(UNITTEST)
504 #include <sys/time.h>
505 
506 #include <inttypes.h>
507 #include <stdio.h>
508 #include <stdlib.h>
509 
510 static void
511 rpst_dump_node(const struct rpst_node *n, unsigned int depth)
512 {
513 	unsigned int i;
514 
515 	for (i = 0; i < depth; i++) {
516 		printf("  ");
517 	}
518 	printf("[%u]", depth);
519 	if (n == NULL) {
520 		printf("NULL\n");
521 		return;
522 	}
523 	printf("%p x=%" PRIx64 "(%" PRIu64 ") y=%" PRIx64 "(%" PRIu64 ")\n",
524 	    (const void *)n, n->n_x, n->n_x, n->n_y, n->n_y);
525 	for (i = 0; i < 2; i++) {
526 		rpst_dump_node(n->n_children[i], depth + 1);
527 	}
528 }
529 
530 static void
531 rpst_dump_tree(const struct rpst_tree *t)
532 {
533 
534 	printf("pst %p height=%u\n", (const void *)t, t->t_height);
535 	rpst_dump_node(t->t_root, 0);
536 }
537 
538 struct testnode {
539 	struct rpst_node n;
540 	struct testnode *next;
541 	bool failed;
542 	bool found;
543 };
544 
545 struct rpst_tree t;
546 struct testnode *h = NULL;
547 
548 static uintmax_t
549 tvdiff(const struct timeval *tv1, const struct timeval *tv2)
550 {
551 
552 	return (uintmax_t)tv1->tv_sec * 1000000 + tv1->tv_usec -
553 	    tv2->tv_sec * 1000000 - tv2->tv_usec;
554 }
555 
556 static unsigned int
557 query(uint64_t max_y, uint64_t min_x, uint64_t max_x)
558 {
559 	struct testnode *n;
560 	struct rpst_node *rn;
561 	struct rpst_iterator it;
562 	struct timeval start;
563 	struct timeval end;
564 	unsigned int done;
565 
566 	printf("quering max_y=%" PRIu64 " min_x=%" PRIu64 " max_x=%" PRIu64
567 	    "\n",
568 	    max_y, min_x, max_x);
569 	done = 0;
570 	gettimeofday(&start, NULL);
571 	for (rn = rpst_iterate_first(&t, max_y, min_x, max_x, &it);
572 	    rn != NULL;
573 	    rn = rpst_iterate_next(&it)) {
574 		done++;
575 #if 0
576 		printf("found %p x=%" PRIu64 " y=%" PRIu64 "\n",
577 		    (void *)rn, rn->n_x, rn->n_y);
578 #endif
579 		n = (void *)rn;
580 		assert(!n->found);
581 		n->found = true;
582 	}
583 	gettimeofday(&end, NULL);
584 	printf("%u nodes found in %ju usecs\n", done,
585 	    tvdiff(&end, &start));
586 
587 	gettimeofday(&start, NULL);
588 	for (n = h; n != NULL; n = n->next) {
589 		assert(n->failed ||
590 		    n->found == rpst_iterator_match_p(&n->n, &it));
591 		n->found = false;
592 	}
593 	gettimeofday(&end, NULL);
594 	printf("(linear search took %ju usecs)\n", tvdiff(&end, &start));
595 	return done;
596 }
597 
598 int
599 main(int argc, char *argv[])
600 {
601 	struct testnode *n;
602 	unsigned int i;
603 	struct rpst_iterator it;
604 	struct timeval start;
605 	struct timeval end;
606 	uint64_t min_y = UINT64_MAX;
607 	uint64_t max_y = 0;
608 	uint64_t min_x = UINT64_MAX;
609 	uint64_t max_x = 0;
610 	uint64_t w;
611 	unsigned int done;
612 	unsigned int fail;
613 	unsigned int num = 500000;
614 
615 	rpst_init_tree(&t);
616 	rpst_dump_tree(&t);
617 	assert(NULL == rpst_iterate_first(&t, UINT64_MAX, 0, UINT64_MAX, &it));
618 
619 	for (i = 0; i < num; i++) {
620 		n = malloc(sizeof(*n));
621 		if (i > 499000) {
622 			n->n.n_x = 10;
623 			n->n.n_y = random();
624 		} else if (i > 400000) {
625 			n->n.n_x = i;
626 			n->n.n_y = random();
627 		} else {
628 			n->n.n_x = random();
629 			n->n.n_y = random();
630 		}
631 		if (n->n.n_y < min_y) {
632 			min_y = n->n.n_y;
633 		}
634 		if (n->n.n_y > max_y) {
635 			max_y = n->n.n_y;
636 		}
637 		if (n->n.n_x < min_x) {
638 			min_x = n->n.n_x;
639 		}
640 		if (n->n.n_x > max_x) {
641 			max_x = n->n.n_x;
642 		}
643 		n->found = false;
644 		n->failed = false;
645 		n->next = h;
646 		h = n;
647 	}
648 
649 	done = 0;
650 	fail = 0;
651 	gettimeofday(&start, NULL);
652 	for (n = h; n != NULL; n = n->next) {
653 		struct rpst_node *o;
654 #if 0
655 		printf("insert %p x=%" PRIu64 " y=%" PRIu64 "\n",
656 		    n, n->n.n_x, n->n.n_y);
657 #endif
658 		o = rpst_insert_node(&t, &n->n);
659 		if (o == NULL) {
660 			done++;
661 		} else {
662 			n->failed = true;
663 			fail++;
664 		}
665 	}
666 	gettimeofday(&end, NULL);
667 	printf("%u nodes inserted and %u insertion failed in %ju usecs\n",
668 	    done, fail,
669 	    tvdiff(&end, &start));
670 
671 	assert(min_y == 0 || 0 == query(min_y - 1, 0, UINT64_MAX));
672 	assert(max_x == UINT64_MAX ||
673 	    0 == query(UINT64_MAX, max_x + 1, UINT64_MAX));
674 	assert(min_x == 0 || 0 == query(UINT64_MAX, 0, min_x - 1));
675 
676 	done = query(max_y, min_x, max_x);
677 	assert(done == num - fail);
678 
679 	done = query(UINT64_MAX, 0, UINT64_MAX);
680 	assert(done == num - fail);
681 
682 	w = max_x - min_x;
683 	query(max_y / 2, min_x, max_x);
684 	query(max_y, min_x + w / 2, max_x);
685 	query(max_y / 2, min_x + w / 2, max_x);
686 	query(max_y / 2, min_x, max_x - w / 2);
687 	query(max_y / 2, min_x + w / 3, max_x - w / 3);
688 	query(max_y - 1, min_x + 1, max_x - 1);
689 	query(UINT64_MAX, 10, 10);
690 
691 	done = 0;
692 	gettimeofday(&start, NULL);
693 	for (n = h; n != NULL; n = n->next) {
694 		if (n->failed) {
695 			continue;
696 		}
697 #if 0
698 		printf("remove %p x=%" PRIu64 " y=%" PRIu64 "\n",
699 		    n, n->n.n_x, n->n.n_y);
700 #endif
701 		rpst_remove_node(&t, &n->n);
702 		done++;
703 	}
704 	gettimeofday(&end, NULL);
705 	printf("%u nodes removed in %ju usecs\n", done,
706 	    tvdiff(&end, &start));
707 
708 	rpst_dump_tree(&t);
709 }
710 #endif /* defined(UNITTEST) */
711