1 package org.broadinstitute.hellbender.tools.walkers.vqsr;
2 
3 import org.apache.logging.log4j.Logger;
4 import org.apache.logging.log4j.LogManager;
5 
6 import org.broadinstitute.hellbender.utils.MathUtils;
7 import org.broadinstitute.hellbender.utils.Utils;
8 
9 import java.util.ArrayList;
10 import java.util.Arrays;
11 import java.util.Collections;
12 import java.util.List;
13 
14 import Jama.Matrix;
15 
16 class GaussianMixtureModel {
17 
18     protected final static Logger logger = LogManager.getLogger(GaussianMixtureModel.class);
19 
20     private final List<MultivariateGaussian> gaussians;
21     private final double shrinkage;
22     private final double dirichletParameter;
23     private final double priorCounts;
24     private final double[] empiricalMu;
25     private final Matrix empiricalSigma;
26     public boolean isModelReadyForEvaluation;
27     public boolean failedToConverge = false;
28 
GaussianMixtureModel( final int numGaussians, final int numVariantData, final int numAnnotations, final double shrinkage, final double dirichletParameter, final double priorCounts )29     public GaussianMixtureModel( final int numGaussians, final int numVariantData, final int numAnnotations,
30                                  final double shrinkage, final double dirichletParameter, final double priorCounts ) {
31 
32         gaussians = new ArrayList<>( numGaussians );
33         for( int iii = 0; iii < numGaussians; iii++ ) {
34             final MultivariateGaussian gaussian = new MultivariateGaussian( numVariantData, numAnnotations );
35             gaussians.add( gaussian );
36         }
37         this.shrinkage = shrinkage;
38         this.dirichletParameter = dirichletParameter;
39         this.priorCounts = priorCounts;
40         empiricalMu = new double[numAnnotations];
41         empiricalSigma = new Matrix(numAnnotations, numAnnotations);
42         isModelReadyForEvaluation = false;
43         Arrays.fill(empiricalMu, 0.0);
44         empiricalSigma.setMatrix(0, empiricalMu.length - 1, 0, empiricalMu.length - 1, Matrix.identity(empiricalMu.length, empiricalMu.length).times(200.0).inverse());
45     }
46 
47     //this is used for the model output unit test
GaussianMixtureModel(final List<MultivariateGaussian> gaussians, final double shrinkage, final double dirichletParameter, final double priorCounts )48     protected GaussianMixtureModel(final List<MultivariateGaussian> gaussians, final double shrinkage, final double dirichletParameter, final double priorCounts ) {
49         this.gaussians = gaussians;
50         final int numAnnotations = gaussians.get(0).mu.length;
51         this.shrinkage = shrinkage;
52         this.dirichletParameter = dirichletParameter;
53         this.priorCounts = priorCounts;
54         empiricalMu = new double[numAnnotations];
55         empiricalSigma = new Matrix(numAnnotations, numAnnotations);
56         isModelReadyForEvaluation = false;
57         Arrays.fill(empiricalMu, 0.0);
58         empiricalSigma.setMatrix(0, empiricalMu.length - 1, 0, empiricalMu.length - 1, Matrix.identity(empiricalMu.length, empiricalMu.length).times(200.0).inverse());
59 
60     }
61 
initializeRandomModel( final List<VariantDatum> data, final int numKMeansIterations )62     public void initializeRandomModel( final List<VariantDatum> data, final int numKMeansIterations ) {
63 
64         // initialize random Gaussian means // BUGBUG: this is broken up this way to match the order of calls to rand.nextDouble() in the old code
65         for( final MultivariateGaussian gaussian : gaussians ) {
66             gaussian.initializeRandomMu( Utils.getRandomGenerator() );
67         }
68 
69         // initialize means using K-means algorithm
70         logger.info( "Initializing model with " + numKMeansIterations + " k-means iterations..." );
71         initializeMeansUsingKMeans( data, numKMeansIterations );
72 
73         // initialize uniform mixture coefficients, random covariance matrices, and initial hyperparameters
74         for( final MultivariateGaussian gaussian : gaussians ) {
75             gaussian.pMixtureLog10 = Math.log10( 1.0 / ((double) gaussians.size()) );
76             gaussian.sumProb = 1.0 / ((double) gaussians.size());
77             gaussian.initializeRandomSigma( Utils.getRandomGenerator() );
78             gaussian.hyperParameter_a = priorCounts;
79             gaussian.hyperParameter_b = shrinkage;
80             gaussian.hyperParameter_lambda = dirichletParameter;
81         }
82     }
83 
initializeMeansUsingKMeans( final List<VariantDatum> data, final int numIterations )84     private void initializeMeansUsingKMeans( final List<VariantDatum> data, final int numIterations ) {
85 
86         int ttt = 0;
87         while( ttt++ < numIterations ) {
88             // E step: assign each variant to the nearest cluster
89             for( final VariantDatum datum : data ) {
90                 double minDistance = Double.MAX_VALUE;
91                 MultivariateGaussian minGaussian = null;
92                 datum.assignment = minGaussian;
93                 for( final MultivariateGaussian gaussian : gaussians ) {
94                     final double dist = gaussian.calculateDistanceFromMeanSquared( datum );
95                     if( dist < minDistance ) {
96                         minDistance = dist;
97                         minGaussian = gaussian;
98                     }
99                 }
100                 datum.assignment = minGaussian;
101             }
102 
103             // M step: update gaussian means based on assigned variants
104             for( final MultivariateGaussian gaussian : gaussians ) {
105                 gaussian.zeroOutMu();
106                 int numAssigned = 0;
107 
108                 for( final VariantDatum datum : data ) {
109                     if( datum.assignment.equals(gaussian) ) {
110                         numAssigned++;
111                         gaussian.incrementMu( datum );
112                     }
113                 }
114                 if( numAssigned != 0 ) {
115                     gaussian.divideEqualsMu( ((double) numAssigned) );
116                 } else {
117                     gaussian.initializeRandomMu( Utils.getRandomGenerator() );
118                 }
119             }
120         }
121     }
122 
expectationStep( final List<VariantDatum> data )123     public void expectationStep( final List<VariantDatum> data ) {
124 
125         for( final MultivariateGaussian gaussian : gaussians ) {
126             gaussian.precomputeDenominatorForVariationalBayes( getSumHyperParameterLambda() );
127         }
128 
129         for( final VariantDatum datum : data ) {
130             final double[] pVarInGaussianLog10 = gaussians.stream().mapToDouble(g -> g.evaluateDatumLog10(datum)).toArray();
131             final double[] pVarInGaussianNormalized = MathUtils.normalizeLog10DeleteMePlease( pVarInGaussianLog10, false);
132             int gaussianIndex = 0;
133             for( final MultivariateGaussian gaussian : gaussians ) {
134                 gaussian.assignPVarInGaussian( pVarInGaussianNormalized[gaussianIndex++] );
135             }
136         }
137     }
138 
maximizationStep( final List<VariantDatum> data )139     public void maximizationStep( final List<VariantDatum> data ) {
140         gaussians.forEach(g -> g.maximizeGaussian( data, empiricalMu, empiricalSigma, shrinkage, dirichletParameter, priorCounts));
141     }
142 
getSumHyperParameterLambda()143     private double getSumHyperParameterLambda() {
144         return gaussians.stream().mapToDouble(g -> g.hyperParameter_lambda).sum();
145     }
146 
evaluateFinalModelParameters( final List<VariantDatum> data )147     public void evaluateFinalModelParameters( final List<VariantDatum> data ) {
148         gaussians.forEach(g -> g.evaluateFinalModelParameters(data));
149         normalizePMixtureLog10();
150     }
151 
normalizePMixtureLog10()152     public double normalizePMixtureLog10() {
153         double sumDiff = 0.0;
154         final double sumPK = gaussians.stream().mapToDouble(g -> g.sumProb).sum();
155 
156         final double log10SumPK = Math.log10(sumPK);
157         final double[] pGaussianLog10 = gaussians.stream().mapToDouble(g -> Math.log10(g.sumProb) - log10SumPK).toArray();
158         MathUtils.normalizeLog10DeleteMePlease( pGaussianLog10, true);
159 
160         int gaussianIndex = 0;
161         for( final MultivariateGaussian gaussian : gaussians ) {
162             sumDiff += Math.abs( pGaussianLog10[gaussianIndex] - gaussian.pMixtureLog10 );
163             gaussian.pMixtureLog10 = pGaussianLog10[gaussianIndex++];
164         }
165         return sumDiff;
166     }
167 
precomputeDenominatorForEvaluation()168     public void precomputeDenominatorForEvaluation() {
169         for( final MultivariateGaussian gaussian : gaussians ) {
170             gaussian.precomputeDenominatorForEvaluation();
171         }
172 
173         isModelReadyForEvaluation = true;
174     }
175 
176     /**
177      * A version of Log10SumLog10 that tolerates NaN values in the array
178      *
179      * In the case where one or more of the values are NaN, this function returns NaN
180      *
181      * @param values a non-null vector of doubles
182      * @return log10 of the sum of the log10 values, or NaN
183      */
nanTolerantLog10SumLog10(final double[] values)184     private double nanTolerantLog10SumLog10(final double[] values) {
185         for ( final double value : values ) {
186             if ( Double.isNaN(value) ) {
187                 return Double.NaN;
188             }
189         }
190         return MathUtils.log10sumLog10(values);
191     }
192 
evaluateDatum( final VariantDatum datum )193     public double evaluateDatum( final VariantDatum datum ) {
194         for( final boolean isNull : datum.isNull ) {
195             if( isNull ) {
196                 return evaluateDatumMarginalized( datum );
197             }
198         }
199         // Fill an array with the log10 probability coming from each Gaussian and then use MathUtils to sum them up correctly
200         final double[] pVarInGaussianLog10 = new double[gaussians.size()];
201         int gaussianIndex = 0;
202         for( final MultivariateGaussian gaussian : gaussians ) {
203             pVarInGaussianLog10[gaussianIndex++] = gaussian.pMixtureLog10 + gaussian.evaluateDatumLog10( datum );
204         }
205         return nanTolerantLog10SumLog10(pVarInGaussianLog10); // Sum(pi_k * p(v|n,k))
206     }
207 
208     // Used only to decide which covariate dimension is most divergent in order to report in the culprit info field annotation
evaluateDatumInOneDimension( final VariantDatum datum, final int iii )209     public Double evaluateDatumInOneDimension( final VariantDatum datum, final int iii ) {
210         if(datum.isNull[iii]) { return null; }
211 
212         final double[] pVarInGaussianLog10 = new double[gaussians.size()];
213         int gaussianIndex = 0;
214         for( final MultivariateGaussian gaussian : gaussians ) {
215             pVarInGaussianLog10[gaussianIndex] = gaussian.pMixtureLog10;
216             if (gaussian.pMixtureLog10 != Double.NEGATIVE_INFINITY) {
217                 pVarInGaussianLog10[gaussianIndex] += MathUtils.normalDistributionLog10(gaussian.mu[iii], gaussian.sigma.get(iii, iii), datum.annotations[iii]);
218             }
219             gaussianIndex++;
220         }
221         return nanTolerantLog10SumLog10(pVarInGaussianLog10); // Sum(pi_k * p(v|n,k))
222     }
223 
evaluateDatumMarginalized( final VariantDatum datum )224     public double evaluateDatumMarginalized( final VariantDatum datum ) {
225         int numRandomDraws = 0;
226         double sumPVarInGaussian = 0.0;
227         final int numIterPerMissingAnnotation = 20; // Trade off here between speed of computation and accuracy of the marginalization
228         final double[] pVarInGaussianLog10 = new double[gaussians.size()];
229         // for each dimension
230         for( int iii = 0; iii < datum.annotations.length; iii++ ) {
231             // if it is missing marginalize over the missing dimension by drawing X random values for the missing annotation and averaging the lod
232             if( datum.isNull[iii] ) {
233                 for( int ttt = 0; ttt < numIterPerMissingAnnotation; ttt++ ) {
234                     datum.annotations[iii] = Utils.getRandomGenerator().nextGaussian(); // draw a random sample from the standard normal distribution
235 
236                     // evaluate this random data point
237                     int gaussianIndex = 0;
238                     for( final MultivariateGaussian gaussian : gaussians ) {
239                         pVarInGaussianLog10[gaussianIndex++] = gaussian.pMixtureLog10 + gaussian.evaluateDatumLog10( datum );
240                     }
241 
242                     // add this sample's probability to the pile in order to take an average in the end
243                     sumPVarInGaussian += Math.pow(10.0, nanTolerantLog10SumLog10(pVarInGaussianLog10)); // p = 10 ^ Sum(pi_k * p(v|n,k))
244                     numRandomDraws++;
245                 }
246             }
247         }
248         return Math.log10( sumPVarInGaussian / ((double) numRandomDraws) );
249     }
250 
getModelGaussians()251     protected List<MultivariateGaussian> getModelGaussians() {return Collections.unmodifiableList(gaussians);}
252 
getNumAnnotations()253     protected int getNumAnnotations() {return empiricalMu.length;}
254 }