1 package org.broadinstitute.hellbender.tools.walkers.vqsr; 2 3 import org.apache.logging.log4j.Logger; 4 import org.apache.logging.log4j.LogManager; 5 6 import org.broadinstitute.hellbender.utils.MathUtils; 7 import org.broadinstitute.hellbender.utils.Utils; 8 9 import java.util.ArrayList; 10 import java.util.Arrays; 11 import java.util.Collections; 12 import java.util.List; 13 14 import Jama.Matrix; 15 16 class GaussianMixtureModel { 17 18 protected final static Logger logger = LogManager.getLogger(GaussianMixtureModel.class); 19 20 private final List<MultivariateGaussian> gaussians; 21 private final double shrinkage; 22 private final double dirichletParameter; 23 private final double priorCounts; 24 private final double[] empiricalMu; 25 private final Matrix empiricalSigma; 26 public boolean isModelReadyForEvaluation; 27 public boolean failedToConverge = false; 28 GaussianMixtureModel( final int numGaussians, final int numVariantData, final int numAnnotations, final double shrinkage, final double dirichletParameter, final double priorCounts )29 public GaussianMixtureModel( final int numGaussians, final int numVariantData, final int numAnnotations, 30 final double shrinkage, final double dirichletParameter, final double priorCounts ) { 31 32 gaussians = new ArrayList<>( numGaussians ); 33 for( int iii = 0; iii < numGaussians; iii++ ) { 34 final MultivariateGaussian gaussian = new MultivariateGaussian( numVariantData, numAnnotations ); 35 gaussians.add( gaussian ); 36 } 37 this.shrinkage = shrinkage; 38 this.dirichletParameter = dirichletParameter; 39 this.priorCounts = priorCounts; 40 empiricalMu = new double[numAnnotations]; 41 empiricalSigma = new Matrix(numAnnotations, numAnnotations); 42 isModelReadyForEvaluation = false; 43 Arrays.fill(empiricalMu, 0.0); 44 empiricalSigma.setMatrix(0, empiricalMu.length - 1, 0, empiricalMu.length - 1, Matrix.identity(empiricalMu.length, empiricalMu.length).times(200.0).inverse()); 45 } 46 47 //this is used for the model output unit test GaussianMixtureModel(final List<MultivariateGaussian> gaussians, final double shrinkage, final double dirichletParameter, final double priorCounts )48 protected GaussianMixtureModel(final List<MultivariateGaussian> gaussians, final double shrinkage, final double dirichletParameter, final double priorCounts ) { 49 this.gaussians = gaussians; 50 final int numAnnotations = gaussians.get(0).mu.length; 51 this.shrinkage = shrinkage; 52 this.dirichletParameter = dirichletParameter; 53 this.priorCounts = priorCounts; 54 empiricalMu = new double[numAnnotations]; 55 empiricalSigma = new Matrix(numAnnotations, numAnnotations); 56 isModelReadyForEvaluation = false; 57 Arrays.fill(empiricalMu, 0.0); 58 empiricalSigma.setMatrix(0, empiricalMu.length - 1, 0, empiricalMu.length - 1, Matrix.identity(empiricalMu.length, empiricalMu.length).times(200.0).inverse()); 59 60 } 61 initializeRandomModel( final List<VariantDatum> data, final int numKMeansIterations )62 public void initializeRandomModel( final List<VariantDatum> data, final int numKMeansIterations ) { 63 64 // initialize random Gaussian means // BUGBUG: this is broken up this way to match the order of calls to rand.nextDouble() in the old code 65 for( final MultivariateGaussian gaussian : gaussians ) { 66 gaussian.initializeRandomMu( Utils.getRandomGenerator() ); 67 } 68 69 // initialize means using K-means algorithm 70 logger.info( "Initializing model with " + numKMeansIterations + " k-means iterations..." ); 71 initializeMeansUsingKMeans( data, numKMeansIterations ); 72 73 // initialize uniform mixture coefficients, random covariance matrices, and initial hyperparameters 74 for( final MultivariateGaussian gaussian : gaussians ) { 75 gaussian.pMixtureLog10 = Math.log10( 1.0 / ((double) gaussians.size()) ); 76 gaussian.sumProb = 1.0 / ((double) gaussians.size()); 77 gaussian.initializeRandomSigma( Utils.getRandomGenerator() ); 78 gaussian.hyperParameter_a = priorCounts; 79 gaussian.hyperParameter_b = shrinkage; 80 gaussian.hyperParameter_lambda = dirichletParameter; 81 } 82 } 83 initializeMeansUsingKMeans( final List<VariantDatum> data, final int numIterations )84 private void initializeMeansUsingKMeans( final List<VariantDatum> data, final int numIterations ) { 85 86 int ttt = 0; 87 while( ttt++ < numIterations ) { 88 // E step: assign each variant to the nearest cluster 89 for( final VariantDatum datum : data ) { 90 double minDistance = Double.MAX_VALUE; 91 MultivariateGaussian minGaussian = null; 92 datum.assignment = minGaussian; 93 for( final MultivariateGaussian gaussian : gaussians ) { 94 final double dist = gaussian.calculateDistanceFromMeanSquared( datum ); 95 if( dist < minDistance ) { 96 minDistance = dist; 97 minGaussian = gaussian; 98 } 99 } 100 datum.assignment = minGaussian; 101 } 102 103 // M step: update gaussian means based on assigned variants 104 for( final MultivariateGaussian gaussian : gaussians ) { 105 gaussian.zeroOutMu(); 106 int numAssigned = 0; 107 108 for( final VariantDatum datum : data ) { 109 if( datum.assignment.equals(gaussian) ) { 110 numAssigned++; 111 gaussian.incrementMu( datum ); 112 } 113 } 114 if( numAssigned != 0 ) { 115 gaussian.divideEqualsMu( ((double) numAssigned) ); 116 } else { 117 gaussian.initializeRandomMu( Utils.getRandomGenerator() ); 118 } 119 } 120 } 121 } 122 expectationStep( final List<VariantDatum> data )123 public void expectationStep( final List<VariantDatum> data ) { 124 125 for( final MultivariateGaussian gaussian : gaussians ) { 126 gaussian.precomputeDenominatorForVariationalBayes( getSumHyperParameterLambda() ); 127 } 128 129 for( final VariantDatum datum : data ) { 130 final double[] pVarInGaussianLog10 = gaussians.stream().mapToDouble(g -> g.evaluateDatumLog10(datum)).toArray(); 131 final double[] pVarInGaussianNormalized = MathUtils.normalizeLog10DeleteMePlease( pVarInGaussianLog10, false); 132 int gaussianIndex = 0; 133 for( final MultivariateGaussian gaussian : gaussians ) { 134 gaussian.assignPVarInGaussian( pVarInGaussianNormalized[gaussianIndex++] ); 135 } 136 } 137 } 138 maximizationStep( final List<VariantDatum> data )139 public void maximizationStep( final List<VariantDatum> data ) { 140 gaussians.forEach(g -> g.maximizeGaussian( data, empiricalMu, empiricalSigma, shrinkage, dirichletParameter, priorCounts)); 141 } 142 getSumHyperParameterLambda()143 private double getSumHyperParameterLambda() { 144 return gaussians.stream().mapToDouble(g -> g.hyperParameter_lambda).sum(); 145 } 146 evaluateFinalModelParameters( final List<VariantDatum> data )147 public void evaluateFinalModelParameters( final List<VariantDatum> data ) { 148 gaussians.forEach(g -> g.evaluateFinalModelParameters(data)); 149 normalizePMixtureLog10(); 150 } 151 normalizePMixtureLog10()152 public double normalizePMixtureLog10() { 153 double sumDiff = 0.0; 154 final double sumPK = gaussians.stream().mapToDouble(g -> g.sumProb).sum(); 155 156 final double log10SumPK = Math.log10(sumPK); 157 final double[] pGaussianLog10 = gaussians.stream().mapToDouble(g -> Math.log10(g.sumProb) - log10SumPK).toArray(); 158 MathUtils.normalizeLog10DeleteMePlease( pGaussianLog10, true); 159 160 int gaussianIndex = 0; 161 for( final MultivariateGaussian gaussian : gaussians ) { 162 sumDiff += Math.abs( pGaussianLog10[gaussianIndex] - gaussian.pMixtureLog10 ); 163 gaussian.pMixtureLog10 = pGaussianLog10[gaussianIndex++]; 164 } 165 return sumDiff; 166 } 167 precomputeDenominatorForEvaluation()168 public void precomputeDenominatorForEvaluation() { 169 for( final MultivariateGaussian gaussian : gaussians ) { 170 gaussian.precomputeDenominatorForEvaluation(); 171 } 172 173 isModelReadyForEvaluation = true; 174 } 175 176 /** 177 * A version of Log10SumLog10 that tolerates NaN values in the array 178 * 179 * In the case where one or more of the values are NaN, this function returns NaN 180 * 181 * @param values a non-null vector of doubles 182 * @return log10 of the sum of the log10 values, or NaN 183 */ nanTolerantLog10SumLog10(final double[] values)184 private double nanTolerantLog10SumLog10(final double[] values) { 185 for ( final double value : values ) { 186 if ( Double.isNaN(value) ) { 187 return Double.NaN; 188 } 189 } 190 return MathUtils.log10sumLog10(values); 191 } 192 evaluateDatum( final VariantDatum datum )193 public double evaluateDatum( final VariantDatum datum ) { 194 for( final boolean isNull : datum.isNull ) { 195 if( isNull ) { 196 return evaluateDatumMarginalized( datum ); 197 } 198 } 199 // Fill an array with the log10 probability coming from each Gaussian and then use MathUtils to sum them up correctly 200 final double[] pVarInGaussianLog10 = new double[gaussians.size()]; 201 int gaussianIndex = 0; 202 for( final MultivariateGaussian gaussian : gaussians ) { 203 pVarInGaussianLog10[gaussianIndex++] = gaussian.pMixtureLog10 + gaussian.evaluateDatumLog10( datum ); 204 } 205 return nanTolerantLog10SumLog10(pVarInGaussianLog10); // Sum(pi_k * p(v|n,k)) 206 } 207 208 // Used only to decide which covariate dimension is most divergent in order to report in the culprit info field annotation evaluateDatumInOneDimension( final VariantDatum datum, final int iii )209 public Double evaluateDatumInOneDimension( final VariantDatum datum, final int iii ) { 210 if(datum.isNull[iii]) { return null; } 211 212 final double[] pVarInGaussianLog10 = new double[gaussians.size()]; 213 int gaussianIndex = 0; 214 for( final MultivariateGaussian gaussian : gaussians ) { 215 pVarInGaussianLog10[gaussianIndex] = gaussian.pMixtureLog10; 216 if (gaussian.pMixtureLog10 != Double.NEGATIVE_INFINITY) { 217 pVarInGaussianLog10[gaussianIndex] += MathUtils.normalDistributionLog10(gaussian.mu[iii], gaussian.sigma.get(iii, iii), datum.annotations[iii]); 218 } 219 gaussianIndex++; 220 } 221 return nanTolerantLog10SumLog10(pVarInGaussianLog10); // Sum(pi_k * p(v|n,k)) 222 } 223 evaluateDatumMarginalized( final VariantDatum datum )224 public double evaluateDatumMarginalized( final VariantDatum datum ) { 225 int numRandomDraws = 0; 226 double sumPVarInGaussian = 0.0; 227 final int numIterPerMissingAnnotation = 20; // Trade off here between speed of computation and accuracy of the marginalization 228 final double[] pVarInGaussianLog10 = new double[gaussians.size()]; 229 // for each dimension 230 for( int iii = 0; iii < datum.annotations.length; iii++ ) { 231 // if it is missing marginalize over the missing dimension by drawing X random values for the missing annotation and averaging the lod 232 if( datum.isNull[iii] ) { 233 for( int ttt = 0; ttt < numIterPerMissingAnnotation; ttt++ ) { 234 datum.annotations[iii] = Utils.getRandomGenerator().nextGaussian(); // draw a random sample from the standard normal distribution 235 236 // evaluate this random data point 237 int gaussianIndex = 0; 238 for( final MultivariateGaussian gaussian : gaussians ) { 239 pVarInGaussianLog10[gaussianIndex++] = gaussian.pMixtureLog10 + gaussian.evaluateDatumLog10( datum ); 240 } 241 242 // add this sample's probability to the pile in order to take an average in the end 243 sumPVarInGaussian += Math.pow(10.0, nanTolerantLog10SumLog10(pVarInGaussianLog10)); // p = 10 ^ Sum(pi_k * p(v|n,k)) 244 numRandomDraws++; 245 } 246 } 247 } 248 return Math.log10( sumPVarInGaussian / ((double) numRandomDraws) ); 249 } 250 getModelGaussians()251 protected List<MultivariateGaussian> getModelGaussians() {return Collections.unmodifiableList(gaussians);} 252 getNumAnnotations()253 protected int getNumAnnotations() {return empiricalMu.length;} 254 }