1 /*
2   This file is part of MADNESS.
3 
4   Copyright (C) 2007,2010 Oak Ridge National Laboratory
5 
6   This program is free software; you can redistribute it and/or modify
7   it under the terms of the GNU General Public License as published by
8   the Free Software Foundation; either version 2 of the License, or
9   (at your option) any later version.
10 
11   This program is distributed in the hope that it will be useful,
12   but WITHOUT ANY WARRANTY; without even the implied warranty of
13   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14   GNU General Public License for more details.
15 
16   You should have received a copy of the GNU General Public License
17   along with this program; if not, write to the Free Software
18   Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA
19 
20   For more information please contact:
21 
22   Robert J. Harrison
23   Oak Ridge National Laboratory
24   One Bethel Valley Road
25   P.O. Box 2008, MS-6367
26 
27   email: harrisonrj@ornl.gov
28   tel:   865-241-3937
29   fax:   865-572-0680
30 */
31 
32 #ifndef MADNESS_WORLD_WORLDHASHMAP_H__INCLUDED
33 #define MADNESS_WORLD_WORLDHASHMAP_H__INCLUDED
34 
35 /// \file worldhashmap.h
36 /// \brief Defines and implements a concurrent hashmap
37 
38 
39 // Why does this exist?  It's a bridge from where we are to where we
40 // want to be, which is a mutlthreaded environment probably
41 // based upon the Intel TBB.  Don't have the resources right now to
42 // bite off the entire TBB but we probably must in the future.
43 // This is a basic, functional-enough, fast-enough hash map with
44 // vague compatibility with the TBB API.
45 
46 #include <madness/world/worldmutex.h>
47 #include <madness/world/madness_exception.h>
48 #include <madness/world/worldhash.h>
49 #include <new>
50 #include <stdio.h>
51 #include <map>
52 
53 namespace madness {
54 
55     template <class keyT, class valueT, class hashT> class ConcurrentHashMap;
56 
57     template <class keyT, class valueT, class hashfunT>
58     class ConcurrentHashMap;
59 
60     namespace Hash_private {
61 
62         // A hashtable is an array of nbin bins.
63         // Each bin is a linked list of entries protected by a spinlock.
64         // Each entry holds a key+value pair, a read-write mutex, and a link to the next entry.
65 
66         template <typename keyT, typename valueT>
67         class entry : public madness::MutexReaderWriter {
68         public:
69             typedef std::pair<const keyT, valueT> datumT;
70             datumT datum;
71 
72             class entry<keyT,valueT> * volatile next;
73 
entry(const datumT & datum,entry<keyT,valueT> * next)74             entry(const datumT& datum, entry<keyT,valueT>* next)
75                     : datum(datum), next(next) {}
76         };
77 
78         template <class keyT, class valueT>
79         class bin : private madness::Spinlock {
80         private:
81             typedef entry<keyT,valueT> entryT;
82             typedef std::pair<const keyT, valueT> datumT;
83             // Could pad here to avoid false sharing of cache line but
84             // perhaps better to just use more bins
85         public:
86 
87             entryT* volatile p;
88             int volatile ninbin;
89 
bin()90             bin() : p(0),ninbin(0) {}
91 
~bin()92             ~bin() {
93                 clear();
94             }
95 
clear()96             void clear() {
97                 lock();             // BEGIN CRITICAL SECTION
98                 while (p) {
99                     entryT* n=p->next;
100                     delete p;
101                     p=n;
102                     ninbin--;
103                 }
104                 MADNESS_ASSERT(ninbin == 0);
105                 unlock();           // END CRITICAL SECTION
106             }
107 
find(const keyT & key,const int lockmode)108             entryT* find(const keyT& key, const int lockmode) const {
109                 bool gotlock;
110                 entryT* result;
111                 madness::MutexWaiter waiter;
112                 do {
113                     lock();             // BEGIN CRITICAL SECTION
114                     result = match(key);
115                     if (result) {
116                         gotlock = result->try_lock(lockmode);
117                     }
118                     else {
119                         gotlock = true;
120                     }
121                     unlock();           // END CRITICAL SECTION
122                     if (!gotlock) waiter.wait(); //cpu_relax();
123                 }
124                 while (!gotlock);
125 
126                 return result;
127             }
128 
insert(const datumT & datum,int lockmode)129             std::pair<entryT*,bool> insert(const datumT& datum, int lockmode) {
130                 bool gotlock;
131                 entryT* result;
132                 bool notfound;
133                 madness::MutexWaiter waiter;
134                 do {
135                     lock();             // BEGIN CRITICAL SECTION
136                     result = match(datum.first);
137                     notfound = !result;
138                     if (notfound) {
139                         result = p = new entryT(datum,p);
140                         ++ninbin;
141                     }
142                     gotlock = result->try_lock(lockmode);
143                     unlock();           // END CRITICAL SECTION
144                     if (!gotlock) waiter.wait(); //cpu_relax();
145                 }
146                 while (!gotlock);
147 
148                 return std::pair<entryT*,bool>(result,notfound);
149             }
150 
del(const keyT & key,int lockmode)151             bool del(const keyT& key, int lockmode) {
152                 bool status = false;
153                 lock();             // BEGIN CRITICAL SECTION
154                 for (entryT *t=p,*prev=0; t; prev=t,t=t->next) {
155                     if (t->datum.first == key) {
156                         if (prev) {
157                             prev->next = t->next;
158                         }
159                         else {
160                             p = t->next;
161                         }
162                         t->unlock(lockmode);
163                         delete t;
164                         --ninbin;
165                         status = true;
166                         break;
167                     }
168                 }
169                 unlock();           // END CRITICAL SECTION
170                 return status;
171             }
172 
size()173             std::size_t size() const {
174                 return ninbin;
175             };
176 
177         private:
match(const keyT & key)178             entryT* match(const keyT& key) const {
179                 entryT* t;
180                 for (t=p; t; t=t->next)
181                     if (t->datum.first == key) break;
182                 return t;
183             }
184 
185         };
186 
187         /// iterator for hash
188         template <class hashT> class HashIterator {
189         public:
190             typedef typename std::conditional<std::is_const<hashT>::value,
191                     typename std::add_const<typename hashT::entryT>::type,
192                     typename hashT::entryT>::type entryT;
193             typedef typename std::conditional<std::is_const<hashT>::value,
194                     typename std::add_const<typename hashT::datumT>::type,
195                     typename hashT::datumT>::type datumT;
196             typedef std::forward_iterator_tag iterator_category;
197             typedef datumT value_type;
198             typedef std::ptrdiff_t difference_type;
199             typedef datumT* pointer;
200             typedef datumT& reference;
201 
202         private:
203             hashT* h;               // Associated hash table
204             int bin;                // Current bin
205             entryT* entry;          // Current entry in bin ... zero means at end
206 
207             template <class otherHashT>
208             friend class HashIterator;
209 
210             /// If the entry is null (end of current bin) finds next non-empty bin
next_non_null_entry()211             void next_non_null_entry() {
212                 while (!entry) {
213                     ++bin;
214                     if ((unsigned) bin == h->nbins) {
215                         entry = 0;
216                         return;
217                     }
218                     entry = h->bins[bin].p;
219                 }
220                 return;
221             }
222 
223         public:
224 
225             /// Makes invalid iterator
HashIterator()226             HashIterator() : h(0), bin(-1), entry(0) {}
227 
228             /// Makes begin/end iterator
HashIterator(hashT * h,bool begin)229             HashIterator(hashT* h, bool begin)
230                     : h(h), bin(-1), entry(0) {
231                 if (begin) next_non_null_entry();
232             }
233 
234             /// Makes iterator to specific entry
HashIterator(hashT * h,int bin,entryT * entry)235             HashIterator(hashT* h, int bin, entryT* entry)
236                     : h(h), bin(bin), entry(entry) {}
237 
238             /// Copy constructor
HashIterator(const HashIterator & other)239             HashIterator(const HashIterator& other)
240                     : h(other.h), bin(other.bin), entry(other.entry) {}
241 
242             /// Implicit conversion of another hash type to this hash type
243 
244             /// This allows implicit conversion from hash types to const hash
245             /// types.
246             template <class otherHashT>
HashIterator(const HashIterator<otherHashT> & other)247             HashIterator(const HashIterator<otherHashT>& other)
248                     : h(other.h), bin(other.bin), entry(other.entry) {}
249 
250             HashIterator& operator++() {
251                 if (!entry) return *this;
252                 entry = entry->next;
253                 next_non_null_entry();
254                 return *this;
255             }
256 
257             HashIterator operator++(int) {
258                 HashIterator old(*this);
259                 operator++();
260                 return old;
261             }
262 
263             /// Difference between iterators \em only supported for this=start and other=end
264 
265             /// This exists to support construction of range for parallel iteration
266             /// over the entire container.
distance(const HashIterator & other)267             int distance(const HashIterator& other) const {
268                 MADNESS_ASSERT(h == other.h  &&  other == h->end()  &&  *this == h->begin());
269                 return h->size();
270             }
271 
272             /// Only positive increments are supported
273 
274             /// This exists to support splitting of range for parallel iteration.
advance(int n)275             void advance(int n) {
276                 if (n==0 || !entry) return;
277                 MADNESS_ASSERT(n>=0);
278 
279                 // Linear increment up to end of this bin
280                 while (n-- && (entry=entry->next)) {}
281                 next_non_null_entry();
282                 if (!entry) return; // end
283 
284                 if (n <= 0) return;
285 
286                 // If here, will point to first entry in
287                 // a bin ... determine which bin contains
288                 // our end point.
289                 while (unsigned(n) >= h->bins[bin].size()) {
290                     n -= h->bins[bin].size();
291                     ++bin;
292                     if (unsigned(bin) == h->nbins) {
293                         entry = 0;
294                         return; // end
295                     }
296                 }
297 
298                 entry = h->bins[bin].p;
299                 MADNESS_ASSERT(entry);
300 
301                 // Linear increment to target
302                 while (n--) entry=entry->next;
303 
304                 return;
305             }
306 
307 
308             bool operator==(const HashIterator& a) const {
309                 return entry==a.entry;
310             }
311 
312             bool operator!=(const HashIterator& a) const {
313                 return entry!=a.entry;
314             }
315 
316             reference operator*() const {
317                 MADNESS_ASSERT(entry);
318                 //if (!entry) throw "Hash iterator: operator*: at end";
319                 return entry->datum;
320             }
321 
322             pointer operator->() const {
323                 MADNESS_ASSERT(entry);
324                 //if (!entry) throw "Hash iterator: operator->: at end";
325                 return &entry->datum;
326             }
327         };
328 
329         template <class hashT, int lockmode>
330         class HashAccessor : private NO_DEFAULTS {
331             template <class a,class b,class c> friend class madness::ConcurrentHashMap;
332         public:
333             typedef typename std::conditional<std::is_const<hashT>::value,
334                     typename std::add_const<typename hashT::entryT>::type,
335                     typename hashT::entryT>::type entryT;
336             typedef typename std::conditional<std::is_const<hashT>::value,
337                     typename std::add_const<typename hashT::datumT>::type,
338                     typename hashT::datumT>::type datumT;
339             typedef datumT value_type;
340             typedef datumT* pointer;
341             typedef datumT& reference;
342 
343         private:
344             entryT* entry;
345             bool gotlock;
346 
347             /// Used by Hash to set entry (assumed that it has the lock already)
set(entryT * entry)348             void set(entryT* entry) {
349                 release();
350                 this->entry = entry;
351                 gotlock = true;
352             }
353 
354             /// Used by Hash after having already released lock and deleted entry
unset()355             void unset() {
356                 gotlock = false;
357                 entry = 0;
358             }
359 
convert_read_lock_to_write_lock()360             void convert_read_lock_to_write_lock() {
361                 if (entry) entry->convert_read_lock_to_write_lock();
362             }
363 
364 
365         public:
HashAccessor()366             HashAccessor() : entry(0), gotlock(false) {}
367 
HashAccessor(entryT * entry)368             HashAccessor(entryT* entry) : entry(entry), gotlock(true) {}
369 
370             datumT& operator*() const {
371                 if (!entry) MADNESS_EXCEPTION("Hash accessor: operator*: no value", 0);
372                 return entry->datum;
373             }
374 
375             datumT* operator->() const {
376                 if (!entry) MADNESS_EXCEPTION("Hash accessor: operator->: no value", 0);
377                 return &entry->datum;
378             }
379 
release()380             void release() {
381                 if (gotlock) {
382                     entry->unlock(lockmode);
383                     entry=0;
384                     gotlock = false;
385                 }
386             }
387 
~HashAccessor()388             ~HashAccessor() {
389                 release();
390             }
391         };
392 
393     } // End of namespace Hash_private
394 
395     template < class keyT, class valueT, class hashfunT = Hash<keyT> >
396     class ConcurrentHashMap {
397     public:
398         typedef ConcurrentHashMap<keyT,valueT,hashfunT> hashT;
399         typedef std::pair<const keyT,valueT> datumT;
400         typedef Hash_private::entry<keyT,valueT> entryT;
401         typedef Hash_private::bin<keyT,valueT> binT;
402         typedef Hash_private::HashIterator<hashT> iterator;
403         typedef Hash_private::HashIterator<const hashT> const_iterator;
404         typedef Hash_private::HashAccessor<hashT,entryT::WRITELOCK> accessor;
405         typedef Hash_private::HashAccessor<const hashT,entryT::READLOCK> const_accessor;
406 
407         friend class Hash_private::HashIterator<hashT>;
408         friend class Hash_private::HashIterator<const hashT>;
409 
410     protected:
411         const size_t nbins;         // Number of bins
412         binT* bins;                 // Array of bins
413 
414     private:
415         hashfunT hashfun;
416 
417         //unsigned int hash(const keyT& key) const {return hashfunT::hash(key)%nbins;}
418 
nbins_prime(int n)419         static int nbins_prime(int n) {
420             static const int primes[] = {11, 23, 31, 41, 53, 61, 71, 83, 101,
421                 131, 181, 239, 293, 359, 421, 557, 673, 821, 953, 1021, 1231,
422                 1531, 1747, 2069, 2543, 3011, 4003, 5011, 6073, 7013, 8053,
423                 9029, 9907, 17401, 27479, 37847, 48623, 59377, 70667, 81839,
424                 93199, 104759, 224759, 350411, 479951, 611969, 746791, 882391,
425                 1299743, 2750171, 4256257, 5800159, 7368811, 8960477, 10570871,
426                 12195269, 13834133};
427             static const int nprimes = sizeof(primes)/sizeof(int);
428             // n is a user provided estimate of the no. of elements to be put
429             // in the table.  Want to make the number of bins a prime number
430             // larger than this.
431             for (int i=0; i<nprimes; ++i) if (n<=primes[i]) return primes[i];
432             return primes[nprimes-1];
433         }
434 
hash_to_bin(const keyT & key)435         unsigned int hash_to_bin(const keyT& key) const {
436             return hashfun(key)%nbins;
437         }
438 
439     public:
440         ConcurrentHashMap(int n=1021, const hashfunT& hf = hashfunT())
nbins(hashT::nbins_prime (n))441                 : nbins(hashT::nbins_prime(n))
442                 , bins(new binT[nbins])
443                 , hashfun(hf) {}
444 
ConcurrentHashMap(const hashT & h)445         ConcurrentHashMap(const  hashT& h)
446                 : nbins(h.nbins)
447                 , bins(new binT[nbins])
448                 , hashfun(h.hashfun) {
449             *this = h;
450         }
451 
~ConcurrentHashMap()452         virtual ~ConcurrentHashMap() {
453             delete [] bins;
454         }
455 
456         hashT& operator=(const  hashT& h) {
457             if (this != &h) {
458                 this->clear();
459                 hashfun = h.hashfun;
460                 for (const_iterator p=h.begin(); p!=h.end(); ++p) {
461                     insert(*p);
462                 }
463             }
464             return *this;
465         }
466 
insert(const datumT & datum)467         std::pair<iterator,bool> insert(const datumT& datum) {
468             int bin = hash_to_bin(datum.first);
469             std::pair<entryT*,bool> result = bins[bin].insert(datum,entryT::NOLOCK);
470             return std::pair<iterator,bool>(iterator(this,bin,result.first),result.second);
471         }
472 
473         /// Returns true if new pair was inserted; false if key is already in the map and the datum was not inserted
insert(accessor & result,const datumT & datum)474         bool insert(accessor& result, const datumT& datum) {
475             result.release();
476             int bin = hash_to_bin(datum.first);
477             std::pair<entryT*,bool> r = bins[bin].insert(datum,entryT::WRITELOCK);
478             result.set(r.first);
479             return r.second;
480         }
481 
482         /// Returns true if new pair was inserted; false if key is already in the map and the datum was not inserted
insert(const_accessor & result,const datumT & datum)483         bool insert(const_accessor& result, const datumT& datum) {
484             result.release();
485             int bin = hash_to_bin(datum.first);
486             std::pair<entryT*,bool> r = bins[bin].insert(datum,entryT::READLOCK);
487             result.set(r.first);
488             return r.second;
489         }
490 
491         /// Returns true if new pair was inserted; false if key is already in the map
insert(accessor & result,const keyT & key)492         inline bool insert(accessor& result, const keyT& key) {
493             return insert(result, datumT(key,valueT()));
494         }
495 
496         /// Returns true if new pair was inserted; false if key is already in the map
insert(const_accessor & result,const keyT & key)497         inline bool insert(const_accessor& result, const keyT& key) {
498             return insert(result, datumT(key,valueT()));
499         }
500 
erase(const keyT & key)501         std::size_t erase(const keyT& key) {
502             if (bins[hash_to_bin(key)].del(key,entryT::NOLOCK)) return 1;
503             else return 0;
504         }
505 
erase(const iterator & it)506         void erase(const iterator& it) {
507             if (it == end()) MADNESS_EXCEPTION("ConcurrentHashMap: erase(iterator): at end", true);
508             erase(it->first);
509         }
510 
erase(accessor & item)511         void erase(accessor& item) {
512             bins[hash_to_bin(item->first)].del(item->first,entryT::WRITELOCK);
513             item.unset();
514         }
515 
erase(const_accessor & item)516         void erase(const_accessor& item) {
517             item.convert_read_lock_to_write_lock();
518             bins[hash_to_bin(item->first)].del(item->first,entryT::WRITELOCK);
519             item.unset();
520         }
521 
find(const keyT & key)522         iterator find(const keyT& key) {
523             int bin = hash_to_bin(key);
524             entryT* entry = bins[bin].find(key,entryT::NOLOCK);
525             if (!entry) return end();
526             else return iterator(this,bin,entry);
527         }
528 
find(const keyT & key)529         const_iterator find(const keyT& key) const {
530             int bin = hash_to_bin(key);
531             const entryT* entry = bins[bin].find(key,entryT::NOLOCK);
532             if (!entry) return end();
533             else return const_iterator(this,bin,entry);
534         }
535 
find(accessor & result,const keyT & key)536         bool find(accessor& result, const keyT& key) {
537             result.release();
538             int bin = hash_to_bin(key);
539             entryT* entry = bins[bin].find(key,entryT::WRITELOCK);
540             bool foundit = entry;
541             if (foundit) result.set(entry);
542             return foundit;
543         }
544 
find(const_accessor & result,const keyT & key)545         bool find(const_accessor& result, const keyT& key) const {
546             result.release();
547             int bin = hash_to_bin(key);
548             entryT* entry = bins[bin].find(key,entryT::READLOCK);
549             bool foundit = entry;
550             if (foundit) result.set(entry);
551             return foundit;
552         }
553 
clear()554         void clear() {
555             for (unsigned int i=0; i<nbins; ++i) bins[i].clear();
556         }
557 
size()558         size_t size() const {
559             size_t sum = 0;
560             for (size_t i=0; i<nbins; ++i) sum += bins[i].size();
561             return sum;
562         }
563 
564         valueT& operator[](const keyT& key) {
565             std::pair<iterator,bool> it = insert(datumT(key,valueT()));
566             return it.first->second;
567         }
568 
begin()569         iterator begin() {
570             return iterator(this,true);
571         }
572 
begin()573         const_iterator begin() const {
574             return cbegin();
575         }
576 
cbegin()577         const_iterator cbegin() const {
578             return const_iterator(this,true);
579         }
580 
end()581         iterator end() {
582             return iterator(this,false);
583         }
584 
end()585         const_iterator end() const {
586             return cend();
587         }
588 
cend()589         const_iterator cend() const {
590             return const_iterator(this,false);
591         }
592 
get_hash()593         hashfunT& get_hash() const { return hashfun; }
594 
print_stats()595         void print_stats() const {
596             for (unsigned int i=0; i<nbins; ++i) {
597                 if (i && (i%10)==0) printf("\n");
598                 printf("%8d", int(bins[i].size()));
599             }
600             printf("\n");
601         }
602     };
603 }
604 
605 namespace std {
606 
607     template <typename hashT, typename distT>
advance(madness::Hash_private::HashIterator<hashT> & it,const distT & dist)608     inline void advance( madness::Hash_private::HashIterator<hashT>& it, const distT& dist ) {
609         //std::cout << " in custom advance \n";
610         it.advance(dist);
611     }
612 
613     template <typename hashT>
distance(const madness::Hash_private::HashIterator<hashT> & it,const madness::Hash_private::HashIterator<hashT> & jt)614     inline int distance(const madness::Hash_private::HashIterator<hashT>& it, const madness::Hash_private::HashIterator<hashT>& jt) {
615         //std::cout << " in custom distance \n";
616         return it.distance(jt);
617     }
618 }
619 
620 #endif // MADNESS_WORLD_WORLDHASHMAP_H__INCLUDED
621