1 #ifndef BTLLIB_COUNTING_BLOOM_FILTER_HPP
2 #define BTLLIB_COUNTING_BLOOM_FILTER_HPP
3 
4 #include "bloom_filter.hpp"
5 #include "nthash.hpp"
6 #include "status.hpp"
7 
8 #include "external/cpptoml.hpp"
9 
10 #include <atomic>
11 #include <climits>
12 #include <cmath>
13 #include <cstdint>
14 #include <fstream>
15 #include <string>
16 #include <vector>
17 
18 namespace btllib {
19 
20 static const char* const COUNTING_BLOOM_FILTER_MAGIC_HEADER =
21   "BTLCountingBloomFilter_v2";
22 static const char* const KMER_COUNTING_BLOOM_FILTER_MAGIC_HEADER =
23   "BTLKmerCountingBloomFilter_v2";
24 
25 template<typename T>
26 class KmerCountingBloomFilter;
27 
28 /**
29  * Counting Bloom filter data structure. Provides CountingBloomFilter8,
30  * CountingBloomFilter16, and CountingBloomFilter32 classes with corresponding
31  * bit-size counters.
32  */
33 template<typename T>
34 class CountingBloomFilter
35 {
36 
37 public:
38   /** Construct a dummy Kmer Bloom filter (e.g. as a default argument). */
CountingBloomFilter()39   CountingBloomFilter() {}
40 
41   /**
42    * Construct an empty Counting Bloom filter of given size.
43    *
44    * @param bytes Filter size in bytes.
45    * @param hash_num Number of hash values per element.
46    */
47   CountingBloomFilter(size_t bytes, unsigned hash_num);
48 
49   /**
50    * Load a Counting Bloom filter from a file.
51    *
52    * @param path Filepath to load from.
53    */
54   explicit CountingBloomFilter(const std::string& path);
55 
~CountingBloomFilter()56   ~CountingBloomFilter() { delete[] array; }
57 
58   CountingBloomFilter(const CountingBloomFilter&) = delete;
59   CountingBloomFilter(CountingBloomFilter&&) = delete;
60 
61   CountingBloomFilter& operator=(const CountingBloomFilter&) = delete;
62   CountingBloomFilter& operator=(CountingBloomFilter&&) = delete;
63 
64   /**
65    * Insert an element's hash values.
66    *
67    * @param hashes Integer array of hash values. Array size should equal the
68    * hash_num argument used when the Bloom filter was constructed.
69    */
70   void insert(const uint64_t* hashes);
71 
72   /**
73    * Insert an element's hash values.
74    *
75    * @param hashes Integer vector of hash values.
76    */
insert(const std::vector<uint64_t> & hashes)77   void insert(const std::vector<uint64_t>& hashes) { insert(hashes.data()); }
78 
79   /**
80    * Check for the presence of an element's hash values.
81    *
82    * @param hashes Integer array of hash values. Array size should equal the
83    * hash_num argument used when the Bloom filter was constructed.
84    *
85    * @return The count of the queried element.
86    */
87   T contains(const uint64_t* hashes) const;
88 
89   /**
90    * Check for the presence of an element's hash values.
91    *
92    * @param hashes Integer vector of hash values.
93    *
94    * @return The count of the queried element.
95    */
contains(const std::vector<uint64_t> & hashes) const96   T contains(const std::vector<uint64_t>& hashes) const
97   {
98     return contains(hashes.data());
99   }
100 
101   /** Get filter size in bytes. */
get_bytes() const102   size_t get_bytes() const { return bytes; }
103   /** Get population count, i.e. the number of counters >0 in the filter. */
104   uint64_t get_pop_cnt() const;
105   /** Get the fraction of the filter occupied by >1 counters. */
106   double get_occupancy() const;
107   /** Get the number of hash values per element. */
get_hash_num() const108   unsigned get_hash_num() const { return hash_num; }
109   /** Get the query false positive rate. */
110   double get_fpr() const;
111 
112   /**
113    * Write the Bloom filter to a file that can be loaded in the future.
114    *
115    * @param path Filepath to store filter at.
116    */
117   void write(const std::string& path);
118 
119 private:
120   friend class KmerCountingBloomFilter<T>;
121 
122   std::atomic<T>* array = nullptr;
123   size_t bytes = 0;
124   size_t array_size = 0;
125   unsigned hash_num = 0;
126 };
127 
128 /**
129  * Counting Bloom filter data structure stores k-mers. Provides
130  * KmerCountingBloomFilter8, KmerCountingBloomFilter16, and
131  * KmerCountingBloomFilter32 classes with corresponding bit-size counters.
132  */
133 template<typename T>
134 class KmerCountingBloomFilter
135 {
136 
137 public:
138   /** Construct a dummy Kmer Bloom filter (e.g. as a default argument). */
KmerCountingBloomFilter()139   KmerCountingBloomFilter() {}
140 
141   /**
142    * Construct an empty Kmer Counting Bloom filter of given size.
143    *
144    * @param bytes Filter size in bytes.
145    * @param hash_num Number of hash values per element.
146    * @param k K-mer size.
147    */
148   KmerCountingBloomFilter(size_t bytes, unsigned hash_num, unsigned k);
149 
150   /**
151    * Load a Kmer Counting Bloom filter from a file.
152    *
153    * @param path Filepath to load from.
154    */
155   explicit KmerCountingBloomFilter(const std::string& path);
156 
157   KmerCountingBloomFilter(const KmerCountingBloomFilter&) = delete;
158   KmerCountingBloomFilter(KmerCountingBloomFilter&&) = delete;
159 
160   KmerCountingBloomFilter& operator=(const KmerCountingBloomFilter&) = delete;
161   KmerCountingBloomFilter& operator=(KmerCountingBloomFilter&&) = delete;
162 
163   /**
164    * Insert a sequence's k-mers into the filter.
165    *
166    * @param seq Sequence to k-merize.
167    * @param seq_len Length of seq.
168    */
169   void insert(const char* seq, size_t seq_len);
170 
171   /**
172    * Insert a sequence's k-mers into the filter.
173    *
174    * @param seq Sequence to k-merize.
175    */
insert(const std::string & seq)176   void insert(const std::string& seq) { insert(seq.c_str(), seq.size()); }
177 
178   /**
179    * Query the presence of k-mers of a sequence.
180    *
181    * @param seq Sequence to k-merize.
182    * @param seq_len Length of seq.
183    *
184    * @return The sum of counters of seq's k-mers found in the filter.
185    */
186   uint64_t contains(const char* seq, size_t seq_len) const;
187 
188   /**
189    * Query the presence of k-mers of a sequence.
190    *
191    * @param seq Sequence to k-merize.
192    *
193    * @return The sum of counters of seq's k-mers found in the filter.
194    */
contains(const std::string & seq) const195   uint64_t contains(const std::string& seq) const
196   {
197     return contains(seq.c_str(), seq.size());
198   }
199 
200   /**
201    * Check for the presence of an element's hash values.
202    *
203    * @param hashes Integer array of hash values. Array size should equal the
204    * hash_num argument used when the Bloom filter was constructed.
205    *
206    * @return The count of the queried element.
207    */
contains(const uint64_t * hashes) const208   T contains(const uint64_t* hashes) const
209   {
210     counting_bloom_filter.contains(hashes);
211   }
212 
213   /**
214    * Check for the presence of an element's hash values.
215    *
216    * @param hashes Integer vector of hash values.
217    *
218    * @return The count of the queried element.
219    */
contains(const std::vector<uint64_t> & hashes) const220   T contains(const std::vector<uint64_t>& hashes) const
221   {
222     counting_bloom_filter.contains(hashes);
223   }
224 
225   /** Get filter size in bytes. */
get_bytes() const226   size_t get_bytes() const { return counting_bloom_filter.get_bytes(); }
227   /** Get population count, i.e. the number of counters >0 in the filter. */
get_pop_cnt() const228   uint64_t get_pop_cnt() const { return counting_bloom_filter.get_pop_cnt(); }
229   /** Get the fraction of the filter occupied by >0 counters. */
get_occupancy() const230   double get_occupancy() const { return counting_bloom_filter.get_occupancy(); }
231   /** Get the number of hash values per element. */
get_hash_num() const232   unsigned get_hash_num() const { return counting_bloom_filter.get_hash_num(); }
233   /** Get the query false positive rate. */
get_fpr() const234   double get_fpr() const { return counting_bloom_filter.get_fpr(); }
235   /** Get the k-mer size used. */
get_k() const236   unsigned get_k() const { return k; }
237   /** Get a reference to the underlying vanilla Counting Bloom filter. */
get_counting_bloom_filter()238   CountingBloomFilter<T>& get_counting_bloom_filter()
239   {
240     return counting_bloom_filter;
241   }
242 
243   /**
244    * Write the Bloom filter to a file that can be loaded in the future.
245    *
246    * @param path Filepath to store filter at.
247    */
248   void write(const std::string& path);
249 
250 private:
251   CountingBloomFilter<T> counting_bloom_filter;
252   unsigned k = 0;
253 };
254 
255 using CountingBloomFilter8 = CountingBloomFilter<uint8_t>;
256 using CountingBloomFilter16 = CountingBloomFilter<uint16_t>;
257 using CountingBloomFilter32 = CountingBloomFilter<uint32_t>;
258 
259 using KmerCountingBloomFilter8 = KmerCountingBloomFilter<uint8_t>;
260 using KmerCountingBloomFilter16 = KmerCountingBloomFilter<uint16_t>;
261 using KmerCountingBloomFilter32 = KmerCountingBloomFilter<uint32_t>;
262 
263 template<typename T>
CountingBloomFilter(size_t bytes,unsigned hash_num)264 inline CountingBloomFilter<T>::CountingBloomFilter(size_t bytes,
265                                                    unsigned hash_num)
266   : bytes(std::ceil(bytes / sizeof(uint64_t)) * sizeof(uint64_t))
267   , array_size(get_bytes() / sizeof(array[0]))
268   , hash_num(hash_num)
269 {
270   check_warning(sizeof(uint8_t) != sizeof(std::atomic<uint8_t>),
271                 "Atomic primitives take extra memory. CountingBloomFilter will "
272                 "have less than " +
273                   std::to_string(bytes) + " for bit array.");
274   array = new std::atomic<T>[array_size];
275   std::memset((void*)array, 0, array_size * sizeof(array[0]));
276 }
277 
278 template<typename T>
279 inline void
insert(const uint64_t * hashes)280 CountingBloomFilter<T>::insert(const uint64_t* hashes)
281 {
282   // Update flag to track if increment is done on at least one counter
283   bool update_done = false;
284   T new_val;
285   T min_val = contains(hashes);
286   while (!update_done) {
287     // Simple check to deal with overflow
288     new_val = min_val + 1;
289     if (min_val > new_val) {
290       return;
291     }
292     for (size_t i = 0; i < hash_num; ++i) {
293       decltype(min_val) temp_min_val = min_val;
294       if (array[hashes[i] % array_size].compare_exchange_strong(temp_min_val,
295                                                                 new_val)) {
296         update_done = true;
297       }
298     }
299     // Recalculate minval because if increment fails, it needs a new minval to
300     // use and if it doesnt hava a new one, the while loop runs forever.
301     if (!update_done) {
302       min_val = contains(hashes);
303     }
304   }
305 }
306 
307 template<typename T>
308 inline T
contains(const uint64_t * hashes) const309 CountingBloomFilter<T>::contains(const uint64_t* hashes) const
310 {
311   T min = array[hashes[0] % array_size];
312   for (size_t i = 1; i < hash_num; ++i) {
313     const size_t idx = hashes[i] % array_size;
314     if (array[idx] < min) {
315       min = array[idx];
316     }
317   }
318   return min;
319 }
320 
321 template<typename T>
322 inline uint64_t
get_pop_cnt() const323 CountingBloomFilter<T>::get_pop_cnt() const
324 {
325   uint64_t pop_cnt = 0;
326 #pragma omp parallel for default(none) reduction(+ : pop_cnt)
327   for (size_t i = 0; i < array_size; ++i) {
328     if (array[i] > 0) {
329       ++pop_cnt;
330     }
331   }
332   return pop_cnt;
333 }
334 
335 template<typename T>
336 inline double
get_occupancy() const337 CountingBloomFilter<T>::get_occupancy() const
338 {
339   return double(get_pop_cnt()) / double(array_size);
340 }
341 
342 template<typename T>
343 inline double
get_fpr() const344 CountingBloomFilter<T>::get_fpr() const
345 {
346   return std::pow(get_occupancy(), double(hash_num));
347 }
348 
349 template<typename T>
CountingBloomFilter(const std::string & path)350 inline CountingBloomFilter<T>::CountingBloomFilter(const std::string& path)
351 {
352   std::ifstream file(path);
353 
354   auto table =
355     BloomFilter::parse_header(file, COUNTING_BLOOM_FILTER_MAGIC_HEADER);
356   bytes = *table->get_as<decltype(bytes)>("bytes");
357   check_warning(sizeof(uint8_t) != sizeof(std::atomic<uint8_t>),
358                 "Atomic primitives take extra memory. CountingBloomFilter will "
359                 "have less than " +
360                   std::to_string(bytes) + " for bit array.");
361   array_size = bytes / sizeof(array[0]);
362   hash_num = *table->get_as<decltype(hash_num)>("hash_num");
363   check_error(
364     sizeof(array[0]) * CHAR_BIT != *table->get_as<size_t>("counter_bits"),
365     "CountingBloomFilter" + std::to_string(sizeof(array[0]) * CHAR_BIT) +
366       " tried to load a file of CountingBloomFilter" +
367       std::to_string(*table->get_as<size_t>("counter_bits")));
368 
369   array = new std::atomic<T>[array_size];
370   file.read((char*)array, array_size * sizeof(array[0]));
371 }
372 
373 template<typename T>
374 inline void
write(const std::string & path)375 CountingBloomFilter<T>::write(const std::string& path)
376 {
377   std::ofstream file(path.c_str(), std::ios::out | std::ios::binary);
378 
379   /* Initialize cpptoml root table
380     Note: Tables and fields are unordered
381     Ordering of table is maintained by directing the table
382     to the output stream immediately after completion  */
383   auto root = cpptoml::make_table();
384 
385   /* Initialize bloom filter section and insert fields
386       and output to ostream */
387   auto header = cpptoml::make_table();
388   header->insert("bytes", get_bytes());
389   header->insert("hash_num", get_hash_num());
390   header->insert("counter_bits", size_t(sizeof(array[0]) * CHAR_BIT));
391   root->insert(COUNTING_BLOOM_FILTER_MAGIC_HEADER, header);
392   file << *root << "[HeaderEnd]\n";
393 
394   file.write((char*)array, array_size * sizeof(array[0]));
395 }
396 
397 template<typename T>
KmerCountingBloomFilter(size_t bytes,unsigned hash_num,unsigned k)398 inline KmerCountingBloomFilter<T>::KmerCountingBloomFilter(size_t bytes,
399                                                            unsigned hash_num,
400                                                            unsigned k)
401   : counting_bloom_filter(bytes, hash_num)
402   , k(k)
403 {}
404 
405 template<typename T>
406 inline void
insert(const char * seq,size_t seq_len)407 KmerCountingBloomFilter<T>::insert(const char* seq, size_t seq_len)
408 {
409   NtHash nthash(seq, seq_len, get_k(), get_hash_num());
410   while (nthash.roll()) {
411     counting_bloom_filter.insert(nthash.hashes());
412   }
413 }
414 
415 template<typename T>
416 inline uint64_t
contains(const char * seq,size_t seq_len) const417 KmerCountingBloomFilter<T>::contains(const char* seq, size_t seq_len) const
418 {
419   uint64_t count = 0;
420   NtHash nthash(seq, seq_len, get_k(), get_hash_num());
421   while (nthash.roll()) {
422     count += counting_bloom_filter.contains(nthash.hashes());
423   }
424   return count;
425 }
426 
427 template<typename T>
KmerCountingBloomFilter(const std::string & path)428 inline KmerCountingBloomFilter<T>::KmerCountingBloomFilter(
429   const std::string& path)
430 {
431   std::ifstream file(path);
432 
433   auto table =
434     BloomFilter::parse_header(file, KMER_COUNTING_BLOOM_FILTER_MAGIC_HEADER);
435   counting_bloom_filter.bytes =
436     *table->get_as<decltype(counting_bloom_filter.bytes)>("bytes");
437   check_warning(sizeof(uint8_t) != sizeof(std::atomic<uint8_t>),
438                 "Atomic primitives take extra memory. CountingBloomFilter will "
439                 "have less than " +
440                   std::to_string(get_bytes()) + " for bit array.");
441   counting_bloom_filter.array_size =
442     get_bytes() / sizeof(counting_bloom_filter.array[0]);
443   counting_bloom_filter.hash_num =
444     *table->get_as<decltype(counting_bloom_filter.hash_num)>("hash_num");
445   k = *table->get_as<decltype(k)>("k");
446   check_error(sizeof(T) * CHAR_BIT != *table->get_as<size_t>("counter_bits"),
447               "CountingBloomFilter" + std::to_string(sizeof(T) * CHAR_BIT) +
448                 " tried to load a file of CountingBloomFilter" +
449                 std::to_string(*table->get_as<size_t>("counter_bits")));
450 
451   counting_bloom_filter.array =
452     new std::atomic<T>[counting_bloom_filter.array_size];
453   file.read((char*)counting_bloom_filter.array,
454             counting_bloom_filter.array_size *
455               sizeof(counting_bloom_filter.array[0]));
456 }
457 
458 template<typename T>
459 inline void
write(const std::string & path)460 KmerCountingBloomFilter<T>::write(const std::string& path)
461 {
462   std::ofstream file(path.c_str(), std::ios::out | std::ios::binary);
463 
464   /* Initialize cpptoml root table
465     Note: Tables and fields are unordered
466     Ordering of table is maintained by directing the table
467     to the output stream immediately after completion  */
468   auto root = cpptoml::make_table();
469 
470   /* Initialize bloom filter section and insert fields
471       and output to ostream */
472   auto header = cpptoml::make_table();
473   header->insert("bytes", get_bytes());
474   header->insert("hash_num", get_hash_num());
475   header->insert("counter_bits",
476                  size_t(sizeof(counting_bloom_filter.array[0]) * CHAR_BIT));
477   header->insert("k", k);
478   root->insert(KMER_COUNTING_BLOOM_FILTER_MAGIC_HEADER, header);
479   file << *root << "[HeaderEnd]\n";
480 
481   file.write((char*)counting_bloom_filter.array,
482              counting_bloom_filter.array_size *
483                sizeof(counting_bloom_filter.array[0]));
484 }
485 
486 } // namespace btllib
487 
488 #endif
489