1 /* 2 * MIBFQuerySupport.hpp 3 * 4 * Purpose: To provide support for complex classification 5 * 6 * Functions for most accurate classification and faster heuristic classification 7 * are split into sections. 8 * 9 * Contains support objects intended to be private per thread (copied) 10 * 11 * 12 * Created on: Jun 6, 2018 13 * Author: justin 14 */ 15 16 #ifndef MIBFQUERYSUPPORT_HPP_ 17 #define MIBFQUERYSUPPORT_HPP_ 18 19 #include "MIBloomFilter.hpp" 20 //#include <set> 21 #include "vendor/ntHashIterator.hpp" 22 #include "vendor/stHashIterator.hpp" 23 #include <boost/math/distributions/binomial.hpp> 24 25 using namespace std; 26 using boost::math::binomial; 27 28 // T = T type, H = rolling hash itr 29 template<typename T> 30 class MIBFQuerySupport 31 { 32 public: MIBFQuerySupport(const MIBloomFilter<T> & miBF,const vector<double> & perFrameProb,unsigned extraCount,unsigned extraFrameLimit,unsigned maxMiss,unsigned minCount,bool bestHitAgree)33 MIBFQuerySupport( 34 const MIBloomFilter<T>& miBF, 35 const vector<double>& perFrameProb, 36 unsigned extraCount, 37 unsigned extraFrameLimit, 38 unsigned maxMiss, 39 unsigned minCount, 40 bool bestHitAgree) 41 : m_miBF(miBF) 42 , m_perFrameProb(perFrameProb) 43 , m_extraCount(extraCount) 44 , m_extraFrameLimit(extraFrameLimit) 45 , m_maxMiss(maxMiss) 46 , m_minCount(minCount) 47 , m_bestHitAgree(bestHitAgree) 48 , m_satCount(0) 49 , m_evalCount(0) 50 , m_bestCounts({ 0, 0, 0, 0, 0, 0, 0 }) 51 , m_secondBestNonSatFrameCount(0) 52 , m_rankPos(miBF.getHashNum()) 53 , m_hits(miBF.getHashNum(), true) 54 , m_counts(vector<CountResult>(perFrameProb.size(), { 0, 0, 0, 0, 0, 0, 0 })) 55 , m_totalReads(0) 56 { 57 // this should always be a small array 58 m_seenSet.reserve(miBF.getHashNum()); 59 } 60 61 struct QueryResult 62 { 63 T id; 64 uint16_t count; 65 uint16_t nonSatCount; 66 uint16_t totalCount; 67 uint16_t totalNonSatCount; 68 uint16_t nonSatFrameCount; 69 uint16_t solidCount; 70 double frameProb; 71 }; 72 73 struct CountResult 74 { 75 uint16_t count; 76 uint16_t nonSatCount; 77 uint16_t totalCount; 78 uint16_t totalNonSatCount; 79 uint16_t nonSatFrameCount; 80 uint16_t solidCount; 81 size_t readCount; // determines if count should be reset 82 }; 83 84 // For returning an empty result emptyResult()85 const vector<QueryResult>& emptyResult() 86 { 87 init(); 88 return m_signifResults; 89 } 90 91 /* 92 * totalTrials = number of possible trials that can be checked 93 */ 94 template<typename ITR> query(ITR & itr,const vector<unsigned> & minCount)95 const vector<QueryResult>& query(ITR& itr, const vector<unsigned>& minCount) 96 { 97 init(); 98 99 unsigned extraFrame = 0; 100 bool candidateFound = false; 101 102 while (itr != itr.end() && !candidateFound) { 103 candidateFound = updateCounts(itr, minCount, extraFrame); 104 ++itr; 105 } 106 summarizeCandiates(); 107 108 return m_signifResults; 109 } 110 111 template<typename ITR> query(ITR & itr1,ITR & itr2,const vector<unsigned> & minCount)112 const vector<QueryResult>& query(ITR& itr1, ITR& itr2, const vector<unsigned>& minCount) 113 { 114 init(); 115 116 unsigned extraFrame = 0; 117 unsigned frameCount = 0; 118 bool candidateFound = false; 119 120 while ((itr1 != itr1.end() || itr2 != itr2.end()) && !candidateFound) { 121 auto& itr = 122 frameCount % 2 == 0 && itr1 != itr1.end() ? itr1 : itr2 != itr2.end() ? itr2 : itr1; 123 candidateFound = updateCounts(itr, minCount, extraFrame); 124 ++itr; 125 ++frameCount; 126 } 127 summarizeCandiates(); 128 129 return m_signifResults; 130 } 131 getSatCount() const132 unsigned getSatCount() const { return m_satCount; } 133 getEvalCount() const134 unsigned getEvalCount() const { return m_evalCount; } 135 136 // debugging functions: 137 printAllCounts(const vector<string> & ids)138 void printAllCounts(const vector<string>& ids) 139 { 140 for (size_t i = 0; i < m_counts.size(); ++i) { 141 if (m_counts[i].totalCount > 0) { 142 cout << i << "\t" << ids[i] << "\t" << m_counts[i].nonSatFrameCount << "\t" 143 << m_counts[i].count << "\t" << m_counts[i].solidCount << "\t" 144 << m_counts[i].nonSatCount << "\t" << m_counts[i].totalNonSatCount << "\t" 145 << m_counts[i].totalCount << "\n"; 146 } 147 } 148 } 149 150 /* 151 * Debugging 152 * Computes criteria used for judging a read consisting of: 153 * Position of matches 154 * Number of actually evaluated k-mers 155 * Return count of matching k-mers to set 156 */ 157 // TODO saturation not handle correctly getMatchSignature(const string & seq,unsigned & evaluatedSeeds,vector<vector<pair<T,bool>>> & hitsPattern)158 inline vector<unsigned> getMatchSignature( 159 const string& seq, 160 unsigned& evaluatedSeeds, 161 vector<vector<pair<T, bool>>>& hitsPattern) 162 { 163 vector<unsigned> matchPos; 164 matchPos.reserve(seq.size() - m_miBF.getKmerSize()); 165 166 if (m_miBF.getSeedValues().size() > 0) { 167 stHashIterator itr( 168 seq, m_miBF.getSeedValues(), m_miBF.getHashNum(), m_miBF.getKmerSize()); 169 while (itr != itr.end()) { 170 if (m_maxMiss >= m_miBF.atRank(*itr, m_rankPos, m_hits, m_maxMiss)) { 171 vector<T> results = m_miBF.getData(m_rankPos); 172 vector<pair<T, bool>> processedResults(results.size(), pair<T, bool>(0, false)); 173 for (unsigned i = 0; i < m_miBF.getHashNum(); ++i) { 174 if (m_hits[i]) { 175 T tempResult = results[i]; 176 if (tempResult > MIBloomFilter<T>::s_mask) { 177 processedResults[i] = 178 pair<T, bool>(tempResult & MIBloomFilter<T>::s_antiMask, true); 179 } else { 180 processedResults[i] = 181 pair<T, bool>(tempResult & MIBloomFilter<T>::s_antiMask, false); 182 } 183 } 184 } 185 matchPos.push_back(itr.pos()); 186 hitsPattern.push_back(processedResults); 187 } 188 ++itr; 189 ++evaluatedSeeds; 190 } 191 } else { 192 ntHashIterator itr(seq, m_miBF.getHashNum(), m_miBF.getKmerSize()); 193 while (itr != itr.end()) { 194 if (m_miBF.atRank(*itr, m_rankPos)) { 195 vector<T> results = m_miBF.getData(m_rankPos); 196 vector<pair<T, bool>> processedResults(results.size(), pair<T, bool>(0, false)); 197 if (results.size() > 0) { 198 for (unsigned i = 0; i < m_miBF.getHashNum(); ++i) { 199 T tempResult = results[i]; 200 if (tempResult > MIBloomFilter<T>::s_mask) { 201 processedResults[i] = 202 pair<T, bool>(tempResult & MIBloomFilter<T>::s_antiMask, true); 203 } else { 204 processedResults[i] = 205 pair<T, bool>(tempResult & MIBloomFilter<T>::s_antiMask, false); 206 } 207 } 208 matchPos.push_back(itr.pos()); 209 hitsPattern.push_back(processedResults); 210 } 211 } 212 ++itr; 213 ++evaluatedSeeds; 214 } 215 } 216 return matchPos; 217 } 218 219 private: 220 /* 221 * Sort in order of 222 * nonSatFrameCount 223 * count 224 * solidCount 225 * nonSatCount 226 * totalNonSatCount 227 * totalCount 228 * frameProb 229 */ sortCandidates(const QueryResult & a,const QueryResult & b)230 static inline bool sortCandidates(const QueryResult& a, const QueryResult& b) 231 { 232 return ( 233 b.nonSatFrameCount == a.nonSatFrameCount 234 ? (b.count == a.count 235 ? (b.solidCount == a.solidCount 236 ? (b.nonSatCount == a.nonSatCount 237 ? (b.totalNonSatCount == a.totalNonSatCount 238 ? (b.totalCount == a.totalCount 239 ? (a.frameProb > b.frameProb) 240 : a.totalCount > b.totalCount) 241 : a.totalNonSatCount > b.totalNonSatCount) 242 : a.nonSatCount > b.nonSatCount) 243 : a.solidCount > b.solidCount) 244 : a.count > b.count) 245 : a.nonSatFrameCount > b.nonSatFrameCount); 246 } 247 248 // static inline bool sortCandidates(const QueryResult &a, 249 // const QueryResult &b, unsigned extraCount ) { 250 // return (isRoughlyEqual(b.count, a.count, extraCount) ? 251 // (isRoughlyEqual(b.totalNonSatCount, a.totalNonSatCount, extraCount) ? 252 // (isRoughlyEqual(b.nonSatFrameCount, a.nonSatFrameCount, extraCount) ? 253 // (isRoughlyEqual(b.solidCount, a.solidCount, extraCount) ? 254 // (isRoughlyEqual(b.nonSatCount, a.nonSatCount, extraCount) ? 255 // (isRoughlyEqual(b.totalCount, a.totalCount, extraCount) ? 256 // (a.frameProb > b.frameProb) : 257 // a.totalCount > b.totalCount) : 258 // a.nonSatCount > b.nonSatCount) : 259 // a.solidCount > b.solidCount) : 260 // a.nonSatFrameCount > b.nonSatFrameCount) : 261 // a.totalNonSatCount > b.totalNonSatCount) : 262 // a.count > b.count); 263 // } 264 265 // static inline bool sortCandidates(const QueryResult &a, 266 // const QueryResult &b ) { 267 // return (compareStdErr(b.count, a.count) ? 268 // (compareStdErr(b.totalNonSatCount, a.totalNonSatCount) ? 269 // (compareStdErr(b.nonSatFrameCount, a.nonSatFrameCount) ? 270 // (compareStdErr(b.solidCount, a.solidCount) ? 271 // (compareStdErr(b.nonSatCount, a.nonSatCount) ? 272 // (compareStdErr(b.totalCount, a.totalCount) ? 273 // (a.frameProb > b.frameProb) : 274 // a.totalCount > b.totalCount) : 275 // a.nonSatCount > b.nonSatCount) : 276 // a.solidCount > b.solidCount) : 277 // a.nonSatFrameCount > b.nonSatFrameCount) : 278 // a.totalNonSatCount > b.totalNonSatCount) : 279 // a.count > b.count); 280 // } 281 282 /* 283 * Returns true if considered roughly equal 284 */ isRoughlyEqual(unsigned a,unsigned b,unsigned extraCount)285 static inline bool isRoughlyEqual(unsigned a, unsigned b, unsigned extraCount) 286 { 287 if (a > b) { 288 return a <= b + extraCount; 289 } 290 return b <= a + extraCount; 291 } 292 293 /* 294 * Returns true if considered roughly equal 295 */ compareStdErr(unsigned a,unsigned b)296 static inline bool compareStdErr(unsigned a, unsigned b) 297 { 298 double stderrA = sqrt(a); 299 double stderrB = sqrt(b); 300 if (a > b) { 301 return (double(a) - stderrA) <= (double(b) + stderrB); 302 } 303 return (double(b) - stderrB) <= (double(a) + stderrA); 304 } 305 306 /* 307 * Returns true if considered roughly equal or b is larger 308 */ compareStdErrLarger(unsigned a,unsigned b) const309 inline bool compareStdErrLarger(unsigned a, unsigned b) const 310 { 311 double stderrA = sqrt(a) * m_extraCount; 312 double stderrB = sqrt(b) * m_extraCount; 313 return (double(a) - stderrA) <= (double(b) + stderrB); 314 } 315 316 /* 317 * Returns true if considered roughly equal 318 */ isRoughlyEqual(const CountResult & a,const CountResult & b,unsigned extraCount) const319 bool isRoughlyEqual(const CountResult& a, const CountResult& b, unsigned extraCount) const 320 { 321 return ( 322 isRoughlyEqual(b.count, a.count, extraCount) && 323 isRoughlyEqual(b.totalNonSatCount, a.totalNonSatCount, extraCount) && 324 isRoughlyEqual(b.nonSatFrameCount, a.nonSatFrameCount, extraCount) && 325 isRoughlyEqual(b.solidCount, a.solidCount, extraCount) && 326 isRoughlyEqual(b.nonSatCount, a.nonSatCount, extraCount) && 327 isRoughlyEqual(b.totalCount, a.totalCount, extraCount)); 328 } 329 330 /* 331 * Returns true if considered roughly equal 332 */ isValid(const CountResult & a,const CountResult & b) const333 bool isValid(const CountResult& a, const CountResult& b) const 334 { 335 return ( 336 compareStdErr(b.count, a.count) || 337 compareStdErr(b.totalNonSatCount, a.totalNonSatCount) || 338 compareStdErr(b.nonSatFrameCount, a.nonSatFrameCount) || 339 compareStdErr(b.solidCount, a.solidCount) || 340 compareStdErr(b.nonSatCount, a.nonSatCount) || 341 compareStdErr(b.totalCount, a.totalCount)); 342 } 343 344 /* 345 * Returns true if considered roughly equal 346 */ isRoughlyEqualOrLarger(const QueryResult & a,const QueryResult & b) const347 bool isRoughlyEqualOrLarger(const QueryResult& a, const QueryResult& b) const 348 { 349 return ( 350 compareStdErrLarger(a.count, b.count) && 351 compareStdErrLarger(a.totalNonSatCount, b.totalNonSatCount) && 352 compareStdErrLarger(a.nonSatFrameCount, b.nonSatFrameCount) && 353 compareStdErrLarger(a.solidCount, b.solidCount) && 354 compareStdErrLarger(a.nonSatCount, b.nonSatCount) && 355 compareStdErrLarger(a.totalCount, b.totalCount)); 356 } 357 checkCountAgreement(QueryResult b,QueryResult a)358 bool checkCountAgreement(QueryResult b, QueryResult a) 359 { 360 return ( 361 b.nonSatFrameCount >= a.nonSatFrameCount && b.count >= a.count && 362 b.solidCount >= a.solidCount && b.nonSatCount >= a.nonSatCount && 363 b.totalNonSatCount >= a.totalNonSatCount && b.totalCount >= a.totalCount); 364 } 365 366 // contains reference to parent 367 const MIBloomFilter<T>& m_miBF; 368 const vector<double>& m_perFrameProb; 369 370 // not references, but shared other objects or static variables 371 const double m_extraCount; 372 const unsigned m_extraFrameLimit; 373 const unsigned m_maxMiss; 374 const unsigned m_minCount; 375 const bool m_bestHitAgree; 376 // const double m_rateSaturated; 377 378 // resusable variables 379 unsigned m_satCount; 380 unsigned m_evalCount; 381 382 // current bestCounts 383 CountResult m_bestCounts; 384 uint16_t m_secondBestNonSatFrameCount; 385 386 // resusable objects 387 vector<uint64_t> m_rankPos; 388 vector<bool> m_hits; 389 vector<QueryResult> m_signifResults; 390 vector<CountResult> m_counts; 391 vector<T> m_candidateMatches; 392 vector<T> m_seenSet; 393 394 // Number of reads processed by object 395 size_t m_totalReads; 396 397 bool updateCounts(const stHashIterator & itr,const vector<unsigned> & minCount,unsigned & extraFrame)398 updateCounts(const stHashIterator& itr, const vector<unsigned>& minCount, unsigned& extraFrame) 399 { 400 bool candidateFound = false; 401 unsigned misses = m_miBF.atRank(*itr, m_rankPos, m_hits, m_maxMiss); 402 if (misses <= m_maxMiss) { 403 candidateFound = updatesCounts(minCount, extraFrame, misses); 404 } 405 return candidateFound; 406 } 407 408 bool updateCounts(const ntHashIterator & itr,const vector<unsigned> & minCount,unsigned & extraFrame)409 updateCounts(const ntHashIterator& itr, const vector<unsigned>& minCount, unsigned& extraFrame) 410 { 411 bool candidateFound = false; 412 if (m_miBF.atRank(*itr, m_rankPos)) { 413 candidateFound = updatesCounts(minCount, extraFrame); 414 } 415 ++m_evalCount; 416 return candidateFound; 417 } 418 init()419 void init() 420 { 421 m_candidateMatches.clear(); 422 m_signifResults.clear(); 423 m_satCount = 0; 424 m_evalCount = 0; 425 m_bestCounts = { 0, 0, 0, 0, 0, 0, 0 }; 426 m_secondBestNonSatFrameCount = 0; 427 ++m_totalReads; 428 } 429 updatesCounts(const vector<unsigned> & minCount,unsigned & extraFrame,unsigned misses=0)430 bool updatesCounts(const vector<unsigned>& minCount, unsigned& extraFrame, unsigned misses = 0) 431 { 432 m_seenSet.clear(); 433 unsigned satCount = 0; 434 for (unsigned i = 0; i < m_miBF.getHashNum(); ++i) { 435 if (m_hits[i]) { 436 T resultRaw = m_miBF.getData(m_rankPos[i]); 437 ++m_evalCount; 438 bool saturated = false; 439 T result = resultRaw; 440 441 // check for saturation 442 if (result > m_miBF.s_mask) { 443 result &= m_miBF.s_antiMask; 444 saturated = true; 445 satCount++; 446 // detemines if count should be reset 447 if (m_totalReads != m_counts[result].readCount) { 448 m_counts[result] = { 0, 0, 0, 0, 0, 0, m_totalReads }; 449 } 450 } else { 451 if (m_totalReads != m_counts[result].readCount) { 452 m_counts[result] = { 0, 0, 0, 0, 0, 0, m_totalReads }; 453 } 454 ++m_counts[result].totalNonSatCount; 455 } 456 ++m_counts[result].totalCount; 457 if (find(m_seenSet.begin(), m_seenSet.end(), resultRaw) == m_seenSet.end()) { 458 // check for saturation 459 if (saturated) { 460 // if the non-saturated version has not been seen before 461 if (find(m_seenSet.begin(), m_seenSet.end(), result) == m_seenSet.end()) { 462 // check is count is exceeded 463 ++m_counts[result].count; 464 } 465 } else { 466 ++m_counts[result].nonSatCount; 467 // check is count is exceeded 468 ++m_counts[result].count; 469 } 470 m_seenSet.push_back(resultRaw); 471 } 472 } 473 } 474 if (satCount == 0) { 475 for (typename vector<T>::iterator itr = m_seenSet.begin(); itr != m_seenSet.end(); 476 ++itr) { 477 ++m_counts[*itr].nonSatFrameCount; 478 if (misses == 0) { 479 ++m_counts[*itr].solidCount; 480 } 481 } 482 } else { 483 ++m_satCount; 484 } 485 for (typename vector<T>::iterator itr = m_seenSet.begin(); itr != m_seenSet.end(); ++itr) { 486 T result = *itr; 487 if (result > m_miBF.s_mask) { 488 // if non-saturated version already exists 489 if (find(m_seenSet.begin(), m_seenSet.end(), result & m_miBF.s_antiMask) != 490 m_seenSet.end()) { 491 continue; 492 } 493 result &= m_miBF.s_antiMask; 494 } 495 if (m_counts[result].count >= minCount[result]) { 496 if (find(m_candidateMatches.begin(), m_candidateMatches.end(), result) == 497 m_candidateMatches.end()) { 498 m_candidateMatches.push_back(result); 499 } 500 updateMaxCounts(m_counts[result]); 501 } else if (m_candidateMatches.size() && m_counts[result].count >= m_bestCounts.count) { 502 if (find(m_candidateMatches.begin(), m_candidateMatches.end(), result) == 503 m_candidateMatches.end()) { 504 m_candidateMatches.push_back(result); 505 } 506 updateMaxCounts(m_counts[result]); 507 } 508 } 509 if (compareStdErr(m_bestCounts.totalNonSatCount, m_secondBestNonSatFrameCount)) { 510 extraFrame = 0; 511 } 512 if (m_bestCounts.nonSatFrameCount > m_secondBestNonSatFrameCount) { 513 if (m_extraFrameLimit < extraFrame++) { 514 return true; 515 } 516 } 517 return false; 518 } 519 updateMaxCounts(const CountResult & count)520 void updateMaxCounts(const CountResult& count) 521 { 522 if (count.nonSatFrameCount > m_bestCounts.nonSatFrameCount) { 523 m_bestCounts.nonSatFrameCount = count.nonSatFrameCount; 524 } else if (count.nonSatFrameCount > m_secondBestNonSatFrameCount) { 525 m_secondBestNonSatFrameCount = count.nonSatFrameCount; 526 } 527 if (count.count > m_bestCounts.count) { 528 m_bestCounts.count = count.count; 529 } 530 if (count.nonSatCount > m_bestCounts.nonSatCount) { 531 m_bestCounts.nonSatCount = count.nonSatCount; 532 } 533 if (count.solidCount > m_bestCounts.solidCount) { 534 m_bestCounts.solidCount = count.solidCount; 535 } 536 if (count.totalCount > m_bestCounts.totalCount) { 537 m_bestCounts.totalCount = count.totalCount; 538 } 539 if (count.totalNonSatCount > m_bestCounts.totalNonSatCount) { 540 m_bestCounts.totalNonSatCount = count.totalNonSatCount; 541 } 542 } 543 544 double calcSat(unsigned evaluatedValues,double singleEventProbSaturted,unsigned saturatedCount)545 calcSat(unsigned evaluatedValues, double singleEventProbSaturted, unsigned saturatedCount) 546 { 547 double probSaturated = 0; 548 if (saturatedCount) { 549 binomial bin(evaluatedValues, singleEventProbSaturted); 550 probSaturated = cdf(bin, saturatedCount - 1); 551 } 552 return probSaturated; 553 } 554 summarizeCandiates()555 void summarizeCandiates() 556 { 557 if (m_candidateMatches.size() && m_minCount <= m_bestCounts.nonSatFrameCount) { 558 vector<QueryResult> signifResults; 559 for (typename vector<T>::const_iterator candidate = m_candidateMatches.begin(); 560 candidate != m_candidateMatches.end(); 561 candidate++) { 562 const CountResult& resultCount = m_counts[*candidate]; 563 if (isValid(resultCount, m_bestCounts)) { 564 QueryResult result; 565 result.id = *candidate; 566 result.count = resultCount.count; 567 result.nonSatCount = resultCount.nonSatCount; 568 result.totalCount = resultCount.totalCount; 569 result.totalNonSatCount = resultCount.totalNonSatCount; 570 result.nonSatFrameCount = resultCount.nonSatFrameCount; 571 result.solidCount = resultCount.solidCount; 572 result.frameProb = m_perFrameProb.at(*candidate); 573 signifResults.push_back(result); 574 } 575 } 576 if (signifResults.size() > 1) { 577 sort(signifResults.begin(), signifResults.end(), sortCandidates); 578 // sort(signifResults.begin(), signifResults.end(), 579 // bind(sortCandidates, placeholders::_1, placeholders::_2, 580 // m_extraCount)); 581 for (typename vector<QueryResult>::iterator candidate = signifResults.begin(); 582 candidate != signifResults.end(); 583 candidate++) { 584 if (isRoughlyEqualOrLarger(signifResults[0], *candidate)) { 585 m_signifResults.push_back(*candidate); 586 } 587 } 588 if (m_bestHitAgree && m_signifResults.size() >= 2 && 589 !checkCountAgreement(m_signifResults[0], m_signifResults[1])) { 590 m_signifResults.clear(); 591 } 592 } else { 593 m_signifResults.push_back(signifResults[0]); 594 } 595 } 596 } 597 }; 598 599 #endif /* MIBFQUERYSUPPORT_HPP_ */ 600