1 /*
2  * Copyright (c) 2015 Andrew Kelley
3  *
4  * This file is part of zig, which is MIT licensed.
5  * See http://opensource.org/licenses/MIT
6  */
7 
8 #ifndef ZIG_MEM_HASH_MAP_HPP
9 #define ZIG_MEM_HASH_MAP_HPP
10 
11 #include "mem.hpp"
12 
13 namespace mem {
14 
15 template<typename K, typename V, uint32_t (*HashFunction)(K key), bool (*EqualFn)(K a, K b)>
16 class HashMap {
17 public:
init(Allocator & allocator,int capacity)18     void init(Allocator& allocator, int capacity) {
19         init_capacity(allocator, capacity);
20     }
deinit(Allocator & allocator)21     void deinit(Allocator& allocator) {
22         allocator.deallocate(_entries, _capacity);
23     }
24 
25     struct Entry {
26         K key;
27         V value;
28         bool used;
29         int distance_from_start_index;
30     };
31 
clear()32     void clear() {
33         for (int i = 0; i < _capacity; i += 1) {
34             _entries[i].used = false;
35         }
36         _size = 0;
37         _max_distance_from_start_index = 0;
38         _modification_count += 1;
39     }
40 
size() const41     int size() const {
42         return _size;
43     }
44 
put(Allocator & allocator,const K & key,const V & value)45     void put(Allocator& allocator, const K &key, const V &value) {
46         _modification_count += 1;
47         internal_put(key, value);
48 
49         // if we get too full (60%), double the capacity
50         if (_size * 5 >= _capacity * 3) {
51             Entry *old_entries = _entries;
52             int old_capacity = _capacity;
53             init_capacity(allocator, _capacity * 2);
54             // dump all of the old elements into the new table
55             for (int i = 0; i < old_capacity; i += 1) {
56                 Entry *old_entry = &old_entries[i];
57                 if (old_entry->used)
58                     internal_put(old_entry->key, old_entry->value);
59             }
60             allocator.deallocate(old_entries, old_capacity);
61         }
62     }
63 
put_unique(Allocator & allocator,const K & key,const V & value)64     Entry *put_unique(Allocator& allocator, const K &key, const V &value) {
65         // TODO make this more efficient
66         Entry *entry = internal_get(key);
67         if (entry)
68             return entry;
69         put(allocator, key, value);
70         return nullptr;
71     }
72 
get(const K & key) const73     const V &get(const K &key) const {
74         Entry *entry = internal_get(key);
75         if (!entry)
76             zig_panic("key not found");
77         return entry->value;
78     }
79 
maybe_get(const K & key) const80     Entry *maybe_get(const K &key) const {
81         return internal_get(key);
82     }
83 
maybe_remove(const K & key)84     void maybe_remove(const K &key) {
85         if (maybe_get(key)) {
86             remove(key);
87         }
88     }
89 
remove(const K & key)90     void remove(const K &key) {
91         _modification_count += 1;
92         int start_index = key_to_index(key);
93         for (int roll_over = 0; roll_over <= _max_distance_from_start_index; roll_over += 1) {
94             int index = (start_index + roll_over) % _capacity;
95             Entry *entry = &_entries[index];
96 
97             if (!entry->used)
98                 zig_panic("key not found");
99 
100             if (!EqualFn(entry->key, key))
101                 continue;
102 
103             for (; roll_over < _capacity; roll_over += 1) {
104                 int next_index = (start_index + roll_over + 1) % _capacity;
105                 Entry *next_entry = &_entries[next_index];
106                 if (!next_entry->used || next_entry->distance_from_start_index == 0) {
107                     entry->used = false;
108                     _size -= 1;
109                     return;
110                 }
111                 *entry = *next_entry;
112                 entry->distance_from_start_index -= 1;
113                 entry = next_entry;
114             }
115             zig_panic("shifting everything in the table");
116         }
117         zig_panic("key not found");
118     }
119 
120     class Iterator {
121     public:
next()122         Entry *next() {
123             if (_inital_modification_count != _table->_modification_count)
124                 zig_panic("concurrent modification");
125             if (_count >= _table->size())
126                 return NULL;
127             for (; _index < _table->_capacity; _index += 1) {
128                 Entry *entry = &_table->_entries[_index];
129                 if (entry->used) {
130                     _index += 1;
131                     _count += 1;
132                     return entry;
133                 }
134             }
135             zig_panic("no next item");
136         }
137 
138     private:
139         const HashMap * _table;
140         // how many items have we returned
141         int _count = 0;
142         // iterator through the entry array
143         int _index = 0;
144         // used to detect concurrent modification
145         uint32_t _inital_modification_count;
Iterator(const HashMap * table)146         Iterator(const HashMap * table) :
147                 _table(table), _inital_modification_count(table->_modification_count) {
148         }
149         friend HashMap;
150     };
151 
152     // you must not modify the underlying HashMap while this iterator is still in use
entry_iterator() const153     Iterator entry_iterator() const {
154         return Iterator(this);
155     }
156 
157 private:
158     Entry *_entries;
159     int _capacity;
160     int _size;
161     int _max_distance_from_start_index;
162     // this is used to detect bugs where a hashtable is edited while an iterator is running.
163     uint32_t _modification_count;
164 
init_capacity(Allocator & allocator,int capacity)165     void init_capacity(Allocator& allocator, int capacity) {
166         _capacity = capacity;
167         _entries = allocator.allocate<Entry>(_capacity);
168         _size = 0;
169         _max_distance_from_start_index = 0;
170         for (int i = 0; i < _capacity; i += 1) {
171             _entries[i].used = false;
172         }
173     }
174 
internal_put(K key,V value)175     void internal_put(K key, V value) {
176         int start_index = key_to_index(key);
177         for (int roll_over = 0, distance_from_start_index = 0;
178                 roll_over < _capacity; roll_over += 1, distance_from_start_index += 1)
179         {
180             int index = (start_index + roll_over) % _capacity;
181             Entry *entry = &_entries[index];
182 
183             if (entry->used && !EqualFn(entry->key, key)) {
184                 if (entry->distance_from_start_index < distance_from_start_index) {
185                     // robin hood to the rescue
186                     Entry tmp = *entry;
187                     if (distance_from_start_index > _max_distance_from_start_index)
188                         _max_distance_from_start_index = distance_from_start_index;
189                     *entry = {
190                         key,
191                         value,
192                         true,
193                         distance_from_start_index,
194                     };
195                     key = tmp.key;
196                     value = tmp.value;
197                     distance_from_start_index = tmp.distance_from_start_index;
198                 }
199                 continue;
200             }
201 
202             if (!entry->used) {
203                 // adding an entry. otherwise overwriting old value with
204                 // same key
205                 _size += 1;
206             }
207 
208             if (distance_from_start_index > _max_distance_from_start_index)
209                 _max_distance_from_start_index = distance_from_start_index;
210             *entry = {
211                 key,
212                 value,
213                 true,
214                 distance_from_start_index,
215             };
216             return;
217         }
218         zig_panic("put into a full HashMap");
219     }
220 
221 
internal_get(const K & key) const222     Entry *internal_get(const K &key) const {
223         int start_index = key_to_index(key);
224         for (int roll_over = 0; roll_over <= _max_distance_from_start_index; roll_over += 1) {
225             int index = (start_index + roll_over) % _capacity;
226             Entry *entry = &_entries[index];
227 
228             if (!entry->used)
229                 return NULL;
230 
231             if (EqualFn(entry->key, key))
232                 return entry;
233         }
234         return NULL;
235     }
236 
key_to_index(const K & key) const237     int key_to_index(const K &key) const {
238         return (int)(HashFunction(key) % ((uint32_t)_capacity));
239     }
240 };
241 
242 } // namespace mem
243 
244 #endif
245