1 /*
2  * Copyright (C) 2016-2019 CZ.NIC, z.s.p.o. <knot-dns@labs.nic.cz>
3  *
4  * This program is free software: you can redistribute it and/or modify
5  * it under the terms of the GNU General Public License as published by
6  * the Free Software Foundation, either version 3 of the License, or
7  * (at your option) any later version.
8  *
9  * This program is distributed in the hope that it will be useful,
10  * but WITHOUT ANY WARRANTY; without even the implied warranty of
11  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
12  * GNU General Public License for more details.
13  *
14  * You should have received a copy of the GNU General Public License
15  * along with this program.  If not, see <http://www.gnu.org/licenses/>.
16  *
17  * The code originated from https://github.com/fanf2/qp/blob/master/qp.c
18  * at revision 5f6d93753.
19  */
20 
21 #include <assert.h>
22 #include <stdlib.h>
23 #include <string.h>
24 
25 #include "lib/trie.h"
26 
27 /*! \brief Error codes used in the library. */
28 enum knot_error {
29     KNOT_EOK = 0,
30 
31     /* Directly mapped error codes. */
32     KNOT_ENOMEM = -ENOMEM,
33     KNOT_EINVAL = -EINVAL,
34     KNOT_ENOENT = -ENOENT,
35 };
36 
37 #if defined(__i386) || defined(__x86_64) || defined(_M_IX86)      \
38     || (defined(__BYTE_ORDER__) && defined(__ORDER_LITTLE_ENDIAN) \
39         && __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__)
40 
41 /*!
42 	 * \brief Use a pointer alignment hack to save memory.
43 	 *
44 	 * When on, isbranch() relies on the fact that in leaf_t the first pointer
45 	 * is aligned on multiple of 4 bytes and that the flags bitfield is
46 	 * overlaid over the lowest two bits of that pointer.
47 	 * Neither is really guaranteed by the C standards; the second part should
48 	 * be OK with x86_64 ABI and most likely any other little-endian platform.
49 	 * It would be possible to manipulate the right bits portably, but it would
50 	 * complicate the code nontrivially. C++ doesn't even guarantee type-punning.
51 	 * In debug mode we check this works OK when creating a new trie instance.
52 	 */
53 #define FLAGS_HACK 1
54 #else
55 #define FLAGS_HACK 0
56 #endif
57 
58 typedef unsigned char byte;
59 #ifndef uint
60 typedef unsigned int uint;
61 #define uint uint
62 #endif
63 typedef uint bitmap_t; /*! Bit-maps, using the range of 1<<0 to 1<<16 (inclusive). */
64 
65 typedef struct {
66     uint32_t len; // 32 bits are enough for key lengths; probably even 16 bits would be.
67     uint8_t  chars[];
68 } tkey_t;
69 
70 /*! \brief Leaf of trie. */
71 typedef struct {
72 #if !FLAGS_HACK
73     byte flags;
74 #endif
75     tkey_t*    key; /*!< The pointer must be aligned to 4-byte multiples! */
76     trie_val_t val;
77 } leaf_t;
78 
79 /*! \brief A trie node is either leaf_t or branch_t. */
80 typedef union node node_t;
81 
82 /*!
83  * \brief Branch node of trie.
84  *
85  * - The flags distinguish whether the node is a leaf_t (0), or a branch
86  *   testing the more-important nibble (1) or the less-important one (2).
87  * - It stores the index of the byte that the node tests.  The combined
88  *   value (index*4 + flags) increases in branch nodes as you go deeper
89  *   into the trie.  All the keys below a branch are identical up to the
90  *   nibble identified by the branch.  Indices have to be stored because
91  *   we skip any branch nodes that would have a single child.
92  *   (Consequently, the skipped parts of key have to be validated in a leaf.)
93  * - The bitmap indicates which subtries are present.  The present child nodes
94  *   are stored in the twigs array (with no holes between them).
95  * - To simplify storing keys that are prefixes of each other, the end-of-string
96  *   position is treated as another nibble value, ordered before all others.
97  *   That affects the bitmap and twigs fields.
98  *
99  * \note The branch nodes are never allocated individually, but they are
100  *   always part of either the root node or the twigs array of the parent.
101  */
102 typedef struct {
103 #if FLAGS_HACK
104     uint32_t flags : 2,
105         bitmap : 17; /*!< The first bitmap bit is for end-of-string child. */
106 #else
107     byte     flags;
108     uint32_t bitmap;
109 #endif
110     uint32_t index;
111     node_t*  twigs;
112 } branch_t;
113 
114 union node {
115     leaf_t   leaf;
116     branch_t branch;
117 };
118 
119 struct trie {
120     node_t    root; // undefined when weight == 0, see empty_root()
121     size_t    weight;
122     knot_mm_t mm;
123 };
124 
125 /* Included from other files */
126 
127 /** Readability: avoid const-casts in code. */
free_const(const void * what)128 static inline void free_const(const void* what)
129 {
130     free((void*)what);
131 }
132 
mm_alloc(knot_mm_t * mm,size_t size)133 static inline void* mm_alloc(knot_mm_t* mm, size_t size)
134 {
135     if (mm)
136         return mm->alloc(mm->ctx, size);
137     else
138         return malloc(size);
139 }
140 
mm_free(knot_mm_t * mm,const void * what)141 static inline void mm_free(knot_mm_t* mm, const void* what)
142 {
143     if (mm) {
144         if (mm->free)
145             mm->free((void*)what);
146     } else
147         free_const(what);
148 }
149 
mm_malloc(void * ctx,size_t n)150 static void* mm_malloc(void* ctx, size_t n)
151 {
152     (void)ctx;
153     return malloc(n);
154 }
155 
mm_realloc(knot_mm_t * mm,void * what,size_t size,size_t prev_size)156 static void* mm_realloc(knot_mm_t* mm, void* what, size_t size, size_t prev_size)
157 {
158     if (mm) {
159         void* p = mm->alloc(mm->ctx, size);
160         if (p == NULL) {
161             return NULL;
162         } else {
163             if (what) {
164                 memcpy(p, what,
165                     prev_size < size ? prev_size : size);
166             }
167             mm_free(mm, what);
168             return p;
169         }
170     } else {
171         return realloc(what, size);
172     }
173 }
174 
mm_ctx_init(knot_mm_t * mm)175 static inline void mm_ctx_init(knot_mm_t* mm)
176 {
177     mm->ctx   = NULL;
178     mm->alloc = mm_malloc;
179     mm->free  = free;
180 }
181 
182 /*! \brief Make the root node empty (debug-only). */
empty_root(node_t * root)183 static inline void empty_root(node_t* root)
184 {
185 #ifndef NDEBUG
186     *root = (node_t) { .branch = {
187                            .flags  = 3, // invalid value that fits
188                            .bitmap = 0,
189                            .index  = -1,
190                            .twigs  = NULL } };
191 #endif
192 }
193 
194 /*! \brief Check that unportable code works OK (debug-only). */
assert_portability(void)195 static void assert_portability(void)
196 {
197 #if FLAGS_HACK
198     assert(((union node) { .leaf = {
199                                .key = (tkey_t*)(((uint8_t*)NULL) + 1),
200                                .val = NULL } })
201                .branch.flags
202            == 1);
203 #endif
204 }
205 
206 /*! \brief Propagate error codes. */
207 #define ERR_RETURN(x)                        \
208     do {                                     \
209         int err_code_ = x;                   \
210         if (unlikely(err_code_ != KNOT_EOK)) \
211             return err_code_;                \
212     } while (false)
213 
214 /*!
215  * \brief Count the number of set bits.
216  *
217  * \TODO This implementation may be relatively slow on some HW.
218  */
bitmap_weight(bitmap_t w)219 static uint bitmap_weight(bitmap_t w)
220 {
221     assert((w & ~((1 << 17) - 1)) == 0); // using the least-important 17 bits
222     return __builtin_popcount(w);
223 }
224 
225 /*! \brief Only keep the lowest bit in the bitmap (least significant -> twigs[0]). */
bitmap_lowest_bit(bitmap_t w)226 static bitmap_t bitmap_lowest_bit(bitmap_t w)
227 {
228     assert((w & ~((1 << 17) - 1)) == 0); // using the least-important 17 bits
229     return 1 << __builtin_ctz(w);
230 }
231 
232 /*! \brief Test flags to determine type of this node. */
isbranch(const node_t * t)233 static bool isbranch(const node_t* t)
234 {
235     uint f = t->branch.flags;
236     assert(f <= 2);
237     return f != 0;
238 }
239 
240 /*! \brief Make a bitmask for testing a branch bitmap. */
nibbit(byte k,uint flags)241 static bitmap_t nibbit(byte k, uint flags)
242 {
243     uint shift  = (2 - flags) << 2;
244     uint nibble = (k >> shift) & 0xf;
245     return 1 << (nibble + 1 /*because of prefix keys*/);
246 }
247 
248 /*! \brief Extract a nibble from a key and turn it into a bitmask. */
twigbit(const node_t * t,const uint8_t * key,uint32_t len)249 static bitmap_t twigbit(const node_t* t, const uint8_t* key, uint32_t len)
250 {
251     assert(isbranch(t));
252     uint i = t->branch.index;
253 
254     if (i >= len)
255         return 1 << 0; // leaf position
256 
257     return nibbit((byte)key[i], t->branch.flags);
258 }
259 
260 /*! \brief Test if a branch node has a child indicated by a bitmask. */
hastwig(const node_t * t,bitmap_t bit)261 static bool hastwig(const node_t* t, bitmap_t bit)
262 {
263     assert(isbranch(t));
264     return t->branch.bitmap & bit;
265 }
266 
267 /*! \brief Compute offset of an existing child in a branch node. */
twigoff(const node_t * t,bitmap_t b)268 static uint twigoff(const node_t* t, bitmap_t b)
269 {
270     assert(isbranch(t));
271     return bitmap_weight(t->branch.bitmap & (b - 1));
272 }
273 
274 /*! \brief Get pointer to a particular child of a branch node. */
twig(node_t * t,uint i)275 static node_t* twig(node_t* t, uint i)
276 {
277     assert(isbranch(t));
278     return &t->branch.twigs[i];
279 }
280 
281 /*!
282  * \brief For a branch nod, compute offset of a child and child count.
283  *
284  * Having this separate might be meaningful for performance optimization.
285  */
286 #define TWIGOFFMAX(off, max, t, b)                 \
287     do {                                           \
288         (off) = twigoff((t), (b));                 \
289         (max) = bitmap_weight((t)->branch.bitmap); \
290     } while (0)
291 
292 /*! \brief Simple string comparator. */
key_cmp(const uint8_t * k1,uint32_t k1_len,const uint8_t * k2,uint32_t k2_len)293 static int key_cmp(const uint8_t* k1, uint32_t k1_len, const uint8_t* k2, uint32_t k2_len)
294 {
295     int ret = memcmp(k1, k2, MIN(k1_len, k2_len));
296     if (ret != 0) {
297         return ret;
298     }
299 
300     /* Key string is equal, compare lengths. */
301     if (k1_len == k2_len) {
302         return 0;
303     } else if (k1_len < k2_len) {
304         return -1;
305     } else {
306         return 1;
307     }
308 }
309 
trie_create(knot_mm_t * mm)310 trie_t* trie_create(knot_mm_t* mm)
311 {
312     assert_portability();
313     trie_t* trie = mm_alloc(mm, sizeof(trie_t));
314     if (trie != NULL) {
315         empty_root(&trie->root);
316         trie->weight = 0;
317         if (mm != NULL)
318             trie->mm = *mm;
319         else
320             mm_ctx_init(&trie->mm);
321     }
322     return trie;
323 }
324 
325 /*! \brief Free anything under the trie node, except for the passed pointer itself. */
clear_trie(node_t * trie,knot_mm_t * mm)326 static void clear_trie(node_t* trie, knot_mm_t* mm)
327 {
328     if (!isbranch(trie)) {
329         mm_free(mm, trie->leaf.key);
330     } else {
331         branch_t* b   = &trie->branch;
332         int       len = bitmap_weight(b->bitmap);
333         int       i;
334         for (i = 0; i < len; ++i)
335             clear_trie(b->twigs + i, mm);
336         mm_free(mm, b->twigs);
337     }
338 }
339 
trie_free(trie_t * tbl)340 void trie_free(trie_t* tbl)
341 {
342     if (tbl == NULL)
343         return;
344     if (tbl->weight)
345         clear_trie(&tbl->root, &tbl->mm);
346     mm_free(&tbl->mm, tbl);
347 }
348 
trie_clear(trie_t * tbl)349 void trie_clear(trie_t* tbl)
350 {
351     assert(tbl);
352     if (!tbl->weight)
353         return;
354     clear_trie(&tbl->root, &tbl->mm);
355     empty_root(&tbl->root);
356     tbl->weight = 0;
357 }
358 
trie_weight(const trie_t * tbl)359 size_t trie_weight(const trie_t* tbl)
360 {
361     assert(tbl);
362     return tbl->weight;
363 }
364 
365 struct found {
366     leaf_t*   l; /**< the found leaf (NULL if not found) */
367     branch_t* p; /**< the leaf's parent (if exists) */
368     bitmap_t  b; /**< bit-mask with a single bit marking l under p */
369 };
370 /** Search trie for an item with the given key (equality only). */
find_equal(trie_t * tbl,const uint8_t * key,uint32_t len)371 static struct found find_equal(trie_t* tbl, const uint8_t* key, uint32_t len)
372 {
373     assert(tbl);
374     struct found ret0;
375     memset(&ret0, 0, sizeof(ret0));
376     if (!tbl->weight)
377         return ret0;
378     /* Current node and parent while descending (returned values basically). */
379     node_t*   t = &tbl->root;
380     branch_t* p = NULL;
381     bitmap_t  b = 0;
382     while (isbranch(t)) {
383         __builtin_prefetch(t->branch.twigs);
384         b = twigbit(t, key, len);
385         if (!hastwig(t, b))
386             return ret0;
387         p = &t->branch;
388         t = twig(t, twigoff(t, b));
389     }
390     if (key_cmp(key, len, t->leaf.key->chars, t->leaf.key->len) != 0)
391         return ret0;
392     return (struct found) {
393         .l = &t->leaf,
394         .p = p,
395         .b = b,
396     };
397 }
398 /** Find item with the first key (lexicographical order). */
find_first(trie_t * tbl)399 static struct found find_first(trie_t* tbl)
400 {
401     assert(tbl);
402     if (!tbl->weight) {
403         struct found ret0;
404         memset(&ret0, 0, sizeof(ret0));
405         return ret0;
406     }
407     /* Current node and parent while descending (returned values basically). */
408     node_t*   t = &tbl->root;
409     branch_t* p = NULL;
410     while (isbranch(t)) {
411         p = &t->branch;
412         t = &p->twigs[0];
413     }
414     return (struct found) {
415         .l = &t->leaf,
416         .p = p,
417         .b = p ? bitmap_lowest_bit(p->bitmap) : 0,
418     };
419 }
420 
trie_get_try(trie_t * tbl,const uint8_t * key,uint32_t len)421 trie_val_t* trie_get_try(trie_t* tbl, const uint8_t* key, uint32_t len)
422 {
423     struct found found = find_equal(tbl, key, len);
424     return found.l ? &found.l->val : NULL;
425 }
426 
trie_get_first(trie_t * tbl,uint8_t ** key,uint32_t * len)427 trie_val_t* trie_get_first(trie_t* tbl, uint8_t** key, uint32_t* len)
428 {
429     struct found found = find_first(tbl);
430     if (!found.l)
431         return NULL;
432     if (key)
433         *key = found.l->key->chars;
434     if (len)
435         *len = found.l->key->len;
436     return &found.l->val;
437 }
438 
439 /*!
440  * \brief Stack of nodes, storing a path down a trie.
441  *
442  * The structure also serves directly as the public trie_it_t type,
443  * in which case it always points to the current leaf, unless we've finished
444  * (i.e. it->len == 0).
445  */
446 typedef struct trie_it {
447     node_t** stack; /*!< The stack; malloc is used directly instead of mm. */
448     uint32_t len; /*!< Current length of the stack. */
449     uint32_t alen; /*!< Allocated/available length of the stack. */
450     /*! \brief Initial storage for \a stack; it should fit in many use cases. */
451     node_t* stack_init[60];
452 } nstack_t;
453 
454 /*! \brief Create a node stack containing just the root (or empty). */
ns_init(nstack_t * ns,trie_t * tbl)455 static void ns_init(nstack_t* ns, trie_t* tbl)
456 {
457     assert(tbl);
458     ns->stack = ns->stack_init;
459     ns->alen  = sizeof(ns->stack_init) / sizeof(ns->stack_init[0]);
460     if (tbl->weight) {
461         ns->len      = 1;
462         ns->stack[0] = &tbl->root;
463     } else {
464         ns->len = 0;
465     }
466 }
467 
468 /*! \brief Free inside of the stack, i.e. not the passed pointer itself. */
ns_cleanup(nstack_t * ns)469 static void ns_cleanup(nstack_t* ns)
470 {
471     assert(ns && ns->stack);
472     if (likely(ns->stack == ns->stack_init))
473         return;
474     free(ns->stack);
475 #ifndef NDEBUG
476     ns->stack = NULL;
477     ns->alen  = 0;
478 #endif
479 }
480 
481 /*! \brief Allocate more space for the stack. */
ns_longer_alloc(nstack_t * ns)482 static int ns_longer_alloc(nstack_t* ns)
483 {
484     ns->alen *= 2;
485     size_t   new_size = sizeof(nstack_t) + ns->alen * sizeof(node_t*);
486     node_t** st;
487     if (ns->stack == ns->stack_init) {
488         st = malloc(new_size);
489         if (st != NULL)
490             memcpy(st, ns->stack, ns->len * sizeof(node_t*));
491     } else {
492         st = realloc(ns->stack, new_size);
493     }
494     if (st == NULL)
495         return KNOT_ENOMEM;
496     ns->stack = st;
497     return KNOT_EOK;
498 }
499 
500 /*! \brief Ensure the node stack can be extended by one. */
ns_longer(nstack_t * ns)501 static inline int ns_longer(nstack_t* ns)
502 {
503     // get a longer stack if needed
504     if (likely(ns->len < ns->alen))
505         return KNOT_EOK;
506     return ns_longer_alloc(ns); // hand-split the part suitable for inlining
507 }
508 
509 /*!
510  * \brief Find the "branching point" as if searching for a key.
511  *
512  *  The whole path to the point is kept on the passed stack;
513  *  always at least the root will remain on the top of it.
514  *  Beware: the precise semantics of this function is rather tricky.
515  *  The top of the stack will contain: the corresponding leaf if exact match is found;
516  *  or the immediate node below a branching-point-on-edge or the branching-point itself.
517  *
518  *  \param info   Set position of the point of first mismatch (in index and flags).
519  *  \param first  Set the value of the first non-matching character (from trie),
520  *                optionally; end-of-string character has value -256 (that's why it's int).
521  *                Note: the character is converted to *unsigned* char (i.e. 0..255),
522  *                as that's the ordering used in the trie.
523  *
524  *  \return KNOT_EOK or KNOT_ENOMEM.
525  */
ns_find_branch(nstack_t * ns,const uint8_t * key,uint32_t len,branch_t * info,int * first)526 static int ns_find_branch(nstack_t* ns, const uint8_t* key, uint32_t len,
527     branch_t* info, int* first)
528 {
529     assert(ns && ns->len && info);
530     // First find some leaf with longest matching prefix.
531     while (isbranch(ns->stack[ns->len - 1])) {
532         ERR_RETURN(ns_longer(ns));
533         node_t* t = ns->stack[ns->len - 1];
534         __builtin_prefetch(t->branch.twigs);
535         bitmap_t b = twigbit(t, key, len);
536         // Even if our key is missing from this branch we need to
537         // keep iterating down to a leaf. It doesn't matter which
538         // twig we choose since the keys are all the same up to this
539         // index. Note that blindly using twigoff(t, b) can cause
540         // an out-of-bounds index if it equals twigmax(t).
541         uint i               = hastwig(t, b) ? twigoff(t, b) : 0;
542         ns->stack[ns->len++] = twig(t, i);
543     }
544     tkey_t* lkey = ns->stack[ns->len - 1]->leaf.key;
545     // Find index of the first char that differs.
546     uint32_t index = 0;
547     while (index < MIN(len, lkey->len)) {
548         if (key[index] != lkey->chars[index])
549             break;
550         else
551             ++index;
552     }
553     info->index = index;
554     if (first)
555         *first = lkey->len > index ? (unsigned char)lkey->chars[index] : -256;
556     // Find flags: which half-byte has matched.
557     uint flags;
558     if (index == len && len == lkey->len) { // found equivalent key
559         info->flags = flags = 0;
560         goto success;
561     }
562     if (likely(index < MIN(len, lkey->len))) {
563         byte k2 = (byte)lkey->chars[index];
564         byte k1 = (byte)key[index];
565         flags   = ((k1 ^ k2) & 0xf0) ? 1 : 2;
566     } else { // one is prefix of another
567         flags = 1;
568     }
569     info->flags = flags;
570     // now go up the trie from the current leaf
571     branch_t* t;
572     do {
573         if (unlikely(ns->len == 1))
574             goto success; // only the root stays on the stack
575         t = (branch_t*)ns->stack[ns->len - 2];
576         if (t->index < index || (t->index == index && t->flags < flags))
577             goto success;
578         --ns->len;
579     } while (true);
580 success:
581 #ifndef NDEBUG // invariants on successful return
582     assert(ns->len);
583     if (isbranch(ns->stack[ns->len - 1])) {
584         t = &ns->stack[ns->len - 1]->branch;
585         assert(t->index > index || (t->index == index && t->flags >= flags));
586     }
587     if (ns->len > 1) {
588         t = &ns->stack[ns->len - 2]->branch;
589         assert(t->index < index || (t->index == index && (t->flags < flags || (t->flags == 1 && flags == 0))));
590     }
591 #endif
592     return KNOT_EOK;
593 }
594 
595 /*!
596  * \brief Advance the node stack to the last leaf in the subtree.
597  *
598  * \return KNOT_EOK or KNOT_ENOMEM.
599  */
ns_last_leaf(nstack_t * ns)600 static int ns_last_leaf(nstack_t* ns)
601 {
602     assert(ns);
603     do {
604         ERR_RETURN(ns_longer(ns));
605         node_t* t = ns->stack[ns->len - 1];
606         if (!isbranch(t))
607             return KNOT_EOK;
608         int lasti = bitmap_weight(t->branch.bitmap) - 1;
609         assert(lasti >= 0);
610         ns->stack[ns->len++] = twig(t, lasti);
611     } while (true);
612 }
613 
614 /*!
615  * \brief Advance the node stack to the first leaf in the subtree.
616  *
617  * \return KNOT_EOK or KNOT_ENOMEM.
618  */
ns_first_leaf(nstack_t * ns)619 static int ns_first_leaf(nstack_t* ns)
620 {
621     assert(ns && ns->len);
622     do {
623         ERR_RETURN(ns_longer(ns));
624         node_t* t = ns->stack[ns->len - 1];
625         if (!isbranch(t))
626             return KNOT_EOK;
627         ns->stack[ns->len++] = twig(t, 0);
628     } while (true);
629 }
630 
631 /*!
632  * \brief Advance the node stack to the leaf that is previous to the current node.
633  *
634  * \note Prefix leaf under the current node DOES count (if present; perhaps questionable).
635  * \return KNOT_EOK on success, KNOT_ENOENT on not-found, or possibly KNOT_ENOMEM.
636  */
ns_prev_leaf(nstack_t * ns)637 static int ns_prev_leaf(nstack_t* ns)
638 {
639     assert(ns && ns->len > 0);
640 
641     node_t* t = ns->stack[ns->len - 1];
642     if (hastwig(t, 1 << 0)) { // the prefix leaf
643         t = twig(t, 0);
644         ERR_RETURN(ns_longer(ns));
645         ns->stack[ns->len++] = t;
646         return KNOT_EOK;
647     }
648 
649     do {
650         if (ns->len < 2)
651             return KNOT_ENOENT; // root without empty key has no previous leaf
652         t              = ns->stack[ns->len - 1];
653         node_t* p      = ns->stack[ns->len - 2];
654         int     pindex = t - p->branch.twigs; // index in parent via pointer arithmetic
655         assert(pindex >= 0 && pindex <= 16);
656         if (pindex > 0) { // t isn't the first child -> go down the previous one
657             ns->stack[ns->len - 1] = twig(p, pindex - 1);
658             return ns_last_leaf(ns);
659         }
660         // we've got to go up again
661         --ns->len;
662     } while (true);
663 }
664 
665 /*!
666  * \brief Advance the node stack to the leaf that is successor to the current node.
667  *
668  * \note Prefix leaf or anything else under the current node DOES count.
669  * \return KNOT_EOK on success, KNOT_ENOENT on not-found, or possibly KNOT_ENOMEM.
670  */
ns_next_leaf(nstack_t * ns)671 static int ns_next_leaf(nstack_t* ns)
672 {
673     assert(ns && ns->len > 0);
674 
675     node_t* t = ns->stack[ns->len - 1];
676     if (isbranch(t))
677         return ns_first_leaf(ns);
678     do {
679         if (ns->len < 2)
680             return KNOT_ENOENT; // not found, as no more parent is available
681         t              = ns->stack[ns->len - 1];
682         node_t* p      = ns->stack[ns->len - 2];
683         int     pindex = t - p->branch.twigs; // index in parent via pointer arithmetic
684         assert(pindex >= 0 && pindex <= 16);
685         int pcount = bitmap_weight(p->branch.bitmap);
686         if (pindex + 1 < pcount) { // t isn't the last child -> go down the next one
687             ns->stack[ns->len - 1] = twig(p, pindex + 1);
688             return ns_first_leaf(ns);
689         }
690         // we've got to go up again
691         --ns->len;
692     } while (true);
693 }
694 
trie_get_leq(trie_t * tbl,const uint8_t * key,uint32_t len,trie_val_t ** val)695 int trie_get_leq(trie_t* tbl, const uint8_t* key, uint32_t len, trie_val_t** val)
696 {
697     assert(tbl && val);
698     *val = NULL; // so on failure we can just return;
699     if (tbl->weight == 0)
700         return KNOT_ENOENT;
701     { // Intentionally un-indented; until end of function, to bound cleanup attr.
702         // First find a key with longest-matching prefix
703         __attribute__((cleanup(ns_cleanup)))
704         nstack_t ns_local;
705         ns_init(&ns_local, tbl);
706         nstack_t* ns = &ns_local;
707         branch_t  bp;
708         int       un_leaf; // first unmatched character in the leaf
709         ERR_RETURN(ns_find_branch(ns, key, len, &bp, &un_leaf));
710         int     un_key = bp.index < len ? (unsigned char)key[bp.index] : -256;
711         node_t* t      = ns->stack[ns->len - 1];
712         if (bp.flags == 0) { // found exact match
713             *val = &t->leaf.val;
714             return KNOT_EOK;
715         }
716         // Get t: the last node on matching path
717         if (isbranch(t) && t->branch.index == bp.index && t->branch.flags == bp.flags) {
718             // t is OK
719         } else {
720             // the top of the stack was the first unmatched node -> step up
721             if (ns->len == 1) {
722                 // root was unmatched already
723                 if (un_key < un_leaf)
724                     return KNOT_ENOENT;
725                 ERR_RETURN(ns_last_leaf(ns));
726                 goto success;
727             }
728             --ns->len;
729             t = ns->stack[ns->len - 1];
730         }
731         // Now we re-do the first "non-matching" step in the trie
732         // but try the previous child if key was less (it may not exist)
733         bitmap_t b = twigbit(t, key, len);
734         int      i = hastwig(t, b)
735                     ? twigoff(t, b) - (un_key < un_leaf)
736                     : twigoff(t, b) - 1 /*twigoff returns successor when !hastwig*/;
737         if (i >= 0) {
738             ERR_RETURN(ns_longer(ns));
739             ns->stack[ns->len++] = twig(t, i);
740             ERR_RETURN(ns_last_leaf(ns));
741         } else {
742             ERR_RETURN(ns_prev_leaf(ns));
743         }
744     success:
745         assert(!isbranch(ns->stack[ns->len - 1]));
746         *val = &ns->stack[ns->len - 1]->leaf.val;
747         return 1;
748     }
749 }
750 
751 /*! \brief Initialize a new leaf, copying the key, and returning failure code. */
mk_leaf(node_t * leaf,const uint8_t * key,uint32_t len,knot_mm_t * mm)752 static int mk_leaf(node_t* leaf, const uint8_t* key, uint32_t len, knot_mm_t* mm)
753 {
754     tkey_t* k = mm_alloc(mm, sizeof(tkey_t) + len);
755 #if FLAGS_HACK
756     assert(((uintptr_t)k) % 4 == 0); // we need an aligned pointer
757 #endif
758     if (unlikely(!k))
759         return KNOT_ENOMEM;
760     k->len = len;
761     memcpy(k->chars, key, len);
762     leaf->leaf = (leaf_t)
763     {
764 #if !FLAGS_HACK
765         .flags = 0,
766 #endif
767         .val = NULL,
768         .key = k
769     };
770     return KNOT_EOK;
771 }
772 
trie_get_ins(trie_t * tbl,const uint8_t * key,uint32_t len)773 trie_val_t* trie_get_ins(trie_t* tbl, const uint8_t* key, uint32_t len)
774 {
775     assert(tbl);
776     // First leaf in an empty tbl?
777     if (unlikely(!tbl->weight)) {
778         if (unlikely(mk_leaf(&tbl->root, key, len, &tbl->mm)))
779             return NULL;
780         ++tbl->weight;
781         return &tbl->root.leaf.val;
782     }
783     { // Intentionally un-indented; until end of function, to bound cleanup attr.
784         // Find the branching-point
785         __attribute__((cleanup(ns_cleanup)))
786         nstack_t ns_local;
787         ns_init(&ns_local, tbl);
788         nstack_t* ns = &ns_local;
789         branch_t  bp; // branch-point: index and flags signifying the longest common prefix
790         int       k2; // the first unmatched character in the leaf
791         if (unlikely(ns_find_branch(ns, key, len, &bp, &k2)))
792             return NULL;
793         node_t* t = ns->stack[ns->len - 1];
794         if (bp.flags == 0) // the same key was already present
795             return &t->leaf.val;
796         node_t leaf;
797         if (unlikely(mk_leaf(&leaf, key, len, &tbl->mm)))
798             return NULL;
799 
800         if (isbranch(t) && bp.index == t->branch.index && bp.flags == t->branch.flags) {
801             // The node t needs a new leaf child.
802             bitmap_t b1 = twigbit(t, key, len);
803             assert(!hastwig(t, b1));
804             uint s, m;
805             TWIGOFFMAX(s, m, t, b1); // new child position and original child count
806             node_t* twigs = mm_realloc(&tbl->mm, t->branch.twigs,
807                 sizeof(node_t) * (m + 1), sizeof(node_t) * m);
808             if (unlikely(!twigs))
809                 goto err_leaf;
810             memmove(twigs + s + 1, twigs + s, sizeof(node_t) * (m - s));
811             twigs[s]        = leaf;
812             t->branch.twigs = twigs;
813             t->branch.bitmap |= b1;
814             ++tbl->weight;
815             return &twigs[s].leaf.val;
816         } else {
817 // We need to insert a new binary branch with leaf at *t.
818 // Note: it works the same for the case where we insert above root t.
819 #ifndef NDEBUG
820             if (ns->len > 1) {
821                 node_t* pt = ns->stack[ns->len - 2];
822                 assert(hastwig(pt, twigbit(pt, key, len)));
823             }
824 #endif
825             node_t* twigs = mm_alloc(&tbl->mm, sizeof(node_t) * 2);
826             if (unlikely(!twigs))
827                 goto err_leaf;
828             node_t t2                = *t; // Save before overwriting t.
829             t->branch.flags          = bp.flags;
830             t->branch.index          = bp.index;
831             t->branch.twigs          = twigs;
832             bitmap_t b1              = twigbit(t, key, len);
833             bitmap_t b2              = unlikely(k2 == -256) ? (1 << 0) : nibbit(k2, bp.flags);
834             t->branch.bitmap         = b1 | b2;
835             *twig(t, twigoff(t, b1)) = leaf;
836             *twig(t, twigoff(t, b2)) = t2;
837             ++tbl->weight;
838             return &twig(t, twigoff(t, b1))->leaf.val;
839         };
840     err_leaf:
841         mm_free(&tbl->mm, leaf.leaf.key);
842         return NULL;
843     }
844 }
845 
846 /*! \brief Apply a function to every trie_val_t*, in order; a recursive solution. */
apply_trie(node_t * t,int (* f)(trie_val_t *,void *),void * d)847 static int apply_trie(node_t* t, int (*f)(trie_val_t*, void*), void* d)
848 {
849     assert(t);
850     if (!isbranch(t))
851         return f(&t->leaf.val, d);
852     int child_count = bitmap_weight(t->branch.bitmap);
853     int i;
854     for (i = 0; i < child_count; ++i)
855         ERR_RETURN(apply_trie(twig(t, i), f, d));
856     return KNOT_EOK;
857 }
858 
trie_apply(trie_t * tbl,int (* f)(trie_val_t *,void *),void * d)859 int trie_apply(trie_t* tbl, int (*f)(trie_val_t*, void*), void* d)
860 {
861     assert(tbl && f);
862     if (!tbl->weight)
863         return KNOT_EOK;
864     return apply_trie(&tbl->root, f, d);
865 }
866 
867 /* These are all thin wrappers around static Tns* functions. */
trie_it_begin(trie_t * tbl)868 trie_it_t* trie_it_begin(trie_t* tbl)
869 {
870     assert(tbl);
871     trie_it_t* it = malloc(sizeof(nstack_t));
872     if (!it)
873         return NULL;
874     ns_init(it, tbl);
875     if (it->len == 0) // empty tbl
876         return it;
877     if (ns_first_leaf(it)) {
878         ns_cleanup(it);
879         free(it);
880         return NULL;
881     }
882     return it;
883 }
884 
trie_it_next(trie_it_t * it)885 void trie_it_next(trie_it_t* it)
886 {
887     assert(it && it->len);
888     if (ns_next_leaf(it) != KNOT_EOK)
889         it->len = 0;
890 }
891 
trie_it_finished(trie_it_t * it)892 bool trie_it_finished(trie_it_t* it)
893 {
894     assert(it);
895     return it->len == 0;
896 }
897 
trie_it_free(trie_it_t * it)898 void trie_it_free(trie_it_t* it)
899 {
900     if (!it)
901         return;
902     ns_cleanup(it);
903     free(it);
904 }
905 
trie_it_key(trie_it_t * it,size_t * len)906 const uint8_t* trie_it_key(trie_it_t* it, size_t* len)
907 {
908     assert(it && it->len);
909     node_t* t = it->stack[it->len - 1];
910     assert(!isbranch(t));
911     tkey_t* key = t->leaf.key;
912     if (len)
913         *len = key->len;
914     return key->chars;
915 }
916 
trie_it_val(trie_it_t * it)917 trie_val_t* trie_it_val(trie_it_t* it)
918 {
919     assert(it && it->len);
920     node_t* t = it->stack[it->len - 1];
921     assert(!isbranch(t));
922     return &t->leaf.val;
923 }
924