1 #include <boost/lexical_cast.hpp>
2 #include <boost/math/special_functions/binomial.hpp>
3 
4 #include <cmath>
5 #include <cstdlib>
6 #include <fstream>
7 #include <iostream>
8 #include <limits>
9 #include <string>
10 
11 
12 #include "MathUtil.h"
13 #include "StrIdMap.h"
14 #include "DataStructures.h"
15 
16 #ifndef HMM_H_
17 #define HMM_H_
18 
19 class PossibleState;
20 class Hmm;
21 
22 /** A transition between two consecutive Hmm nodes. */
23 class Transition {
24 public:
25 	unsigned long obs;						// observed value transitioning to: this is 1 or 0 for the PHASE_CONCORDANCE method
26 											// it is the reference allele count for the READ_COUNTS method
27 
28 	// extra observations needed for the READ_COUNTS method
29 	unsigned int n1,n2,r1,r2;						// this is coverage at a site and ref_coverage: used for the READ_COUNTS method
30 	bool ref1,ref2;								// true if emitting reference allele, otherwise false
31 
32 	PossibleState* from;					// source node
33 	PossibleState* to;						// destination node
34 
35 	double logDiGamma;						// used for parameter re-estimation
setLogDiGamma(double logDiGamma)36 	void setLogDiGamma(double logDiGamma) 	{ this->logDiGamma = logDiGamma;	}
getLogDiGamma()37 	double getLogDiGamma() const 			{ return this->logDiGamma;}
38 	Transition(PossibleState* from, PossibleState* to, unsigned long obs, unsigned int n1, unsigned int n2, unsigned int r1, unsigned int r2, bool ref1, bool ref2);
39 	~Transition();
40 };
41 
42 /** A node in an Hmm object. */
43 class PossibleState {
44 	int time; 								// The time slot for this node.
45 	Hmm* hmm; 								// The hmm that this node belongs to
46 	unsigned long state; 					// the state
47 	std::vector<Transition*> inTrans; 		// incoming transitions
48 	std::vector<Transition*> outTrans;		// outgoing transitions
49 	double logAlpha; 						// alpha_t(s) = P(e_1:t, x_t=s);
50 	double logBeta;  						// beta_t(s)  = P(e_t+1:T | x_t=s);
51 	double logPosterior;					// posterior probability that node is in this state given observations
52 	double logGamma;						// used for parameter re-estimation
53 	Transition* psi; 						// the last transition of the most probable path that reaches this node
54 public:
getTime()55 	int getTime() 							{ return this->time; 			}
getState()56 	unsigned long getState() const 			{ return this->state;			}
setLogAlpha(double logAlpha)57 	void setLogAlpha(double logAlpha) 		{ this->logAlpha = logAlpha;	}
getLogAlpha()58 	double getLogAlpha() const 				{ return this->logAlpha;		}
setLogBeta(double logBeta)59 	void setLogBeta(double logBeta) 		{ this->logBeta = logBeta;		}
getLogBeta()60 	double getLogBeta() const 				{ return this->logBeta;			}
setLogPosterior(double logPosterior)61 	void setLogPosterior(double logPosterior) 	{ this->logPosterior = logPosterior;	}
getLogPosterior()62 	double getLogPosterior() const 			{ return this->logPosterior;			}
setLogGamma(double logGamma)63 	void setLogGamma(double logGamma) 	{ this->logGamma = logGamma;	}
getLogGamma()64 	double getLogGamma() const 			{ return this->logGamma;}
65 
getPsi()66 	Transition* getPsi() 					{ return this->psi;				}
setPsi(Transition * psi)67 	void setPsi(Transition* psi) 			{ this->psi = psi;				}
getInTrans()68 	std::vector<Transition*>& getInTrans()	{ return this->inTrans;			}
getOutTrans()69 	std::vector<Transition*>& getOutTrans()	{ return this->outTrans;		}
70 
71 	void print();
72 
73 	PossibleState(int time, unsigned long state, Hmm* hmm);
74 	~PossibleState();
75 };
76 
77 /** The possible states at a particular time. */
78 // This kind of notation allows you to use an iterator
79 // on LatentStateZ to loop through possible states.
80 class LatentStateZ : public std::vector<PossibleState*> {
81 	double scaleFactor;
82 public:
setScaleFactor(double scaleFactor)83 	void setScaleFactor(double scaleFactor) { this->scaleFactor = scaleFactor;	}
getScaleFactor()84 	double getScaleFactor() const 			{ return this->scaleFactor;			}
85 
86 	~LatentStateZ();
87 };
88 
89 /** Pseudo Counts */
90 class PseudoCounts {
91 	HmmMap stateCount;
92 	HmmKeyedMaps transCount;
93 	HmmKeyedMaps emitCount;
94 public:
getStateCount()95 	HmmMap& getStateCount() 			{ return this->stateCount;}
getTransCount()96 	HmmKeyedMaps& getTransCount() 		{ return this->transCount;}
getEmitCount()97 	HmmKeyedMaps& getEmitCount()  		{ return this->emitCount;}
98 	void print(StrIdMap& strIdMap);
99 };
100 
101 /** An Hmm object implements the Hidden Markov Model. */
102 class Hmm {
103 	unsigned long initState; 					// the initial state
104 	std::vector<unsigned long> states;			// states (not including the init state)
105 	HmmKeyedMaps transProbs;    				// transition probabilities
106 	HmmKeyedMaps emitProbs;      				// emission probabilities
107 	StrIdMap strIdMap;        					// mapping between strings and integers
108 	std::vector<LatentStateZ*> latentStatesZ; 	// the time steps
109 	std::vector<double> likelihoodProfile;		// likelihood profile (for diagnostics)
110 	double minLogProb;       					// log probabilities lower than this are set to 0
111 	double logProbObs;
112 	bool forwardHasBeenRun;
113 	bool backwardHasBeenRun;
114 	bool forwardBackwardHasBeenRun;
115 	unsigned int numStates;						// number of possible HMM states
116 	std::string obsType;
117 	double epsilon;								// phase error (currently treated as a constant)
118 
119 public:
120 	// possible observed types for the HMM
121 	static const std::string PHASE_CONCORDANCE, RAF_DEVIATION, BAF_DEVIATION;
122 
setInitState(std::string initState)123 	void setInitState(std::string initState)	{ this->initState = this->strIdMap.getId(initState);}
setObsType(std::string obsType)124 	void setObsType(std::string obsType)		{ this->obsType = obsType;	}
getObsType()125 	std::string& getObsType()					{ return this->obsType;		}
getStates()126 	std::vector<unsigned long>& getStates()		{ return this->states;		}
getInitState()127 	unsigned long getInitState()				{ return this->initState;	}
getEmitProbs()128 	HmmKeyedMaps& getEmitProbs()				{ return this->emitProbs;	}
getTransProbs()129 	HmmKeyedMaps& getTransProbs()				{ return this->transProbs;	}
getStrIdMap()130 	StrIdMap& getStrIdMap()						{ return this->strIdMap;	}
setLogProbObs(double logProbObs)131 	void setLogProbObs(double logProbObs)		{ this->logProbObs = logProbObs;	}
getLogProbObs()132 	double getLogProbObs()						{ return this->logProbObs;	}
getLatentStatesZ()133 	std::vector<LatentStateZ*>& getLatentStatesZ()	{ return this->latentStatesZ;	}
getLikelihoodProfile()134 	std::vector<double>& getLikelihoodProfile()	{ return this->likelihoodProfile;	}
135 
setEpsilon(double epsilon)136 	void setEpsilon(double epsilon)		{ this->epsilon = epsilon;	}
getEpsilon()137 	double getEpsilon()						{ return this->epsilon;	}
138 
139 	double getTransProb(Transition* trans);
140 	double getEmitProb(Transition* trans);
141 
hasForwardBeenRun()142 	bool hasForwardBeenRun()					{ return this->forwardHasBeenRun;	}
hasBackwardBeenRun()143 	bool hasBackwardBeenRun()					{ return this->backwardHasBeenRun;	}
hasForwardBackwardBeenRun()144 	bool hasForwardBackwardBeenRun()			{ return this->forwardBackwardHasBeenRun;	}
145 
146 	void forward(bool force=false);  		// compute the forward probabilities P(e_1:t, X_t=s)
147 	void backward(bool force=false); 		// compute the backward probabilities P(e_t+1:T | X_t=s)
148 	void forwardBackward(bool force=false);	// compute forward/backward probabilities
149 
150 	/** Retrieves posterior probabilities in a vector */
151 	std::vector<double> extractPosteriors(unsigned long state);
152 
153 	/** Re-compute the transition and emission probabilities according
154       to the pseudo counts. */
155 	void updateProbs(PseudoCounts& counts);
156 
157 	/** Accumulate pseudo counts using the BaumWelch algorithm.  The
158       return value is the probability of the observations according to
159       the current model.  */
160 	double getPseudoCounts(PseudoCounts& counts);
161 
162 	void estimateProbs(bool estimateTrans, bool estimateAberrantEmit, bool estimateNormalEmit, std::string normalState);
163 
164 	/** Add an observation into the Hmm after the current last time
165       slot. The states that have non-zero probability of generating
166       the observation will be created. The transition between the new
167       states and the states in the previous time slots will also be
168       created.*/
169 	void addObservation(unsigned long obs, unsigned long n1, unsigned long n2, unsigned long r1, unsigned long r2, bool ref1, bool ref2);
170 
171 	/** Same as void addObservation(unsigned long obs) above, except
172       with a different form of parameter. */
173 	// TODO: move implementation to Hmm.cpp
addObservation(std::string obs,unsigned long n1,unsigned long n2,unsigned long r1,unsigned long r2,bool ref1,bool ref2)174 	void addObservation(std::string obs, unsigned long n1, unsigned long n2, unsigned long r1, unsigned long r2, bool ref1, bool ref2) {
175 		addObservation(atoi(obs.c_str()), n1, n2, r1, r2, ref1, ref2);
176 	}
177 
178 	// Dummy values are inserted for the PHASE_CONCORDANCE method
addObservation(std::string obs)179 	void addObservation(std::string obs) {
180 		addObservation(strIdMap.getId(obs), 0, 0, 0, 0, false, false);
181 	}
182 
183 	std::string getObservation(unsigned long time);
184 	std::vector<std::string> getObservations();
185 
186 	/** Read the transition and emission probability tables from the
187       files NAME.trans and NAME.emit, where NAME is the value of the
188       variable name.*/
189 	void loadProbs(std::string initProbFilename, std::string transProbFilename, std::string emitProbFilename);
190 	void loadProbs(double& lambda_0, double& lambda_1, double& alpha_0, double& alpha_1, std::string normalState);
191 //	void loadProbs(double& eventPrevalence, unsigned int& eventLengthMarkers, double& initialParamNormal, double& initialParamEvent, std::string normalState);
192 
193 	/** Save the transition and emission probability tables into the
194       files NAME.trans and NAME.emit, where NAME is the value of the
195       variable name. If name is "", both tables are printed on the
196       standard output. */
197 	void saveProbs(std::string name="");
198 	void writeTrans(std::ostream& file);
199 	void writeEmit(std::ostream& file);
200 
201 	/** Read the training data from the input stream. Each line in the
202       input stream is an observation sequence. */
203 	void readSeqs(
204 			std::string obsSeqFilename, std::vector<std::vector<unsigned long>*>& sequences
205 	);
206 
207 
208 	/** Find the state sequence (a path) that has the maximum
209       probability given the sequence of observations:
210          max_{x_1:T} P(x_1:T | e_1:T);
211       The return value is the logarithm of the joint probability
212       of the state sequence and the observation sequence:
213         log P(x_1:T, e_1:T)
214 	 */
215 	double viterbi(std::vector<Transition*>& path);
216 
217 	/** return the logarithm of the observation sequence: log P(e_1:T) */
218 	double calcLogProbObs();
219 
220 	/** Train the model with the given observation sequences using the
221       Baum-Welch algorithm. */
222 	void baumWelch(std::vector<std::string>& observedSequence, const std::vector<unsigned int>& covSequence, std::vector<bool>& refSequence, unsigned int maxIterations, bool estimateTrans, bool estimateAberrantEmit, bool estimateNormalEmit, std::string normalState);
223 	void baumWelch(std::vector<std::string>& observedSequence, const std::vector<double>& logRRSequence, std::vector<bool>& bAlleleSequence, unsigned int iterations, bool estimateTrans, bool estimateAberrantEmit, bool estimateNormalEmit, std::string normalState);
224 
225 	/** Conversion between the integer id and string form of states and
226       observations. */
getStr(unsigned long id)227 	std::string getStr(unsigned long id) { return this->strIdMap.getStr(id);}
getId(std::string str)228 	unsigned long getId(std::string str) { return this->strIdMap.getId(str);}
229 
230 	/** Clear all time slots to get ready to deal with another
231       sequence. */
232 	void reset();
233 
234 	/** Print the states at all time slots and the alpha/beta values at
235       these states. */
236 	void print();
237 
238 	/** Generate seqs observation sequences according to the model. */
239 	void genSeqs(std::ostream& ostrm, int seqs);
240 
241 	/** Generate an observation sequence with up to maxlen elements
242       according to the model. */
243 	void genSeq(std::vector<unsigned long>& seq);
244 
245 	Hmm(unsigned int numStates, std::string& obsType);
246 	~Hmm();
247 };
248 
249 #endif /* HMM_H_ */
250