1 /*
2  * Copyright (c) 2015 Andrew Kelley
3  *
4  * This file is part of zig, which is MIT licensed.
5  * See http://opensource.org/licenses/MIT
6  */
7 
8 #ifndef ZIG_HASH_MAP_HPP
9 #define ZIG_HASH_MAP_HPP
10 
11 #include "util.hpp"
12 
13 #include <stdint.h>
14 
15 template<typename K>
16 struct MakePointer {
17     typedef K const *Type;
convertMakePointer18     static Type convert(K const &val) {
19         return &val;
20     }
21 };
22 
23 template<typename K>
24 struct MakePointer<K*> {
25     typedef K *Type;
convertMakePointer26     static Type convert(K * const &val) {
27         return val;
28     }
29 };
30 
31 template<typename K>
32 struct MakePointer<K const *> {
33     typedef K const *Type;
convertMakePointer34     static Type convert(K const * const &val) {
35         return val;
36     }
37 };
38 
39 template<typename K, typename V,
40     uint32_t (*HashFunction)(typename MakePointer<K>::Type key),
41     bool (*EqualFn)(typename MakePointer<K>::Type a, typename MakePointer<K>::Type b)>
42 class HashMap {
43 public:
init(int capacity)44     void init(int capacity) {
45         init_capacity(capacity);
46     }
deinit(void)47     void deinit(void) {
48         _entries.deinit();
49         heap::c_allocator.deallocate(_index_bytes,
50                 _indexes_len * capacity_index_size(_indexes_len));
51     }
52 
53     struct Entry {
54         uint32_t hash;
55         uint32_t distance_from_start_index;
56         K key;
57         V value;
58     };
59 
clear()60     void clear() {
61         _entries.clear();
62         memset(_index_bytes, 0, _indexes_len * capacity_index_size(_indexes_len));
63         _max_distance_from_start_index = 0;
64         _modification_count += 1;
65     }
66 
size() const67     size_t size() const {
68         return _entries.length;
69     }
70 
put(const K & key,const V & value)71     void put(const K &key, const V &value) {
72         _modification_count += 1;
73 
74         // This allows us to take a pointer to an entry in `internal_put` which
75         // will not become a dead pointer when the array list is appended.
76         _entries.ensure_capacity(_entries.length + 1);
77 
78         if (_index_bytes == nullptr) {
79             if (_entries.length < 16) {
80                 _entries.append({HashFunction(MakePointer<K>::convert(key)), 0, key, value});
81                 return;
82             } else {
83                 _indexes_len = 32;
84                 _index_bytes = heap::c_allocator.allocate<uint8_t>(_indexes_len);
85                 _max_distance_from_start_index = 0;
86                 for (size_t i = 0; i < _entries.length; i += 1) {
87                     Entry *entry = &_entries.items[i];
88                     put_index(entry, i, _index_bytes);
89                 }
90                 return internal_put(key, value, _index_bytes);
91             }
92         }
93 
94         // if we would get too full (60%), double the indexes size
95         if ((_entries.length + 1) * 5 >= _indexes_len * 3) {
96             heap::c_allocator.deallocate(_index_bytes,
97                     _indexes_len * capacity_index_size(_indexes_len));
98             _indexes_len *= 2;
99             size_t sz = capacity_index_size(_indexes_len);
100             // This zero initializes the bytes, setting them all empty.
101             _index_bytes = heap::c_allocator.allocate<uint8_t>(_indexes_len * sz);
102             _max_distance_from_start_index = 0;
103             for (size_t i = 0; i < _entries.length; i += 1) {
104                 Entry *entry = &_entries.items[i];
105                 switch (sz) {
106                     case 1:
107                         put_index(entry, i, (uint8_t*)_index_bytes);
108                         continue;
109                     case 2:
110                         put_index(entry, i, (uint16_t*)_index_bytes);
111                         continue;
112                     case 4:
113                         put_index(entry, i, (uint32_t*)_index_bytes);
114                         continue;
115                     default:
116                         put_index(entry, i, (size_t*)_index_bytes);
117                         continue;
118                 }
119             }
120         }
121 
122         switch (capacity_index_size(_indexes_len)) {
123             case 1: return internal_put(key, value, (uint8_t*)_index_bytes);
124             case 2: return internal_put(key, value, (uint16_t*)_index_bytes);
125             case 4: return internal_put(key, value, (uint32_t*)_index_bytes);
126             default: return internal_put(key, value, (size_t*)_index_bytes);
127         }
128     }
129 
put_unique(const K & key,const V & value)130     Entry *put_unique(const K &key, const V &value) {
131         // TODO make this more efficient
132         Entry *entry = internal_get(key);
133         if (entry)
134             return entry;
135         put(key, value);
136         return nullptr;
137     }
138 
get(const K & key) const139     const V &get(const K &key) const {
140         Entry *entry = internal_get(key);
141         if (!entry)
142             zig_panic("key not found");
143         return entry->value;
144     }
145 
maybe_get(const K & key) const146     Entry *maybe_get(const K &key) const {
147         return internal_get(key);
148     }
149 
remove(const K & key)150     bool remove(const K &key) {
151         bool deleted_something = maybe_remove(key);
152         if (!deleted_something)
153             zig_panic("key not found");
154         return deleted_something;
155     }
156 
maybe_remove(const K & key)157     bool maybe_remove(const K &key) {
158         _modification_count += 1;
159         if (_index_bytes == nullptr) {
160             uint32_t hash = HashFunction(MakePointer<K>::convert(key));
161             for (size_t i = 0; i < _entries.length; i += 1) {
162                 if (_entries.items[i].hash == hash && EqualFn(MakePointer<K>::convert(_entries.items[i].key), MakePointer<K>::convert(key))) {
163                     _entries.swap_remove(i);
164                     return true;
165                 }
166             }
167             return false;
168         }
169         switch (capacity_index_size(_indexes_len)) {
170             case 1: return internal_remove(key, (uint8_t*)_index_bytes);
171             case 2: return internal_remove(key, (uint16_t*)_index_bytes);
172             case 4: return internal_remove(key, (uint32_t*)_index_bytes);
173             default: return internal_remove(key, (size_t*)_index_bytes);
174         }
175     }
176 
177     class Iterator {
178     public:
next()179         Entry *next() {
180             if (_inital_modification_count != _table->_modification_count)
181                 zig_panic("concurrent modification");
182             if (_index >= _table->_entries.length)
183                 return nullptr;
184             Entry *entry = &_table->_entries.items[_index];
185             _index += 1;
186             return entry;
187         }
188     private:
189         const HashMap * _table;
190         // iterator through the entry array
191         size_t _index = 0;
192         // used to detect concurrent modification
193         uint32_t _inital_modification_count;
Iterator(const HashMap * table)194         Iterator(const HashMap * table) :
195                 _table(table), _inital_modification_count(table->_modification_count) {
196         }
197         friend HashMap;
198     };
199 
200     // you must not modify the underlying HashMap while this iterator is still in use
entry_iterator() const201     Iterator entry_iterator() const {
202         return Iterator(this);
203     }
204 
205 private:
206     // Maintains insertion order.
207     ZigList<Entry> _entries;
208     // If _indexes_len is less than 2**8, this is an array of uint8_t.
209     // If _indexes_len is less than 2**16, it is an array of uint16_t.
210     // If _indexes_len is less than 2**32, it is an array of uint32_t.
211     // Otherwise it is size_t.
212     // It's off by 1. 0 means empty slot, 1 means index 0, etc.
213     uint8_t *_index_bytes;
214     // This is the number of indexes. When indexes are bytes, it equals number of bytes.
215     // When indexes are uint16_t, _indexes_len is half the number of bytes.
216     size_t _indexes_len;
217 
218     size_t _max_distance_from_start_index;
219     // This is used to detect bugs where a hashtable is edited while an iterator is running.
220     uint32_t _modification_count;
221 
init_capacity(size_t capacity)222     void init_capacity(size_t capacity) {
223         _entries = {};
224         _entries.ensure_capacity(capacity);
225         _indexes_len = 0;
226         if (capacity >= 16) {
227             // So that at capacity it will only be 60% full.
228             _indexes_len = capacity * 5 / 3;
229             size_t sz = capacity_index_size(_indexes_len);
230             // This zero initializes _index_bytes which sets them all to empty.
231             _index_bytes = heap::c_allocator.allocate<uint8_t>(_indexes_len * sz);
232         } else {
233             _index_bytes = nullptr;
234         }
235 
236         _max_distance_from_start_index = 0;
237         _modification_count = 0;
238     }
239 
capacity_index_size(size_t len)240     static size_t capacity_index_size(size_t len) {
241         if (len < UINT8_MAX)
242             return 1;
243         if (len < UINT16_MAX)
244             return 2;
245         if (len < UINT32_MAX)
246             return 4;
247         return sizeof(size_t);
248     }
249 
250     template <typename I>
internal_put(const K & key,const V & value,I * indexes)251     void internal_put(const K &key, const V &value, I *indexes) {
252         uint32_t hash = HashFunction(MakePointer<K>::convert(key));
253         uint32_t distance_from_start_index = 0;
254         size_t start_index = hash_to_index(hash);
255         for (size_t roll_over = 0; roll_over < _indexes_len;
256                 roll_over += 1, distance_from_start_index += 1)
257         {
258             size_t index_index = (start_index + roll_over) % _indexes_len;
259             I index_data = indexes[index_index];
260             if (index_data == 0) {
261                 _entries.append_assuming_capacity({ hash, distance_from_start_index, key, value });
262                 indexes[index_index] = _entries.length;
263                 if (distance_from_start_index > _max_distance_from_start_index)
264                     _max_distance_from_start_index = distance_from_start_index;
265                 return;
266             }
267             // This pointer survives the following append because we call
268             // _entries.ensure_capacity before internal_put.
269             Entry *entry = &_entries.items[index_data - 1];
270             if (entry->hash == hash && EqualFn(MakePointer<K>::convert(entry->key), MakePointer<K>::convert(key))) {
271                 *entry = {hash, distance_from_start_index, key, value};
272                 if (distance_from_start_index > _max_distance_from_start_index)
273                     _max_distance_from_start_index = distance_from_start_index;
274                 return;
275             }
276             if (entry->distance_from_start_index < distance_from_start_index) {
277                 // In this case, we did not find the item. We will put a new entry.
278                 // However, we will use this index for the new entry, and move
279                 // the previous index down the line, to keep the _max_distance_from_start_index
280                 // as small as possible.
281                 _entries.append_assuming_capacity({ hash, distance_from_start_index, key, value });
282                 indexes[index_index] = _entries.length;
283                 if (distance_from_start_index > _max_distance_from_start_index)
284                     _max_distance_from_start_index = distance_from_start_index;
285 
286                 distance_from_start_index = entry->distance_from_start_index;
287 
288                 // Find somewhere to put the index we replaced by shifting
289                 // following indexes backwards.
290                 roll_over += 1;
291                 distance_from_start_index += 1;
292                 for (; roll_over < _indexes_len; roll_over += 1, distance_from_start_index += 1) {
293                     size_t index_index = (start_index + roll_over) % _indexes_len;
294                     I next_index_data = indexes[index_index];
295                     if (next_index_data == 0) {
296                         if (distance_from_start_index > _max_distance_from_start_index)
297                             _max_distance_from_start_index = distance_from_start_index;
298                         entry->distance_from_start_index = distance_from_start_index;
299                         indexes[index_index] = index_data;
300                         return;
301                     }
302                     Entry *next_entry = &_entries.items[next_index_data - 1];
303                     if (next_entry->distance_from_start_index < distance_from_start_index) {
304                         if (distance_from_start_index > _max_distance_from_start_index)
305                             _max_distance_from_start_index = distance_from_start_index;
306                         entry->distance_from_start_index = distance_from_start_index;
307                         indexes[index_index] = index_data;
308                         distance_from_start_index = next_entry->distance_from_start_index;
309                         entry = next_entry;
310                         index_data = next_index_data;
311                     }
312                 }
313                 zig_unreachable();
314             }
315         }
316         zig_unreachable();
317     }
318 
319     template <typename I>
put_index(Entry * entry,size_t entry_index,I * indexes)320     void put_index(Entry *entry, size_t entry_index, I *indexes) {
321         size_t start_index = hash_to_index(entry->hash);
322         size_t index_data = entry_index + 1;
323         for (size_t roll_over = 0, distance_from_start_index = 0;
324                 roll_over < _indexes_len; roll_over += 1, distance_from_start_index += 1)
325         {
326             size_t index_index = (start_index + roll_over) % _indexes_len;
327             size_t next_index_data = indexes[index_index];
328             if (next_index_data == 0) {
329                 if (distance_from_start_index > _max_distance_from_start_index)
330                     _max_distance_from_start_index = distance_from_start_index;
331                 entry->distance_from_start_index = distance_from_start_index;
332                 indexes[index_index] = index_data;
333                 return;
334             }
335             Entry *next_entry = &_entries.items[next_index_data - 1];
336             if (next_entry->distance_from_start_index < distance_from_start_index) {
337                 if (distance_from_start_index > _max_distance_from_start_index)
338                     _max_distance_from_start_index = distance_from_start_index;
339                 entry->distance_from_start_index = distance_from_start_index;
340                 indexes[index_index] = index_data;
341                 distance_from_start_index = next_entry->distance_from_start_index;
342                 entry = next_entry;
343                 index_data = next_index_data;
344             }
345         }
346         zig_unreachable();
347     }
348 
internal_get(const K & key) const349     Entry *internal_get(const K &key) const {
350         if (_index_bytes == nullptr) {
351             uint32_t hash = HashFunction(MakePointer<K>::convert(key));
352             for (size_t i = 0; i < _entries.length; i += 1) {
353                 if (_entries.items[i].hash == hash && EqualFn(MakePointer<K>::convert(_entries.items[i].key), MakePointer<K>::convert(key))) {
354                     return &_entries.items[i];
355                 }
356             }
357             return nullptr;
358         }
359         switch (capacity_index_size(_indexes_len)) {
360             case 1: return internal_get2(key, (uint8_t*)_index_bytes);
361             case 2: return internal_get2(key, (uint16_t*)_index_bytes);
362             case 4: return internal_get2(key, (uint32_t*)_index_bytes);
363             default: return internal_get2(key, (size_t*)_index_bytes);
364         }
365     }
366 
367     template <typename I>
internal_get2(const K & key,I * indexes) const368     Entry *internal_get2(const K &key, I *indexes) const {
369         uint32_t hash = HashFunction(MakePointer<K>::convert(key));
370         size_t start_index = hash_to_index(hash);
371         for (size_t roll_over = 0; roll_over <= _max_distance_from_start_index; roll_over += 1) {
372             size_t index_index = (start_index + roll_over) % _indexes_len;
373             size_t index_data = indexes[index_index];
374             if (index_data == 0)
375                 return nullptr;
376 
377             Entry *entry = &_entries.items[index_data - 1];
378             if (entry->hash == hash && EqualFn(MakePointer<K>::convert(entry->key), MakePointer<K>::convert(key)))
379                 return entry;
380         }
381         return nullptr;
382     }
383 
hash_to_index(uint32_t hash) const384     size_t hash_to_index(uint32_t hash) const {
385         return ((size_t)hash) % _indexes_len;
386     }
387 
388     template <typename I>
internal_remove(const K & key,I * indexes)389     bool internal_remove(const K &key, I *indexes) {
390         uint32_t hash = HashFunction(MakePointer<K>::convert(key));
391         size_t start_index = hash_to_index(hash);
392         for (size_t roll_over = 0; roll_over <= _max_distance_from_start_index; roll_over += 1) {
393             size_t index_index = (start_index + roll_over) % _indexes_len;
394             size_t index_data = indexes[index_index];
395             if (index_data == 0)
396                 return false;
397 
398             size_t index = index_data - 1;
399             Entry *entry = &_entries.items[index];
400             if (entry->hash != hash || !EqualFn(MakePointer<K>::convert(entry->key), MakePointer<K>::convert(key)))
401                 continue;
402 
403             size_t prev_index = index_index;
404             _entries.swap_remove(index);
405             if (_entries.length > 0 && _entries.length != index) {
406                 // Because of the swap remove, now we need to update the index that was
407                 // pointing to the last entry and is now pointing to this removed item slot.
408                 update_entry_index(_entries.length, index, indexes);
409             }
410 
411             // Now we have to shift over the following indexes.
412             roll_over += 1;
413             for (; roll_over < _indexes_len; roll_over += 1) {
414                 size_t next_index = (start_index + roll_over) % _indexes_len;
415                 if (indexes[next_index] == 0) {
416                     indexes[prev_index] = 0;
417                     return true;
418                 }
419                 Entry *next_entry = &_entries.items[indexes[next_index] - 1];
420                 if (next_entry->distance_from_start_index == 0) {
421                     indexes[prev_index] = 0;
422                     return true;
423                 }
424                 indexes[prev_index] = indexes[next_index];
425                 prev_index = next_index;
426                 next_entry->distance_from_start_index -= 1;
427             }
428             zig_unreachable();
429         }
430         return false;
431     }
432 
433     template <typename I>
update_entry_index(size_t old_entry_index,size_t new_entry_index,I * indexes)434     void update_entry_index(size_t old_entry_index, size_t new_entry_index, I *indexes) {
435         size_t start_index = hash_to_index(_entries.items[new_entry_index].hash);
436         for (size_t roll_over = 0; roll_over <= _max_distance_from_start_index; roll_over += 1) {
437             size_t index_index = (start_index + roll_over) % _indexes_len;
438             if (indexes[index_index] == old_entry_index + 1) {
439                 indexes[index_index] = new_entry_index + 1;
440                 return;
441             }
442         }
443         zig_unreachable();
444     }
445 };
446 #endif
447