1 #ifndef BTLLIB_COUNTING_BLOOM_FILTER_HPP
2 #define BTLLIB_COUNTING_BLOOM_FILTER_HPP
3
4 #include "bloom_filter.hpp"
5 #include "nthash.hpp"
6 #include "status.hpp"
7
8 #include "external/cpptoml.hpp"
9
10 #include <atomic>
11 #include <climits>
12 #include <cmath>
13 #include <cstdint>
14 #include <fstream>
15 #include <string>
16 #include <vector>
17
18 namespace btllib {
19
20 static const char* const COUNTING_BLOOM_FILTER_MAGIC_HEADER =
21 "BTLCountingBloomFilter_v2";
22 static const char* const KMER_COUNTING_BLOOM_FILTER_MAGIC_HEADER =
23 "BTLKmerCountingBloomFilter_v2";
24
25 template<typename T>
26 class KmerCountingBloomFilter;
27
28 /**
29 * Counting Bloom filter data structure. Provides CountingBloomFilter8,
30 * CountingBloomFilter16, and CountingBloomFilter32 classes with corresponding
31 * bit-size counters.
32 */
33 template<typename T>
34 class CountingBloomFilter
35 {
36
37 public:
38 /** Construct a dummy Kmer Bloom filter (e.g. as a default argument). */
CountingBloomFilter()39 CountingBloomFilter() {}
40
41 /**
42 * Construct an empty Counting Bloom filter of given size.
43 *
44 * @param bytes Filter size in bytes.
45 * @param hash_num Number of hash values per element.
46 */
47 CountingBloomFilter(size_t bytes, unsigned hash_num);
48
49 /**
50 * Load a Counting Bloom filter from a file.
51 *
52 * @param path Filepath to load from.
53 */
54 explicit CountingBloomFilter(const std::string& path);
55
~CountingBloomFilter()56 ~CountingBloomFilter() { delete[] array; }
57
58 CountingBloomFilter(const CountingBloomFilter&) = delete;
59 CountingBloomFilter(CountingBloomFilter&&) = delete;
60
61 CountingBloomFilter& operator=(const CountingBloomFilter&) = delete;
62 CountingBloomFilter& operator=(CountingBloomFilter&&) = delete;
63
64 /**
65 * Insert an element's hash values.
66 *
67 * @param hashes Integer array of hash values. Array size should equal the
68 * hash_num argument used when the Bloom filter was constructed.
69 */
70 void insert(const uint64_t* hashes);
71
72 /**
73 * Insert an element's hash values.
74 *
75 * @param hashes Integer vector of hash values.
76 */
insert(const std::vector<uint64_t> & hashes)77 void insert(const std::vector<uint64_t>& hashes) { insert(hashes.data()); }
78
79 /**
80 * Check for the presence of an element's hash values.
81 *
82 * @param hashes Integer array of hash values. Array size should equal the
83 * hash_num argument used when the Bloom filter was constructed.
84 *
85 * @return The count of the queried element.
86 */
87 T contains(const uint64_t* hashes) const;
88
89 /**
90 * Check for the presence of an element's hash values.
91 *
92 * @param hashes Integer vector of hash values.
93 *
94 * @return The count of the queried element.
95 */
contains(const std::vector<uint64_t> & hashes) const96 T contains(const std::vector<uint64_t>& hashes) const
97 {
98 return contains(hashes.data());
99 }
100
101 /** Get filter size in bytes. */
get_bytes() const102 size_t get_bytes() const { return bytes; }
103 /** Get population count, i.e. the number of counters >0 in the filter. */
104 uint64_t get_pop_cnt() const;
105 /** Get the fraction of the filter occupied by >1 counters. */
106 double get_occupancy() const;
107 /** Get the number of hash values per element. */
get_hash_num() const108 unsigned get_hash_num() const { return hash_num; }
109 /** Get the query false positive rate. */
110 double get_fpr() const;
111
112 /**
113 * Write the Bloom filter to a file that can be loaded in the future.
114 *
115 * @param path Filepath to store filter at.
116 */
117 void write(const std::string& path);
118
119 private:
120 friend class KmerCountingBloomFilter<T>;
121
122 std::atomic<T>* array = nullptr;
123 size_t bytes = 0;
124 size_t array_size = 0;
125 unsigned hash_num = 0;
126 };
127
128 /**
129 * Counting Bloom filter data structure stores k-mers. Provides
130 * KmerCountingBloomFilter8, KmerCountingBloomFilter16, and
131 * KmerCountingBloomFilter32 classes with corresponding bit-size counters.
132 */
133 template<typename T>
134 class KmerCountingBloomFilter
135 {
136
137 public:
138 /** Construct a dummy Kmer Bloom filter (e.g. as a default argument). */
KmerCountingBloomFilter()139 KmerCountingBloomFilter() {}
140
141 /**
142 * Construct an empty Kmer Counting Bloom filter of given size.
143 *
144 * @param bytes Filter size in bytes.
145 * @param hash_num Number of hash values per element.
146 * @param k K-mer size.
147 */
148 KmerCountingBloomFilter(size_t bytes, unsigned hash_num, unsigned k);
149
150 /**
151 * Load a Kmer Counting Bloom filter from a file.
152 *
153 * @param path Filepath to load from.
154 */
155 explicit KmerCountingBloomFilter(const std::string& path);
156
157 KmerCountingBloomFilter(const KmerCountingBloomFilter&) = delete;
158 KmerCountingBloomFilter(KmerCountingBloomFilter&&) = delete;
159
160 KmerCountingBloomFilter& operator=(const KmerCountingBloomFilter&) = delete;
161 KmerCountingBloomFilter& operator=(KmerCountingBloomFilter&&) = delete;
162
163 /**
164 * Insert a sequence's k-mers into the filter.
165 *
166 * @param seq Sequence to k-merize.
167 * @param seq_len Length of seq.
168 */
169 void insert(const char* seq, size_t seq_len);
170
171 /**
172 * Insert a sequence's k-mers into the filter.
173 *
174 * @param seq Sequence to k-merize.
175 */
insert(const std::string & seq)176 void insert(const std::string& seq) { insert(seq.c_str(), seq.size()); }
177
178 /**
179 * Query the presence of k-mers of a sequence.
180 *
181 * @param seq Sequence to k-merize.
182 * @param seq_len Length of seq.
183 *
184 * @return The sum of counters of seq's k-mers found in the filter.
185 */
186 uint64_t contains(const char* seq, size_t seq_len) const;
187
188 /**
189 * Query the presence of k-mers of a sequence.
190 *
191 * @param seq Sequence to k-merize.
192 *
193 * @return The sum of counters of seq's k-mers found in the filter.
194 */
contains(const std::string & seq) const195 uint64_t contains(const std::string& seq) const
196 {
197 return contains(seq.c_str(), seq.size());
198 }
199
200 /**
201 * Check for the presence of an element's hash values.
202 *
203 * @param hashes Integer array of hash values. Array size should equal the
204 * hash_num argument used when the Bloom filter was constructed.
205 *
206 * @return The count of the queried element.
207 */
contains(const uint64_t * hashes) const208 T contains(const uint64_t* hashes) const
209 {
210 counting_bloom_filter.contains(hashes);
211 }
212
213 /**
214 * Check for the presence of an element's hash values.
215 *
216 * @param hashes Integer vector of hash values.
217 *
218 * @return The count of the queried element.
219 */
contains(const std::vector<uint64_t> & hashes) const220 T contains(const std::vector<uint64_t>& hashes) const
221 {
222 counting_bloom_filter.contains(hashes);
223 }
224
225 /** Get filter size in bytes. */
get_bytes() const226 size_t get_bytes() const { return counting_bloom_filter.get_bytes(); }
227 /** Get population count, i.e. the number of counters >0 in the filter. */
get_pop_cnt() const228 uint64_t get_pop_cnt() const { return counting_bloom_filter.get_pop_cnt(); }
229 /** Get the fraction of the filter occupied by >0 counters. */
get_occupancy() const230 double get_occupancy() const { return counting_bloom_filter.get_occupancy(); }
231 /** Get the number of hash values per element. */
get_hash_num() const232 unsigned get_hash_num() const { return counting_bloom_filter.get_hash_num(); }
233 /** Get the query false positive rate. */
get_fpr() const234 double get_fpr() const { return counting_bloom_filter.get_fpr(); }
235 /** Get the k-mer size used. */
get_k() const236 unsigned get_k() const { return k; }
237 /** Get a reference to the underlying vanilla Counting Bloom filter. */
get_counting_bloom_filter()238 CountingBloomFilter<T>& get_counting_bloom_filter()
239 {
240 return counting_bloom_filter;
241 }
242
243 /**
244 * Write the Bloom filter to a file that can be loaded in the future.
245 *
246 * @param path Filepath to store filter at.
247 */
248 void write(const std::string& path);
249
250 private:
251 CountingBloomFilter<T> counting_bloom_filter;
252 unsigned k = 0;
253 };
254
255 using CountingBloomFilter8 = CountingBloomFilter<uint8_t>;
256 using CountingBloomFilter16 = CountingBloomFilter<uint16_t>;
257 using CountingBloomFilter32 = CountingBloomFilter<uint32_t>;
258
259 using KmerCountingBloomFilter8 = KmerCountingBloomFilter<uint8_t>;
260 using KmerCountingBloomFilter16 = KmerCountingBloomFilter<uint16_t>;
261 using KmerCountingBloomFilter32 = KmerCountingBloomFilter<uint32_t>;
262
263 template<typename T>
CountingBloomFilter(size_t bytes,unsigned hash_num)264 inline CountingBloomFilter<T>::CountingBloomFilter(size_t bytes,
265 unsigned hash_num)
266 : bytes(std::ceil(bytes / sizeof(uint64_t)) * sizeof(uint64_t))
267 , array_size(get_bytes() / sizeof(array[0]))
268 , hash_num(hash_num)
269 {
270 check_warning(sizeof(uint8_t) != sizeof(std::atomic<uint8_t>),
271 "Atomic primitives take extra memory. CountingBloomFilter will "
272 "have less than " +
273 std::to_string(bytes) + " for bit array.");
274 array = new std::atomic<T>[array_size];
275 std::memset((void*)array, 0, array_size * sizeof(array[0]));
276 }
277
278 template<typename T>
279 inline void
insert(const uint64_t * hashes)280 CountingBloomFilter<T>::insert(const uint64_t* hashes)
281 {
282 // Update flag to track if increment is done on at least one counter
283 bool update_done = false;
284 T new_val;
285 T min_val = contains(hashes);
286 while (!update_done) {
287 // Simple check to deal with overflow
288 new_val = min_val + 1;
289 if (min_val > new_val) {
290 return;
291 }
292 for (size_t i = 0; i < hash_num; ++i) {
293 decltype(min_val) temp_min_val = min_val;
294 if (array[hashes[i] % array_size].compare_exchange_strong(temp_min_val,
295 new_val)) {
296 update_done = true;
297 }
298 }
299 // Recalculate minval because if increment fails, it needs a new minval to
300 // use and if it doesnt hava a new one, the while loop runs forever.
301 if (!update_done) {
302 min_val = contains(hashes);
303 }
304 }
305 }
306
307 template<typename T>
308 inline T
contains(const uint64_t * hashes) const309 CountingBloomFilter<T>::contains(const uint64_t* hashes) const
310 {
311 T min = array[hashes[0] % array_size];
312 for (size_t i = 1; i < hash_num; ++i) {
313 const size_t idx = hashes[i] % array_size;
314 if (array[idx] < min) {
315 min = array[idx];
316 }
317 }
318 return min;
319 }
320
321 template<typename T>
322 inline uint64_t
get_pop_cnt() const323 CountingBloomFilter<T>::get_pop_cnt() const
324 {
325 uint64_t pop_cnt = 0;
326 #pragma omp parallel for default(none) reduction(+ : pop_cnt)
327 for (size_t i = 0; i < array_size; ++i) {
328 if (array[i] > 0) {
329 ++pop_cnt;
330 }
331 }
332 return pop_cnt;
333 }
334
335 template<typename T>
336 inline double
get_occupancy() const337 CountingBloomFilter<T>::get_occupancy() const
338 {
339 return double(get_pop_cnt()) / double(array_size);
340 }
341
342 template<typename T>
343 inline double
get_fpr() const344 CountingBloomFilter<T>::get_fpr() const
345 {
346 return std::pow(get_occupancy(), double(hash_num));
347 }
348
349 template<typename T>
CountingBloomFilter(const std::string & path)350 inline CountingBloomFilter<T>::CountingBloomFilter(const std::string& path)
351 {
352 std::ifstream file(path);
353
354 auto table =
355 BloomFilter::parse_header(file, COUNTING_BLOOM_FILTER_MAGIC_HEADER);
356 bytes = *table->get_as<decltype(bytes)>("bytes");
357 check_warning(sizeof(uint8_t) != sizeof(std::atomic<uint8_t>),
358 "Atomic primitives take extra memory. CountingBloomFilter will "
359 "have less than " +
360 std::to_string(bytes) + " for bit array.");
361 array_size = bytes / sizeof(array[0]);
362 hash_num = *table->get_as<decltype(hash_num)>("hash_num");
363 check_error(
364 sizeof(array[0]) * CHAR_BIT != *table->get_as<size_t>("counter_bits"),
365 "CountingBloomFilter" + std::to_string(sizeof(array[0]) * CHAR_BIT) +
366 " tried to load a file of CountingBloomFilter" +
367 std::to_string(*table->get_as<size_t>("counter_bits")));
368
369 array = new std::atomic<T>[array_size];
370 file.read((char*)array, array_size * sizeof(array[0]));
371 }
372
373 template<typename T>
374 inline void
write(const std::string & path)375 CountingBloomFilter<T>::write(const std::string& path)
376 {
377 std::ofstream file(path.c_str(), std::ios::out | std::ios::binary);
378
379 /* Initialize cpptoml root table
380 Note: Tables and fields are unordered
381 Ordering of table is maintained by directing the table
382 to the output stream immediately after completion */
383 auto root = cpptoml::make_table();
384
385 /* Initialize bloom filter section and insert fields
386 and output to ostream */
387 auto header = cpptoml::make_table();
388 header->insert("bytes", get_bytes());
389 header->insert("hash_num", get_hash_num());
390 header->insert("counter_bits", size_t(sizeof(array[0]) * CHAR_BIT));
391 root->insert(COUNTING_BLOOM_FILTER_MAGIC_HEADER, header);
392 file << *root << "[HeaderEnd]\n";
393
394 file.write((char*)array, array_size * sizeof(array[0]));
395 }
396
397 template<typename T>
KmerCountingBloomFilter(size_t bytes,unsigned hash_num,unsigned k)398 inline KmerCountingBloomFilter<T>::KmerCountingBloomFilter(size_t bytes,
399 unsigned hash_num,
400 unsigned k)
401 : counting_bloom_filter(bytes, hash_num)
402 , k(k)
403 {}
404
405 template<typename T>
406 inline void
insert(const char * seq,size_t seq_len)407 KmerCountingBloomFilter<T>::insert(const char* seq, size_t seq_len)
408 {
409 NtHash nthash(seq, seq_len, get_k(), get_hash_num());
410 while (nthash.roll()) {
411 counting_bloom_filter.insert(nthash.hashes());
412 }
413 }
414
415 template<typename T>
416 inline uint64_t
contains(const char * seq,size_t seq_len) const417 KmerCountingBloomFilter<T>::contains(const char* seq, size_t seq_len) const
418 {
419 uint64_t count = 0;
420 NtHash nthash(seq, seq_len, get_k(), get_hash_num());
421 while (nthash.roll()) {
422 count += counting_bloom_filter.contains(nthash.hashes());
423 }
424 return count;
425 }
426
427 template<typename T>
KmerCountingBloomFilter(const std::string & path)428 inline KmerCountingBloomFilter<T>::KmerCountingBloomFilter(
429 const std::string& path)
430 {
431 std::ifstream file(path);
432
433 auto table =
434 BloomFilter::parse_header(file, KMER_COUNTING_BLOOM_FILTER_MAGIC_HEADER);
435 counting_bloom_filter.bytes =
436 *table->get_as<decltype(counting_bloom_filter.bytes)>("bytes");
437 check_warning(sizeof(uint8_t) != sizeof(std::atomic<uint8_t>),
438 "Atomic primitives take extra memory. CountingBloomFilter will "
439 "have less than " +
440 std::to_string(get_bytes()) + " for bit array.");
441 counting_bloom_filter.array_size =
442 get_bytes() / sizeof(counting_bloom_filter.array[0]);
443 counting_bloom_filter.hash_num =
444 *table->get_as<decltype(counting_bloom_filter.hash_num)>("hash_num");
445 k = *table->get_as<decltype(k)>("k");
446 check_error(sizeof(T) * CHAR_BIT != *table->get_as<size_t>("counter_bits"),
447 "CountingBloomFilter" + std::to_string(sizeof(T) * CHAR_BIT) +
448 " tried to load a file of CountingBloomFilter" +
449 std::to_string(*table->get_as<size_t>("counter_bits")));
450
451 counting_bloom_filter.array =
452 new std::atomic<T>[counting_bloom_filter.array_size];
453 file.read((char*)counting_bloom_filter.array,
454 counting_bloom_filter.array_size *
455 sizeof(counting_bloom_filter.array[0]));
456 }
457
458 template<typename T>
459 inline void
write(const std::string & path)460 KmerCountingBloomFilter<T>::write(const std::string& path)
461 {
462 std::ofstream file(path.c_str(), std::ios::out | std::ios::binary);
463
464 /* Initialize cpptoml root table
465 Note: Tables and fields are unordered
466 Ordering of table is maintained by directing the table
467 to the output stream immediately after completion */
468 auto root = cpptoml::make_table();
469
470 /* Initialize bloom filter section and insert fields
471 and output to ostream */
472 auto header = cpptoml::make_table();
473 header->insert("bytes", get_bytes());
474 header->insert("hash_num", get_hash_num());
475 header->insert("counter_bits",
476 size_t(sizeof(counting_bloom_filter.array[0]) * CHAR_BIT));
477 header->insert("k", k);
478 root->insert(KMER_COUNTING_BLOOM_FILTER_MAGIC_HEADER, header);
479 file << *root << "[HeaderEnd]\n";
480
481 file.write((char*)counting_bloom_filter.array,
482 counting_bloom_filter.array_size *
483 sizeof(counting_bloom_filter.array[0]));
484 }
485
486 } // namespace btllib
487
488 #endif
489