OILS / vendor / souffle / datastructure / UnionFind.h View on Github | oils.pub

357 lines, 161 significant
1/*
2 * Souffle - A Datalog Compiler
3 * Copyright (c) 2017 The Souffle Developers. All rights reserved
4 * Licensed under the Universal Permissive License v 1.0 as shown at:
5 * - https://opensource.org/licenses/UPL
6 * - <souffle root>/licenses/SOUFFLE-UPL.txt
7 */
8
9/************************************************************************
10 *
11 * @file UnionFind.h
12 *
13 * Defines a union-find data-structure
14 *
15 ***********************************************************************/
16
17#pragma once
18
19#include "souffle/datastructure/LambdaBTree.h"
20#include "souffle/datastructure/PiggyList.h"
21#include "souffle/utility/MiscUtil.h"
22#include <atomic>
23#include <cstddef>
24#include <cstdint>
25#include <functional>
26#include <utility>
27
28namespace souffle {
29
30// branch predictor hacks
31#define unlikely(x) __builtin_expect((x), 0)
32#define likely(x) __builtin_expect((x), 1)
33
34using rank_t = uint8_t;
35/* technically uint56_t, but, doesn't exist. Just be careful about storing > 2^56 elements. */
36using parent_t = uint64_t;
37
38// number of bits that the rank is
39constexpr uint8_t split_size = 8u;
40
41// block_t stores parent in the upper half, rank in the lower half
42using block_t = uint64_t;
43// block_t & rank_mask extracts the rank
44constexpr block_t rank_mask = (1ul << split_size) - 1;
45
46/**
47 * Structure that emulates a Disjoint Set, i.e. a data structure that supports efficient union-find operations
48 */
49class DisjointSet {
50 template <typename TupleType>
51 friend class EquivalenceRelation;
52
53 PiggyList<std::atomic<block_t>> a_blocks;
54
55public:
56 DisjointSet() = default;
57
58 // copy ctor
59 DisjointSet(DisjointSet& other) = delete;
60 // move ctor
61 DisjointSet(DisjointSet&& other) = delete;
62
63 // copy assign ctor
64 DisjointSet& operator=(DisjointSet& ds) = delete;
65 // move assign ctor
66 DisjointSet& operator=(DisjointSet&& ds) = delete;
67
68 /**
69 * Return the number of elements in this disjoint set (not the number of pairs)
70 */
71 inline std::size_t size() {
72 auto sz = a_blocks.size();
73 return sz;
74 };
75
76 /**
77 * Yield reference to the node by its node index
78 * @param node node to be searched
79 * @return the parent block of the specified node
80 */
81 inline std::atomic<block_t>& get(parent_t node) const {
82 auto& ret = a_blocks.get(node);
83 return ret;
84 };
85
86 /**
87 * Equivalent to the find() function in union/find
88 * Find the highest ancestor of the provided node - flattening as we go
89 * @param x the node to find the parent of, whilst flattening its set-tree
90 * @return The parent of x
91 */
92 parent_t findNode(parent_t x) {
93 // while x's parent is not itself
94 while (x != b2p(get(x))) {
95 block_t xState = get(x);
96 // yield x's parent's parent
97 parent_t newParent = b2p(get(b2p(xState)));
98 // construct block out of the original rank and the new parent
99 block_t newState = pr2b(newParent, b2r(xState));
100
101 this->get(x).compare_exchange_strong(xState, newState);
102
103 x = newParent;
104 }
105 return x;
106 }
107
108private:
109 /**
110 * Update the root of the tree of which x is, to have y as the base instead
111 * @param x : old root
112 * @param oldrank : old root rank
113 * @param y : new root
114 * @param newrank : new root rank
115 * @return Whether the update succeeded (fails if another root update/union has been perfomed in the
116 * interim)
117 */
118 bool updateRoot(const parent_t x, const rank_t oldrank, const parent_t y, const rank_t newrank) {
119 block_t oldState = get(x);
120 parent_t nextN = b2p(oldState);
121 rank_t rankN = b2r(oldState);
122
123 if (nextN != x || rankN != oldrank) return false;
124 // set the parent and rank of the new record
125 block_t newVal = pr2b(y, newrank);
126
127 return this->get(x).compare_exchange_strong(oldState, newVal);
128 }
129
130public:
131 /**
132 * Clears the DisjointSet of all nodes
133 * Invalidates all iterators
134 */
135 void clear() {
136 a_blocks.clear();
137 }
138
139 /**
140 * Check whether the two indices are in the same set
141 * @param x node to be checked
142 * @param y node to be checked
143 * @return where the two indices are in the same set
144 */
145 bool sameSet(parent_t x, parent_t y) {
146 while (true) {
147 x = findNode(x);
148 y = findNode(y);
149 if (x == y) return true;
150 // if x's parent is itself, they are not the same set
151 if (b2p(get(x)) == x) return false;
152 }
153 }
154
155 /**
156 * Union the two specified index nodes
157 * @param x node to be unioned
158 * @param y node to be unioned
159 */
160 void unionNodes(parent_t x, parent_t y) {
161 while (true) {
162 x = findNode(x);
163 y = findNode(y);
164
165 // no need to union if both already in same set
166 if (x == y) return;
167
168 rank_t xrank = b2r(get(x));
169 rank_t yrank = b2r(get(y));
170
171 // if x comes before y (better rank or earlier & equal node)
172 if (xrank > yrank || ((xrank == yrank) && x > y)) {
173 std::swap(x, y);
174 std::swap(xrank, yrank);
175 }
176 // join the trees together
177 // perhaps we can optimise the use of compare_exchange_strong here, as we're in a pessimistic loop
178 if (!updateRoot(x, xrank, y, yrank)) {
179 continue;
180 }
181 // make sure that the ranks are orderable
182 if (xrank == yrank) {
183 updateRoot(y, yrank, y, yrank + 1);
184 }
185 break;
186 }
187 }
188
189 /**
190 * Create a node with its parent as itself, rank 0
191 * @return the newly created block
192 */
193 inline block_t makeNode() {
194 // make node and find out where we've added it
195 std::size_t nodeDetails = a_blocks.createNode();
196
197 a_blocks.get(nodeDetails).store(pr2b(nodeDetails, 0));
198
199 return a_blocks.get(nodeDetails).load();
200 };
201
202 /**
203 * Extract parent from block
204 * @param inblock the block to be masked
205 * @return The parent_t contained in the upper half of block_t
206 */
207 static inline parent_t b2p(const block_t inblock) {
208 return (parent_t)(inblock >> split_size);
209 };
210
211 /**
212 * Extract rank from block
213 * @param inblock the block to be masked
214 * @return the rank_t contained in the lower half of block_t
215 */
216 static inline rank_t b2r(const block_t inblock) {
217 return (rank_t)(inblock & rank_mask);
218 };
219
220 /**
221 * Yield a block given parent and rank
222 * @param parent the top half bits
223 * @param rank the lower half bits
224 * @return the resultant block after merge
225 */
226 static inline block_t pr2b(const parent_t parent, const rank_t rank) {
227 return (((block_t)parent) << split_size) | rank;
228 };
229};
230
231template <typename StorePair>
232struct EqrelMapComparator {
233 int operator()(const StorePair& a, const StorePair& b) {
234 if (a.first < b.first) {
235 return -1;
236 } else if (b.first < a.first) {
237 return 1;
238 } else {
239 return 0;
240 }
241 }
242
243 bool less(const StorePair& a, const StorePair& b) {
244 return operator()(a, b) < 0;
245 }
246
247 bool equal(const StorePair& a, const StorePair& b) {
248 return operator()(a, b) == 0;
249 }
250};
251
252template <typename SparseDomain>
253class SparseDisjointSet {
254 DisjointSet ds;
255
256 template <typename TupleType>
257 friend class EquivalenceRelation;
258
259 using PairStore = std::pair<SparseDomain, parent_t>;
260 using SparseMap =
261 LambdaBTreeSet<PairStore, std::function<parent_t(PairStore&)>, EqrelMapComparator<PairStore>>;
262 using DenseMap = RandomInsertPiggyList<SparseDomain>;
263
264 typename SparseMap::operation_hints last_ins;
265
266 SparseMap sparseToDenseMap;
267 // mapping from union-find val to souffle, union-find encoded as index
268 DenseMap denseToSparseMap;
269
270public:
271 /**
272 * Retrieve dense encoding, adding it in if non-existent
273 * @param in the sparse value
274 * @return the corresponding dense value
275 */
276 parent_t toDense(const SparseDomain in) {
277 // insert into the mapping - if the key doesn't exist (in), the function will be called
278 // and a dense value will be created for it
279 PairStore p = {in, -1};
280 return sparseToDenseMap.insert(p, [&](PairStore& p) {
281 parent_t c2 = DisjointSet::b2p(this->ds.makeNode());
282 this->denseToSparseMap.insertAt(c2, p.first);
283 p.second = c2;
284 return c2;
285 });
286 }
287
288public:
289 SparseDisjointSet() = default;
290
291 // copy ctor
292 SparseDisjointSet(SparseDisjointSet& other) = delete;
293
294 // move ctor
295 SparseDisjointSet(SparseDisjointSet&& other) = delete;
296
297 // copy assign ctor
298 SparseDisjointSet& operator=(SparseDisjointSet& other) = delete;
299
300 // move assign ctor
301 SparseDisjointSet& operator=(SparseDisjointSet&& other) = delete;
302
303 /**
304 * For the given dense value, return the associated sparse value
305 * Undefined behaviour if dense value not in set
306 * @param in the supplied dense value
307 * @return the sparse value from the denseToSparseMap
308 */
309 inline const SparseDomain toSparse(const parent_t in) const {
310 return denseToSparseMap.get(in);
311 };
312
313 /* a wrapper to enable checking in the sparse set - however also adds them if not already existing */
314 inline bool sameSet(SparseDomain x, SparseDomain y) {
315 return ds.sameSet(toDense(x), toDense(y));
316 };
317 /* finds the node in the underlying disjoint set, adding the node if non-existent */
318 inline SparseDomain findNode(SparseDomain x) {
319 return toSparse(ds.findNode(toDense(x)));
320 };
321 /* union the nodes, add if not existing */
322 inline void unionNodes(SparseDomain x, SparseDomain y) {
323 ds.unionNodes(toDense(x), toDense(y));
324 };
325
326 inline std::size_t size() {
327 return ds.size();
328 };
329
330 /**
331 * Remove all elements from this disjoint set
332 */
333 void clear() {
334 ds.clear();
335 sparseToDenseMap.clear();
336 denseToSparseMap.clear();
337 }
338
339 /* wrapper for node creation */
340 inline void makeNode(SparseDomain val) {
341 // dense has the behaviour of creating if not exists.
342 toDense(val);
343 };
344
345 /* whether the supplied node exists */
346 inline bool nodeExists(const SparseDomain val) const {
347 return sparseToDenseMap.contains({val, -1});
348 };
349
350 inline bool contains(SparseDomain v1, SparseDomain v2) {
351 if (nodeExists(v1) && nodeExists(v2)) {
352 return sameSet(v1, v2);
353 }
354 return false;
355 }
356};
357} // namespace souffle