1 #if !defined(__UNIONFIND_H)
2 #define __UNIONFIND_H
3 
4 #include <vector>
5 #include <atomic>
6 #include <iostream>
7 
8 /**
9  * Lock-free parallel disjoint set data structure (aka UNION-FIND)
10  * with path compression and union by rank
11  *
12  * Supports concurrent find(), same() and unite() calls as described
13  * in the paper
14  *
15  * "Wait-free Parallel Algorithms for the Union-Find Problem"
16  * by Richard J. Anderson and Heather Woll
17  *
18  * In addition, this class supports optimistic locking (try_lock/unlock)
19  * of disjoint sets and a combined unite+unlock operation.
20  *
21  * \author Wenzel Jakob
22  */
23 class DisjointSets {
24 public:
DisjointSets(uint32_t size)25     DisjointSets(uint32_t size) : mData(size) {
26         for (uint32_t i=0; i<size; ++i)
27             mData[i] = (uint32_t) i;
28     }
29 
find(uint32_t id)30     uint32_t find(uint32_t id) const {
31         while (id != parent(id)) {
32             uint64_t value = mData[id];
33             uint32_t new_parent = parent((uint32_t) value);
34             uint64_t new_value =
35                 (value & 0xFFFFFFFF00000000ULL) | new_parent;
36             /* Try to update parent (may fail, that's ok) */
37             if (value != new_value)
38                 mData[id].compare_exchange_weak(value, new_value);
39             id = new_parent;
40         }
41         return id;
42     }
43 
same(uint32_t id1,uint32_t id2)44     bool same(uint32_t id1, uint32_t id2) const {
45         for (;;) {
46             id1 = find(id1);
47             id2 = find(id2);
48             if (id1 == id2)
49                 return true;
50             if (parent(id1) == id1)
51                 return false;
52         }
53     }
54 
unite(uint32_t id1,uint32_t id2)55     uint32_t unite(uint32_t id1, uint32_t id2) {
56         for (;;) {
57             id1 = find(id1);
58             id2 = find(id2);
59 
60             if (id1 == id2)
61                 return id1;
62 
63             uint32_t r1 = rank(id1), r2 = rank(id2);
64 
65             if (r1 > r2 || (r1 == r2 && id1 < id2)) {
66                 std::swap(r1, r2);
67                 std::swap(id1, id2);
68             }
69 
70             uint64_t oldEntry = ((uint64_t) r1 << 32) | id1;
71             uint64_t newEntry = ((uint64_t) r1 << 32) | id2;
72 
73             if (!mData[id1].compare_exchange_strong(oldEntry, newEntry))
74                 continue;
75 
76             if (r1 == r2) {
77                 oldEntry = ((uint64_t) r2 << 32) | id2;
78                 newEntry = ((uint64_t) (r2+1) << 32) | id2;
79                 /* Try to update the rank (may fail, that's ok) */
80                 mData[id2].compare_exchange_weak(oldEntry, newEntry);
81             }
82 
83             break;
84         }
85         return id2;
86     }
87 
88     /**
89      * Try to lock the a disjoint union identified by one
90      * of its elements (this can occasionally fail when there
91      * are concurrent operations). The parameter 'id' will be
92      * updated to store the current representative ID of the
93      * union
94      */
try_lock(uint32_t & id)95     bool try_lock(uint32_t &id) {
96         const uint64_t lock_flag = 1ULL << 63;
97         id = find(id);
98         uint64_t value = mData[id];
99         if ((value & lock_flag) || (uint32_t) value != id)
100             return false;
101         // On IA32/x64, a PAUSE instruction is recommended for CAS busy loops
102         #if defined(__i386__) || defined(__amd64__)
103             __asm__ __volatile__ ("pause\n");
104         #endif
105         return mData[id].compare_exchange_strong(value, value | lock_flag);
106     }
107 
unlock(uint32_t id)108     void unlock(uint32_t id) {
109         const uint64_t lock_flag = 1ULL << 63;
110         mData[id] &= ~lock_flag;
111     }
112 
113     /**
114      * Return the representative index of the set that results from merging
115      * locked disjoint sets 'id1' and 'id2'
116      */
unite_index_locked(uint32_t id1,uint32_t id2)117     uint32_t unite_index_locked(uint32_t id1, uint32_t id2) const {
118         uint32_t r1 = rank(id1), r2 = rank(id2);
119         return (r1 > r2 || (r1 == r2 && id1 < id2)) ? id1 : id2;
120     }
121 
122     /**
123      * Atomically unite two locked disjoint sets and unlock them. Assumes
124      * that here are no other concurrent unite() involving the same sets
125      */
unite_unlock(uint32_t id1,uint32_t id2)126     uint32_t unite_unlock(uint32_t id1, uint32_t id2) {
127         uint32_t r1 = rank(id1), r2 = rank(id2);
128 
129         if (r1 > r2 || (r1 == r2 && id1 < id2)) {
130             std::swap(r1, r2);
131             std::swap(id1, id2);
132         }
133 
134         mData[id1] = ((uint64_t) r1 << 32) | id2;
135         mData[id2] = ((uint64_t) (r2 + ((r1 == r2) ? 1 : 0)) << 32) | id2;
136 
137         return id2;
138     }
139 
size()140     uint32_t size() const { return (uint32_t) mData.size(); }
141 
rank(uint32_t id)142     uint32_t rank(uint32_t id) const {
143         return ((uint32_t) (mData[id] >> 32)) & 0x7FFFFFFFu;
144     }
145 
parent(uint32_t id)146     uint32_t parent(uint32_t id) const {
147         return (uint32_t) mData[id];
148     }
149 
150     friend std::ostream &operator<<(std::ostream &os, const DisjointSets &f) {
151         for (size_t i=0; i<f.mData.size(); ++i)
152             os << i << ": parent=" << f.parent(i) << ", rank=" << f.rank(i) << std::endl;
153         return os;
154     }
155 
156     mutable std::vector<std::atomic<uint64_t>> mData;
157 };
158 
159 #endif /* __UNIONFIND_H */
160