1 #if !defined(__UNIONFIND_H) 2 #define __UNIONFIND_H 3 4 #include <vector> 5 #include <atomic> 6 #include <iostream> 7 8 /** 9 * Lock-free parallel disjoint set data structure (aka UNION-FIND) 10 * with path compression and union by rank 11 * 12 * Supports concurrent find(), same() and unite() calls as described 13 * in the paper 14 * 15 * "Wait-free Parallel Algorithms for the Union-Find Problem" 16 * by Richard J. Anderson and Heather Woll 17 * 18 * In addition, this class supports optimistic locking (try_lock/unlock) 19 * of disjoint sets and a combined unite+unlock operation. 20 * 21 * \author Wenzel Jakob 22 */ 23 class DisjointSets { 24 public: DisjointSets(uint32_t size)25 DisjointSets(uint32_t size) : mData(size) { 26 for (uint32_t i=0; i<size; ++i) 27 mData[i] = (uint32_t) i; 28 } 29 find(uint32_t id)30 uint32_t find(uint32_t id) const { 31 while (id != parent(id)) { 32 uint64_t value = mData[id]; 33 uint32_t new_parent = parent((uint32_t) value); 34 uint64_t new_value = 35 (value & 0xFFFFFFFF00000000ULL) | new_parent; 36 /* Try to update parent (may fail, that's ok) */ 37 if (value != new_value) 38 mData[id].compare_exchange_weak(value, new_value); 39 id = new_parent; 40 } 41 return id; 42 } 43 same(uint32_t id1,uint32_t id2)44 bool same(uint32_t id1, uint32_t id2) const { 45 for (;;) { 46 id1 = find(id1); 47 id2 = find(id2); 48 if (id1 == id2) 49 return true; 50 if (parent(id1) == id1) 51 return false; 52 } 53 } 54 unite(uint32_t id1,uint32_t id2)55 uint32_t unite(uint32_t id1, uint32_t id2) { 56 for (;;) { 57 id1 = find(id1); 58 id2 = find(id2); 59 60 if (id1 == id2) 61 return id1; 62 63 uint32_t r1 = rank(id1), r2 = rank(id2); 64 65 if (r1 > r2 || (r1 == r2 && id1 < id2)) { 66 std::swap(r1, r2); 67 std::swap(id1, id2); 68 } 69 70 uint64_t oldEntry = ((uint64_t) r1 << 32) | id1; 71 uint64_t newEntry = ((uint64_t) r1 << 32) | id2; 72 73 if (!mData[id1].compare_exchange_strong(oldEntry, newEntry)) 74 continue; 75 76 if (r1 == r2) { 77 oldEntry = ((uint64_t) r2 << 32) | id2; 78 newEntry = ((uint64_t) (r2+1) << 32) | id2; 79 /* Try to update the rank (may fail, that's ok) */ 80 mData[id2].compare_exchange_weak(oldEntry, newEntry); 81 } 82 83 break; 84 } 85 return id2; 86 } 87 88 /** 89 * Try to lock the a disjoint union identified by one 90 * of its elements (this can occasionally fail when there 91 * are concurrent operations). The parameter 'id' will be 92 * updated to store the current representative ID of the 93 * union 94 */ try_lock(uint32_t & id)95 bool try_lock(uint32_t &id) { 96 const uint64_t lock_flag = 1ULL << 63; 97 id = find(id); 98 uint64_t value = mData[id]; 99 if ((value & lock_flag) || (uint32_t) value != id) 100 return false; 101 // On IA32/x64, a PAUSE instruction is recommended for CAS busy loops 102 #if defined(__i386__) || defined(__amd64__) 103 __asm__ __volatile__ ("pause\n"); 104 #endif 105 return mData[id].compare_exchange_strong(value, value | lock_flag); 106 } 107 unlock(uint32_t id)108 void unlock(uint32_t id) { 109 const uint64_t lock_flag = 1ULL << 63; 110 mData[id] &= ~lock_flag; 111 } 112 113 /** 114 * Return the representative index of the set that results from merging 115 * locked disjoint sets 'id1' and 'id2' 116 */ unite_index_locked(uint32_t id1,uint32_t id2)117 uint32_t unite_index_locked(uint32_t id1, uint32_t id2) const { 118 uint32_t r1 = rank(id1), r2 = rank(id2); 119 return (r1 > r2 || (r1 == r2 && id1 < id2)) ? id1 : id2; 120 } 121 122 /** 123 * Atomically unite two locked disjoint sets and unlock them. Assumes 124 * that here are no other concurrent unite() involving the same sets 125 */ unite_unlock(uint32_t id1,uint32_t id2)126 uint32_t unite_unlock(uint32_t id1, uint32_t id2) { 127 uint32_t r1 = rank(id1), r2 = rank(id2); 128 129 if (r1 > r2 || (r1 == r2 && id1 < id2)) { 130 std::swap(r1, r2); 131 std::swap(id1, id2); 132 } 133 134 mData[id1] = ((uint64_t) r1 << 32) | id2; 135 mData[id2] = ((uint64_t) (r2 + ((r1 == r2) ? 1 : 0)) << 32) | id2; 136 137 return id2; 138 } 139 size()140 uint32_t size() const { return (uint32_t) mData.size(); } 141 rank(uint32_t id)142 uint32_t rank(uint32_t id) const { 143 return ((uint32_t) (mData[id] >> 32)) & 0x7FFFFFFFu; 144 } 145 parent(uint32_t id)146 uint32_t parent(uint32_t id) const { 147 return (uint32_t) mData[id]; 148 } 149 150 friend std::ostream &operator<<(std::ostream &os, const DisjointSets &f) { 151 for (size_t i=0; i<f.mData.size(); ++i) 152 os << i << ": parent=" << f.parent(i) << ", rank=" << f.rank(i) << std::endl; 153 return os; 154 } 155 156 mutable std::vector<std::atomic<uint64_t>> mData; 157 }; 158 159 #endif /* __UNIONFIND_H */ 160