1 #include <boost/lexical_cast.hpp> 2 #include <boost/math/special_functions/binomial.hpp> 3 4 #include <cmath> 5 #include <cstdlib> 6 #include <fstream> 7 #include <iostream> 8 #include <limits> 9 #include <string> 10 11 12 #include "MathUtil.h" 13 #include "StrIdMap.h" 14 #include "DataStructures.h" 15 16 #ifndef HMM_H_ 17 #define HMM_H_ 18 19 class PossibleState; 20 class Hmm; 21 22 /** A transition between two consecutive Hmm nodes. */ 23 class Transition { 24 public: 25 unsigned long obs; // observed value transitioning to: this is 1 or 0 for the PHASE_CONCORDANCE method 26 // it is the reference allele count for the READ_COUNTS method 27 28 // extra observations needed for the READ_COUNTS method 29 unsigned int n1,n2,r1,r2; // this is coverage at a site and ref_coverage: used for the READ_COUNTS method 30 bool ref1,ref2; // true if emitting reference allele, otherwise false 31 32 PossibleState* from; // source node 33 PossibleState* to; // destination node 34 35 double logDiGamma; // used for parameter re-estimation setLogDiGamma(double logDiGamma)36 void setLogDiGamma(double logDiGamma) { this->logDiGamma = logDiGamma; } getLogDiGamma()37 double getLogDiGamma() const { return this->logDiGamma;} 38 Transition(PossibleState* from, PossibleState* to, unsigned long obs, unsigned int n1, unsigned int n2, unsigned int r1, unsigned int r2, bool ref1, bool ref2); 39 ~Transition(); 40 }; 41 42 /** A node in an Hmm object. */ 43 class PossibleState { 44 int time; // The time slot for this node. 45 Hmm* hmm; // The hmm that this node belongs to 46 unsigned long state; // the state 47 std::vector<Transition*> inTrans; // incoming transitions 48 std::vector<Transition*> outTrans; // outgoing transitions 49 double logAlpha; // alpha_t(s) = P(e_1:t, x_t=s); 50 double logBeta; // beta_t(s) = P(e_t+1:T | x_t=s); 51 double logPosterior; // posterior probability that node is in this state given observations 52 double logGamma; // used for parameter re-estimation 53 Transition* psi; // the last transition of the most probable path that reaches this node 54 public: getTime()55 int getTime() { return this->time; } getState()56 unsigned long getState() const { return this->state; } setLogAlpha(double logAlpha)57 void setLogAlpha(double logAlpha) { this->logAlpha = logAlpha; } getLogAlpha()58 double getLogAlpha() const { return this->logAlpha; } setLogBeta(double logBeta)59 void setLogBeta(double logBeta) { this->logBeta = logBeta; } getLogBeta()60 double getLogBeta() const { return this->logBeta; } setLogPosterior(double logPosterior)61 void setLogPosterior(double logPosterior) { this->logPosterior = logPosterior; } getLogPosterior()62 double getLogPosterior() const { return this->logPosterior; } setLogGamma(double logGamma)63 void setLogGamma(double logGamma) { this->logGamma = logGamma; } getLogGamma()64 double getLogGamma() const { return this->logGamma;} 65 getPsi()66 Transition* getPsi() { return this->psi; } setPsi(Transition * psi)67 void setPsi(Transition* psi) { this->psi = psi; } getInTrans()68 std::vector<Transition*>& getInTrans() { return this->inTrans; } getOutTrans()69 std::vector<Transition*>& getOutTrans() { return this->outTrans; } 70 71 void print(); 72 73 PossibleState(int time, unsigned long state, Hmm* hmm); 74 ~PossibleState(); 75 }; 76 77 /** The possible states at a particular time. */ 78 // This kind of notation allows you to use an iterator 79 // on LatentStateZ to loop through possible states. 80 class LatentStateZ : public std::vector<PossibleState*> { 81 double scaleFactor; 82 public: setScaleFactor(double scaleFactor)83 void setScaleFactor(double scaleFactor) { this->scaleFactor = scaleFactor; } getScaleFactor()84 double getScaleFactor() const { return this->scaleFactor; } 85 86 ~LatentStateZ(); 87 }; 88 89 /** Pseudo Counts */ 90 class PseudoCounts { 91 HmmMap stateCount; 92 HmmKeyedMaps transCount; 93 HmmKeyedMaps emitCount; 94 public: getStateCount()95 HmmMap& getStateCount() { return this->stateCount;} getTransCount()96 HmmKeyedMaps& getTransCount() { return this->transCount;} getEmitCount()97 HmmKeyedMaps& getEmitCount() { return this->emitCount;} 98 void print(StrIdMap& strIdMap); 99 }; 100 101 /** An Hmm object implements the Hidden Markov Model. */ 102 class Hmm { 103 unsigned long initState; // the initial state 104 std::vector<unsigned long> states; // states (not including the init state) 105 HmmKeyedMaps transProbs; // transition probabilities 106 HmmKeyedMaps emitProbs; // emission probabilities 107 StrIdMap strIdMap; // mapping between strings and integers 108 std::vector<LatentStateZ*> latentStatesZ; // the time steps 109 std::vector<double> likelihoodProfile; // likelihood profile (for diagnostics) 110 double minLogProb; // log probabilities lower than this are set to 0 111 double logProbObs; 112 bool forwardHasBeenRun; 113 bool backwardHasBeenRun; 114 bool forwardBackwardHasBeenRun; 115 unsigned int numStates; // number of possible HMM states 116 std::string obsType; 117 double epsilon; // phase error (currently treated as a constant) 118 119 public: 120 // possible observed types for the HMM 121 static const std::string PHASE_CONCORDANCE, RAF_DEVIATION, BAF_DEVIATION; 122 setInitState(std::string initState)123 void setInitState(std::string initState) { this->initState = this->strIdMap.getId(initState);} setObsType(std::string obsType)124 void setObsType(std::string obsType) { this->obsType = obsType; } getObsType()125 std::string& getObsType() { return this->obsType; } getStates()126 std::vector<unsigned long>& getStates() { return this->states; } getInitState()127 unsigned long getInitState() { return this->initState; } getEmitProbs()128 HmmKeyedMaps& getEmitProbs() { return this->emitProbs; } getTransProbs()129 HmmKeyedMaps& getTransProbs() { return this->transProbs; } getStrIdMap()130 StrIdMap& getStrIdMap() { return this->strIdMap; } setLogProbObs(double logProbObs)131 void setLogProbObs(double logProbObs) { this->logProbObs = logProbObs; } getLogProbObs()132 double getLogProbObs() { return this->logProbObs; } getLatentStatesZ()133 std::vector<LatentStateZ*>& getLatentStatesZ() { return this->latentStatesZ; } getLikelihoodProfile()134 std::vector<double>& getLikelihoodProfile() { return this->likelihoodProfile; } 135 setEpsilon(double epsilon)136 void setEpsilon(double epsilon) { this->epsilon = epsilon; } getEpsilon()137 double getEpsilon() { return this->epsilon; } 138 139 double getTransProb(Transition* trans); 140 double getEmitProb(Transition* trans); 141 hasForwardBeenRun()142 bool hasForwardBeenRun() { return this->forwardHasBeenRun; } hasBackwardBeenRun()143 bool hasBackwardBeenRun() { return this->backwardHasBeenRun; } hasForwardBackwardBeenRun()144 bool hasForwardBackwardBeenRun() { return this->forwardBackwardHasBeenRun; } 145 146 void forward(bool force=false); // compute the forward probabilities P(e_1:t, X_t=s) 147 void backward(bool force=false); // compute the backward probabilities P(e_t+1:T | X_t=s) 148 void forwardBackward(bool force=false); // compute forward/backward probabilities 149 150 /** Retrieves posterior probabilities in a vector */ 151 std::vector<double> extractPosteriors(unsigned long state); 152 153 /** Re-compute the transition and emission probabilities according 154 to the pseudo counts. */ 155 void updateProbs(PseudoCounts& counts); 156 157 /** Accumulate pseudo counts using the BaumWelch algorithm. The 158 return value is the probability of the observations according to 159 the current model. */ 160 double getPseudoCounts(PseudoCounts& counts); 161 162 void estimateProbs(bool estimateTrans, bool estimateAberrantEmit, bool estimateNormalEmit, std::string normalState); 163 164 /** Add an observation into the Hmm after the current last time 165 slot. The states that have non-zero probability of generating 166 the observation will be created. The transition between the new 167 states and the states in the previous time slots will also be 168 created.*/ 169 void addObservation(unsigned long obs, unsigned long n1, unsigned long n2, unsigned long r1, unsigned long r2, bool ref1, bool ref2); 170 171 /** Same as void addObservation(unsigned long obs) above, except 172 with a different form of parameter. */ 173 // TODO: move implementation to Hmm.cpp addObservation(std::string obs,unsigned long n1,unsigned long n2,unsigned long r1,unsigned long r2,bool ref1,bool ref2)174 void addObservation(std::string obs, unsigned long n1, unsigned long n2, unsigned long r1, unsigned long r2, bool ref1, bool ref2) { 175 addObservation(atoi(obs.c_str()), n1, n2, r1, r2, ref1, ref2); 176 } 177 178 // Dummy values are inserted for the PHASE_CONCORDANCE method addObservation(std::string obs)179 void addObservation(std::string obs) { 180 addObservation(strIdMap.getId(obs), 0, 0, 0, 0, false, false); 181 } 182 183 std::string getObservation(unsigned long time); 184 std::vector<std::string> getObservations(); 185 186 /** Read the transition and emission probability tables from the 187 files NAME.trans and NAME.emit, where NAME is the value of the 188 variable name.*/ 189 void loadProbs(std::string initProbFilename, std::string transProbFilename, std::string emitProbFilename); 190 void loadProbs(double& lambda_0, double& lambda_1, double& alpha_0, double& alpha_1, std::string normalState); 191 // void loadProbs(double& eventPrevalence, unsigned int& eventLengthMarkers, double& initialParamNormal, double& initialParamEvent, std::string normalState); 192 193 /** Save the transition and emission probability tables into the 194 files NAME.trans and NAME.emit, where NAME is the value of the 195 variable name. If name is "", both tables are printed on the 196 standard output. */ 197 void saveProbs(std::string name=""); 198 void writeTrans(std::ostream& file); 199 void writeEmit(std::ostream& file); 200 201 /** Read the training data from the input stream. Each line in the 202 input stream is an observation sequence. */ 203 void readSeqs( 204 std::string obsSeqFilename, std::vector<std::vector<unsigned long>*>& sequences 205 ); 206 207 208 /** Find the state sequence (a path) that has the maximum 209 probability given the sequence of observations: 210 max_{x_1:T} P(x_1:T | e_1:T); 211 The return value is the logarithm of the joint probability 212 of the state sequence and the observation sequence: 213 log P(x_1:T, e_1:T) 214 */ 215 double viterbi(std::vector<Transition*>& path); 216 217 /** return the logarithm of the observation sequence: log P(e_1:T) */ 218 double calcLogProbObs(); 219 220 /** Train the model with the given observation sequences using the 221 Baum-Welch algorithm. */ 222 void baumWelch(std::vector<std::string>& observedSequence, const std::vector<unsigned int>& covSequence, std::vector<bool>& refSequence, unsigned int maxIterations, bool estimateTrans, bool estimateAberrantEmit, bool estimateNormalEmit, std::string normalState); 223 void baumWelch(std::vector<std::string>& observedSequence, const std::vector<double>& logRRSequence, std::vector<bool>& bAlleleSequence, unsigned int iterations, bool estimateTrans, bool estimateAberrantEmit, bool estimateNormalEmit, std::string normalState); 224 225 /** Conversion between the integer id and string form of states and 226 observations. */ getStr(unsigned long id)227 std::string getStr(unsigned long id) { return this->strIdMap.getStr(id);} getId(std::string str)228 unsigned long getId(std::string str) { return this->strIdMap.getId(str);} 229 230 /** Clear all time slots to get ready to deal with another 231 sequence. */ 232 void reset(); 233 234 /** Print the states at all time slots and the alpha/beta values at 235 these states. */ 236 void print(); 237 238 /** Generate seqs observation sequences according to the model. */ 239 void genSeqs(std::ostream& ostrm, int seqs); 240 241 /** Generate an observation sequence with up to maxlen elements 242 according to the model. */ 243 void genSeq(std::vector<unsigned long>& seq); 244 245 Hmm(unsigned int numStates, std::string& obsType); 246 ~Hmm(); 247 }; 248 249 #endif /* HMM_H_ */ 250