1 package org.broadinstitute.hellbender.tools.spark.sv.evidence;
2 
3 import biz.k11i.xgboost.Predictor;
4 import biz.k11i.xgboost.util.FVec;
5 import com.fasterxml.jackson.databind.JsonNode;
6 import com.fasterxml.jackson.databind.ObjectMapper;
7 import htsjdk.samtools.SAMFileHeader;
8 import htsjdk.samtools.SAMReadGroupRecord;
9 import org.apache.spark.api.java.JavaDoubleRDD;
10 import org.apache.spark.api.java.JavaRDD;
11 import org.apache.spark.api.java.JavaSparkContext;
12 import org.broadinstitute.hellbender.GATKBaseTest;
13 import org.broadinstitute.hellbender.engine.spark.SparkContextFactory;
14 import org.broadinstitute.hellbender.exceptions.GATKException;
15 import org.broadinstitute.hellbender.tools.spark.sv.discovery.TestUtilsForAssemblyBasedSVDiscovery;
16 import org.broadinstitute.hellbender.tools.spark.sv.utils.SVInterval;
17 import org.broadinstitute.hellbender.tools.spark.sv.utils.StrandedInterval;
18 import org.broadinstitute.hellbender.tools.spark.utils.IntHistogram;
19 import org.broadinstitute.hellbender.utils.Utils;
20 import org.testng.Assert;
21 import org.testng.annotations.Test;
22 import org.broadinstitute.hellbender.tools.spark.sv.StructuralVariationDiscoveryArgumentCollection.FindBreakpointEvidenceSparkArgumentCollection;
23 
24 
25 import java.io.*;
26 import java.util.*;
27 import java.util.stream.Collectors;
28 
29 import static org.broadinstitute.hellbender.utils.Utils.validateArg;
30 
31 public class XGBoostEvidenceFilterUnitTest extends GATKBaseTest {
32     private static final String SV_EVIDENCE_TEST_DIR = toolsTestDir + "spark/sv/evidence/FindBreakpointEvidenceSpark/";
33     private static final String testAccuracyDataJsonFile = SV_EVIDENCE_TEST_DIR + "sv_classifier_test_data.json";
34     private static final String classifierModelFile = "/large/sv_evidence_classifier.bin";
35     private static final String localClassifierModelFile
36             = new File(publicMainResourcesDir, classifierModelFile).getAbsolutePath();
37     private static final String testFeaturesJsonFile = SV_EVIDENCE_TEST_DIR + "sv_features_test_data.json";
38     private static final double probabilityTol = 2.0e-3;
39     private static final double featuresTol = 1.0e-5;
40     private static final String SV_GENOME_UMAP_S100_FILE = SV_EVIDENCE_TEST_DIR + "hg38_umap_s100.bed.gz";
41     private static final String SV_GENOME_GAPS_FILE = SV_EVIDENCE_TEST_DIR + "hg38_gaps.bed.gz";
42 
43     private static final String PANDAS_TABLE_NODE = "pandas.DataFrame";
44     private static final String PANDAS_COLUMN_NODE = "pandas.Series";
45     private static final String NUMPY_NODE = "numpy.array";
46     private static final String FEATURES_NODE = "features";
47     private static final String STRING_REPS_NODE = "string_reps";
48     private static final String PROBABILITY_NODE = "proba";
49     private static final String MEAN_GENOME_COVERAGE_NODE = "coverage";
50     private static final String TEMPLATE_SIZE_CUMULATIVE_COUNTS_NODE = "template_size_cumulative_counts";
51 
52     private static final ClassifierAccuracyData classifierAccuracyData = new ClassifierAccuracyData(testAccuracyDataJsonFile);
53     private static final Predictor testPredictor = XGBoostEvidenceFilter.loadPredictor(localClassifierModelFile);
54     private static final double[] predictedProbabilitySerial = predictProbability(
55             testPredictor, classifierAccuracyData.features
56     );
57     private static final FeaturesTestData featuresTestData = new FeaturesTestData(testFeaturesJsonFile);
58 
59     private static final FindBreakpointEvidenceSparkArgumentCollection params = initParams();
60 
61     private static final SAMFileHeader artificialSamHeader = initSAMFileHeader();
62     private static final String readGroupName = "Pond-Testing";
63     private static final String DEFAULT_SAMPLE_NAME = "SampleX";
64     private static final ReadMetadata readMetadata = initMetadata();
65     private static final PartitionCrossingChecker emptyCrossingChecker = new PartitionCrossingChecker();
66     private static final BreakpointEvidenceFactory breakpointEvidenceFactory = new BreakpointEvidenceFactory(readMetadata);
67     private static final List<BreakpointEvidence> evidenceList = Arrays.stream(featuresTestData.stringReps)
68             .map(breakpointEvidenceFactory::fromStringRep).collect(Collectors.toList());
69 
initParams()70     private static FindBreakpointEvidenceSparkArgumentCollection initParams() {
71         final FindBreakpointEvidenceSparkArgumentCollection params = new FindBreakpointEvidenceSparkArgumentCollection();
72         params.svGenomeUmapS100File = SV_GENOME_UMAP_S100_FILE;
73         params.svGenomeGapsFile = SV_GENOME_GAPS_FILE;
74         return params;
75     }
76 
initSAMFileHeader()77     private static SAMFileHeader initSAMFileHeader() {
78         final SAMFileHeader samHeader = createArtificialSamHeader();
79         SAMReadGroupRecord readGroup = new SAMReadGroupRecord(readGroupName);
80         readGroup.setSample(DEFAULT_SAMPLE_NAME);
81         samHeader.addReadGroup(readGroup);
82         return samHeader;
83     }
84 
85     /**
86      * Create synthetic SAM Header comptible with genome tracts (e.g. has all the primary contigs)
87      */
createArtificialSamHeader()88     public static SAMFileHeader createArtificialSamHeader() {
89         final SAMFileHeader header = new SAMFileHeader();
90         header.setSortOrder(SAMFileHeader.SortOrder.coordinate);
91         header.setSequenceDictionary(TestUtilsForAssemblyBasedSVDiscovery.bareBoneHg38SAMSeqDict);
92         return header;
93     }
94 
initMetadata()95     private static ReadMetadata initMetadata() {
96         final ReadMetadata.PartitionBounds[] partitionBounds = new ReadMetadata.PartitionBounds[3];
97         partitionBounds[0] = new ReadMetadata.PartitionBounds(0, 1, 0, 10000, 9999);
98         partitionBounds[1] = new ReadMetadata.PartitionBounds(0, 10001, 0, 20000, 9999);
99         partitionBounds[2] = new ReadMetadata.PartitionBounds(0, 20001, 0, 30000, 9999);
100         return new ReadMetadata(Collections.emptySet(), artificialSamHeader,
101                 new LibraryStatistics(cumulativeCountsToCDF(featuresTestData.template_size_cumulative_counts),
102                         60000000000L, 600000000L, 1200000000000L, 3000000000L),
103                 partitionBounds, 100, 10, featuresTestData.coverage);
104     }
105 
cumulativeCountsToCDF(final long[] cumulativeCounts)106     private static IntHistogram.CDF cumulativeCountsToCDF(final long[] cumulativeCounts) {
107         final long totalObservations = cumulativeCounts[cumulativeCounts.length - 1];
108         final float[] cdfFractions = new float[cumulativeCounts.length];
109         for(int index = 0; index < cdfFractions.length; ++index) {
110             cdfFractions[index] = cumulativeCounts[index] / (float)totalObservations;
111         }
112         return new IntHistogram.CDF(cdfFractions, totalObservations);
113     }
114 
115     @Test(groups = "sv")
testLocalXGBoostClassifierAccuracy()116     protected void testLocalXGBoostClassifierAccuracy() {
117         // check accuracy: predictions are same as classifierAccuracyData up to tolerance
118         assertArrayEquals(predictedProbabilitySerial, classifierAccuracyData.probability, probabilityTol, "Probabilities predicted by classifier do not match saved correct answers"
119         );
120     }
121 
122     @Test(groups = "sv")
testLocalXGBoostClassifierSpark()123     protected void testLocalXGBoostClassifierSpark() {
124         final Predictor localPredictor = XGBoostEvidenceFilter.loadPredictor(localClassifierModelFile);
125         // get spark ctx
126         final JavaSparkContext ctx = SparkContextFactory.getTestSparkContext();
127         // parallelize classifierAccuracyData to RDD
128         JavaRDD<FVec> testFeaturesRdd = ctx.parallelize(Arrays.asList(classifierAccuracyData.features));
129         // predict in parallel
130         JavaDoubleRDD predictedProbabilityRdd
131                 = testFeaturesRdd.mapToDouble(f -> localPredictor.predictSingle(f, false, 0));
132         // pull back to local array
133         final double[] predictedProbabilitySpark = predictedProbabilityRdd.collect()
134                 .stream().mapToDouble(Double::doubleValue).toArray();
135         // check probabilities from spark are identical to serial
136         assertArrayEquals(predictedProbabilitySpark, predictedProbabilitySerial, 0.0, "Probabilities predicted in spark context differ from serial"
137         );
138     }
139 
140     @Test(groups = "sv")
testResourceXGBoostClassifier()141     protected void testResourceXGBoostClassifier() {
142         // load classifier from resource
143         final Predictor resourcePredictor = XGBoostEvidenceFilter.loadPredictor(null);
144         final double[] predictedProbabilityResource = predictProbability(resourcePredictor, classifierAccuracyData.features);
145         // check that predictions from resource are identical to local
146         assertArrayEquals(predictedProbabilityResource, predictedProbabilitySerial, 0.0, "Predictions via loading predictor from resource is not identical to local file"
147         );
148     }
149 
150     @Test(groups = "sv")
testFeatureConstruction()151     protected void testFeatureConstruction() {
152         final XGBoostEvidenceFilter evidenceFilter = new XGBoostEvidenceFilter(
153                 evidenceList.iterator(), readMetadata, params, emptyCrossingChecker
154         );
155         for(int ind = 0; ind < featuresTestData.stringReps.length; ++ind) {
156             final BreakpointEvidence evidence = evidenceList.get(ind);
157             final String stringRep = featuresTestData.stringReps[ind];
158             final EvidenceFeatures fVec = featuresTestData.features[ind];
159             final double probability = featuresTestData.probability[ind];
160 
161             final BreakpointEvidence convertedEvidence = breakpointEvidenceFactory.fromStringRep(stringRep);
162             final String convertedRep = convertedEvidence.stringRep(readMetadata, params.minEvidenceMapQ);
163             Assert.assertEquals(convertedRep.trim(), stringRep.trim(),
164                     "BreakpointEvidenceFactory.fromStringRep does not invert BreakpointEvidence.stringRep");
165             final EvidenceFeatures calcFVec = evidenceFilter.getFeatures(evidence);
166             assertArrayEquals(calcFVec.getValues(), fVec.getValues(), featuresTol,
167                     "Features calculated by XGBoostEvidenceFilter don't match expected features");
168             final double calcProbability = evidenceFilter.predictProbability(evidence);
169             Assert.assertEquals(calcProbability, probability, probabilityTol,
170                     "Probability calculated by XGBoostEvidenceFilter doesn't match expected probability");
171         }
172     }
173 
174     @Test(groups = "sv")
testFilter()175     protected void testFilter() {
176         final XGBoostEvidenceFilter evidenceFilter = new XGBoostEvidenceFilter(
177                 evidenceList.iterator(), readMetadata, params, emptyCrossingChecker
178         );
179 
180         // construct list of BreakpointEvidence that is expected to pass the filter
181         final List<BreakpointEvidence> expectedPassed = new ArrayList<>();
182         List<BreakpointEvidence> sameLocationEvidence = new ArrayList<>();
183         boolean locationPassed = false;
184         SVInterval previous = null;
185         for(final BreakpointEvidence evidence : evidenceList) {
186             // Use the classifier to calculate probability, to ensure that minor fluctuations that happen to cross the
187             // decision threshold don't cause test failure. Here we only test if the filtering mechanism works correctly.
188             // Accuracy of probability calculation is tested in testFeatureConstruction.
189             final double probability = evidenceFilter.predictProbability(evidence);
190             final boolean matchesPrevious = evidence.getLocation().equals(previous);
191             locationPassed = matchesPrevious ?
192                     locationPassed || probability > params.svEvidenceFilterThresholdProbability
193                     : probability > params.svEvidenceFilterThresholdProbability;
194             if(locationPassed) {
195                 if(matchesPrevious) {
196                     expectedPassed.addAll(sameLocationEvidence);
197                 } else {
198                     previous = evidence.getLocation();
199                 }
200                 sameLocationEvidence.clear();
201                 expectedPassed.add(evidence);
202             } else if(matchesPrevious) {
203                 sameLocationEvidence.add(evidence);
204             } else {
205                 sameLocationEvidence.clear();
206                 previous = evidence.getLocation();
207             }
208         }
209         sameLocationEvidence.clear();
210 
211         // use evidenceFilter to populate array with passed evidence
212         final List<BreakpointEvidence> passedEvidence = new ArrayList<>();
213         evidenceFilter.forEachRemaining(passedEvidence::add);
214 
215         Assert.assertEquals(passedEvidence, expectedPassed,
216                 "Evidence passed by XGBoostEvidenceFilter not the same as expected");
217     }
218 
assertArrayEquals(final double[] actuals, final double[] expecteds, final double tol, final String message)219     private static void assertArrayEquals(final double[] actuals, final double[] expecteds, final double tol,
220                                           final String message) {
221         Assert.assertEquals(actuals.length, expecteds.length, "Lengths not equal: " + message);
222         for(int index = 0; index < expecteds.length; ++index) {
223             Assert.assertEquals(actuals[index], expecteds[index], tol, "at index=" + index + ": " + message);
224         }
225     }
226 
predictProbability(final Predictor predictor, final FVec[] testFeatures)227     private static double[] predictProbability(final Predictor predictor, final FVec[] testFeatures) {
228 
229         return Arrays.stream(testFeatures).mapToDouble(
230                 features -> predictor.predictSingle(features, false, 0)
231         ).toArray();
232     }
233 
234     static class JsonMatrixLoader {
235 
236         private static final String CLASS_NODE = "__class__";
237         private static final String DATA_NODE = "data";
238         private static final String VALUES_NODE = "values";
239         private static final String CODES_NODE = "codes";
240 
getFVecArrayFromJsonNode(final JsonNode matrixNode)241         static EvidenceFeatures[] getFVecArrayFromJsonNode(final JsonNode matrixNode) {
242             if(!matrixNode.has(CLASS_NODE)) {
243                 throw new IllegalArgumentException("JSON node does not store python matrix data");
244             }
245             String matrixClass = matrixNode.get(CLASS_NODE).asText();
246             switch(matrixClass) {
247                 case PANDAS_TABLE_NODE:
248                     return getFVecArrayFromPandasJsonNode(matrixNode.get(DATA_NODE));
249                 case NUMPY_NODE:
250                     return getFVecArrayFromNumpyJsonNode(matrixNode.get(DATA_NODE));
251                 default:
252                     throw new IllegalArgumentException("JSON node has " + CLASS_NODE + " = " + matrixClass
253                             + " which is not a supported matrix type");
254             }
255         }
256 
getFVecArrayFromNumpyJsonNode(final JsonNode dataNode)257         private static EvidenceFeatures[] getFVecArrayFromNumpyJsonNode(final JsonNode dataNode) {
258             Utils.validateArg(dataNode.isArray(), "dataNode does not encode a valid numpy array");
259             final int numRows = dataNode.size();
260             final EvidenceFeatures[] matrix = new EvidenceFeatures[numRows];
261             if (numRows == 0) {
262                 return matrix;
263             }
264             matrix[0] = new EvidenceFeatures(getDoubleArrayFromJsonArrayNode(dataNode.get(0)));
265             final int numColumns = matrix[0].length();
266             for (int row = 1; row < numRows; ++row) {
267                 matrix[row] = new EvidenceFeatures(getDoubleArrayFromJsonArrayNode(dataNode.get(row)));
268                 final int numRowColumns = matrix[row].length();
269                 Utils.validateArg(numRowColumns == numColumns, "Rows in JSONArray have different lengths.");
270             }
271             return matrix;
272         }
273 
getFVecArrayFromPandasJsonNode(final JsonNode dataNode)274         private static EvidenceFeatures[] getFVecArrayFromPandasJsonNode(final JsonNode dataNode) {
275             if(!dataNode.isObject()) {
276                 throw new IllegalArgumentException("dataNode does not encode a valid pandas DataFrame");
277             }
278             final int numColumns = dataNode.size();
279             if(numColumns == 0) {
280                 return new EvidenceFeatures[0];
281             }
282 
283             final String firstColumnName = dataNode.fieldNames().next();
284             final int numRows = getColumnArrayNode(dataNode.get(firstColumnName)).size();
285             final EvidenceFeatures[] matrix = new EvidenceFeatures[numRows];
286             if (numRows == 0) {
287                 return matrix;
288             }
289             // allocate each EvidenceFeatures in matrix
290             for(int rowIndex = 0; rowIndex < numRows; ++rowIndex) {
291                 matrix[rowIndex] = new EvidenceFeatures(numColumns);
292             }
293             int columnIndex = 0;
294             for(final Iterator<Map.Entry<String, JsonNode>> fieldIter = dataNode.fields(); fieldIter.hasNext();) {
295                 // loop over columns
296                 final Map.Entry<String, JsonNode> columnEntry = fieldIter.next();
297                 final JsonNode columnArrayNode = getColumnArrayNode(columnEntry.getValue());
298                 Utils.validateArg(columnArrayNode.size() == numRows,
299                         "field " + columnEntry.getKey() + " has " + columnArrayNode.size() + " rows (expected " + numRows + ")");
300                 // for each FVec in matrix, assign feature from this column
301                 int rowIndex = 0;
302                 for(final JsonNode valueNode: columnArrayNode) {
303                     final EvidenceFeatures fVec = matrix[rowIndex];
304                     fVec.setValue(columnIndex, valueNode.asDouble());
305                     ++rowIndex;
306                 }
307                 ++columnIndex;
308             }
309             return matrix;
310         }
311 
getColumnArrayNode(final JsonNode columnNode)312         private static JsonNode getColumnArrayNode(final JsonNode columnNode) {
313             return columnNode.has(VALUES_NODE) ? columnNode.get(VALUES_NODE) : columnNode.get(CODES_NODE);
314         }
315 
getDoubleArrayFromJsonNode(final JsonNode vectorNode)316         static double[] getDoubleArrayFromJsonNode(final JsonNode vectorNode) {
317             if(!vectorNode.has(CLASS_NODE)) {
318                 return getDoubleArrayFromJsonArrayNode(vectorNode);
319             }
320             final String vectorClass = vectorNode.get(CLASS_NODE).asText();
321             switch(vectorClass) {
322                 case PANDAS_COLUMN_NODE:
323                     return getDoubleArrayFromJsonArrayNode(getColumnArrayNode(vectorNode));
324                 case NUMPY_NODE:
325                     return getDoubleArrayFromJsonArrayNode(vectorNode.get(DATA_NODE));
326                 default:
327                     throw new IllegalArgumentException("JSON node has " + CLASS_NODE + " = " + vectorClass
328                             + "which is not a supported vector type");
329             }
330         }
331 
getDoubleArrayFromJsonArrayNode(final JsonNode arrayNode)332         private static double [] getDoubleArrayFromJsonArrayNode(final JsonNode arrayNode) {
333             if(!arrayNode.isArray()) {
334                 throw new IllegalArgumentException("JsonNode does not contain an Array");
335             }
336             final int numData = arrayNode.size();
337             final double[] data = new double[numData];
338             int ind = 0;
339             for(final JsonNode valueNode : arrayNode) {
340                 data[ind] = valueNode.asDouble();
341                 ++ind;
342             }
343             return data;
344         }
345 
getLongArrayFromJsonNode(final JsonNode vectorNode)346         static long[] getLongArrayFromJsonNode(final JsonNode vectorNode) {
347             if(!vectorNode.has(CLASS_NODE)) {
348                 return getLongArrayFromJsonArrayNode(vectorNode);
349             }
350             final String vectorClass = vectorNode.get(CLASS_NODE).asText();
351             switch(vectorClass) {
352                 case PANDAS_COLUMN_NODE:
353                     return getLongArrayFromJsonArrayNode(getColumnArrayNode(vectorNode));
354                 case NUMPY_NODE:
355                     return getLongArrayFromJsonArrayNode(vectorNode.get(DATA_NODE));
356                 default:
357                     throw new IllegalArgumentException("JSON node has " + CLASS_NODE + " = " + vectorClass
358                             + "which is not a supported vector type");
359             }
360         }
361 
getLongArrayFromJsonArrayNode(final JsonNode arrayNode)362         private static long [] getLongArrayFromJsonArrayNode(final JsonNode arrayNode) {
363             if(!arrayNode.isArray()) {
364                 throw new IllegalArgumentException("JsonNode does not contain an Array");
365             }
366             final int numData = arrayNode.size();
367             final long[] data = new long[numData];
368             int ind = 0;
369             for(final JsonNode valueNode : arrayNode) {
370                 data[ind] = valueNode.asInt();
371                 ++ind;
372             }
373             return data;
374         }
375 
getStringArrayFromJsonNode(final JsonNode arrayNode)376         static String[] getStringArrayFromJsonNode(final JsonNode arrayNode) {
377             if(!arrayNode.isArray()) {
378                 throw new IllegalArgumentException("JsonNode does not contain an Array");
379             }
380             final int numStrings = arrayNode.size();
381             final String[] stringArray = new String[numStrings];
382             int ind = 0;
383             for(final JsonNode stringNode : arrayNode) {
384                 stringArray[ind] = stringNode.asText();
385                 ++ind;
386             }
387             return stringArray;
388         }
389     }
390 
391     private static class BreakpointEvidenceFactory {
392         private static final String DEFAULT_POND_NAME = "Pond-Testing";
393         final ReadMetadata readMetadata;
394 
BreakpointEvidenceFactory(final ReadMetadata readMetadata)395         BreakpointEvidenceFactory(final ReadMetadata readMetadata) {
396             this.readMetadata = readMetadata;
397         }
398 
399         /**
400          * Returns BreakpointEvidence constructed from string representation. Used to reconstruct BreakpointEvidence for
401          * unit tests. It is intended for stringRep() to be an inverse of this function, but not the other way around. i.e.
402          *      fromStringRep(strRep, readMetadata).stringRep(readMetadata, minEvidenceMapQ) == strRep
403          * but it may be the case that
404          *      fromStringRep(evidence.stringRep(readMetadata, minEvidenceMapQ), readMetadata) != evidence
405          */
fromStringRep(final String strRep)406         BreakpointEvidence fromStringRep(final String strRep) {
407             final String[] words = strRep.split("\t");
408 
409             final SVInterval location = locationFromStringRep(words[0]);
410 
411             final int weight = Integer.parseInt(words[1]);
412 
413             final String evidenceType = words[2];
414             if(evidenceType.equals("TemplateSizeAnomaly")) {
415                 final int readCount = Integer.parseInt(words[4]);
416                 return new BreakpointEvidence.TemplateSizeAnomaly(location, weight, readCount);
417             } else {
418                 final List<StrandedInterval> distalTargets = words[3].isEmpty() ? new ArrayList<>()
419                         : Arrays.stream(words[3].split(";")).map(BreakpointEvidenceFactory::strandedLocationFromStringRep)
420                         .collect(Collectors.toList());
421                 validateArg(distalTargets.size() <= 1, "BreakpointEvidence must have 0 or 1 distal targets");
422                 final String[] templateParts = words[4].split("/");
423                 final String templateName = templateParts[0];
424                 final TemplateFragmentOrdinal fragmentOrdinal;
425                 if(templateParts.length <= 1) {
426                     fragmentOrdinal = TemplateFragmentOrdinal.UNPAIRED;
427                 } else switch (templateParts[1]) {
428                     case "0":
429                         fragmentOrdinal = TemplateFragmentOrdinal.PAIRED_INTERIOR;
430                         break;
431                     case "1":
432                         fragmentOrdinal = TemplateFragmentOrdinal.PAIRED_FIRST;
433                         break;
434                     case "2":
435                         fragmentOrdinal = TemplateFragmentOrdinal.PAIRED_SECOND;
436                         break;
437                     case "?":
438                         fragmentOrdinal = TemplateFragmentOrdinal.PAIRED_UNKNOWN;
439                         break;
440                     default:
441                         throw new IllegalArgumentException("Unknown Template Fragment Ordinal: /" + templateParts[1]);
442                 }
443                 final boolean forwardStrand = words[5].equals("1");
444                 final int templateSize = Integer.parseInt(words[6]);
445                 final String cigarString = words[7];
446                 final int mappingQuality = Integer.parseInt(words[8]);
447                 final String readGroup = DEFAULT_POND_NAME; // for now, just fake this, only for testing.
448                 final boolean validated = false;
449 
450 
451                 final SVInterval target;
452                 final boolean targetForwardStrand;
453                 final int targetQuality;
454                 switch(distalTargets.size()) {
455                     case 0:
456                         target = new SVInterval(0, 0, 0);
457                         targetForwardStrand = false;
458                         targetQuality = -1;
459                         break;
460                     case 1:
461                         target = distalTargets.get(0).getInterval();
462                         targetForwardStrand = distalTargets.get(0).getStrand();
463                         targetQuality = Integer.MAX_VALUE;
464                         break;
465                     default:
466                         throw new IllegalArgumentException("BreakpointEvidence must have <= 1 distal target");
467                 }
468 
469                 switch(evidenceType) {
470                     case "SplitRead":
471                         // NOTE: can't identically reconstruct original values, but can make self-consistent values that reproduce
472                         // the known distal targets. Make plausible cigar strings, primaryAlignmentClippedAtStart and
473                         // primaryAlignmentForwardStrand that are compatible with dumped distal targets.
474                         final String tagSA = distalTargets.isEmpty() ? null : distalTargets.stream().map(this::distalTargetToTagSA).collect(Collectors.joining());
475                         return new BreakpointEvidence.SplitRead(location, weight, templateName, fragmentOrdinal, validated,
476                                 forwardStrand, cigarString, mappingQuality, templateSize, readGroup,
477                                 forwardStrand, forwardStrand, tagSA);
478 
479                     case "LargeIndel":
480                         Utils.validateArg(distalTargets.isEmpty(), "LargeIndel should have no distal targets");
481                         return new BreakpointEvidence.LargeIndel(location, weight, templateName, fragmentOrdinal, validated,
482                                 forwardStrand, cigarString, mappingQuality, templateSize, readGroup);
483 
484                     case "MateUnmapped":
485                         Utils.validateArg(distalTargets.isEmpty(), "MateUnmapped should have no distal targets");
486                         return new BreakpointEvidence.MateUnmapped(location, weight, templateName, fragmentOrdinal, validated,
487                                 forwardStrand, cigarString, mappingQuality, templateSize, readGroup);
488 
489                     case "InterContigPair":
490                         return new BreakpointEvidence.InterContigPair(
491                                 location, weight, templateName, fragmentOrdinal, validated, forwardStrand, cigarString,
492                                 mappingQuality, templateSize, readGroup, target, targetForwardStrand, targetQuality
493                         );
494 
495                     case "OutiesPair":
496                         return new BreakpointEvidence.OutiesPair(
497                                 location, weight, templateName, fragmentOrdinal, validated, forwardStrand, cigarString,
498                                 mappingQuality, templateSize, readGroup, target, targetForwardStrand, targetQuality
499                         );
500 
501                     case "SameStrandPair":
502                         return new BreakpointEvidence.SameStrandPair(
503                                 location, weight, templateName, fragmentOrdinal, validated, forwardStrand, cigarString,
504                                 mappingQuality, templateSize, readGroup, target, targetForwardStrand, targetQuality
505                         );
506 
507                     case "WeirdTemplateSize":
508                         return new BreakpointEvidence.WeirdTemplateSize(
509                                 location, weight, templateName, fragmentOrdinal, validated, forwardStrand, cigarString,
510                                 mappingQuality, templateSize, readGroup, target, targetForwardStrand, targetQuality
511                         );
512                     default:
513                         throw new IllegalArgumentException("Unknown BreakpointEvidence type: " + evidenceType);
514                 }
515             }
516         }
517 
distalTargetToTagSA(final StrandedInterval distalTarget)518         private String distalTargetToTagSA(final StrandedInterval distalTarget) {
519             final String contigName = readMetadata.getContigName(distalTarget.getInterval().getContig());
520             final boolean isForwardStrand = distalTarget.getStrand();
521             final int referenceLength = distalTarget.getInterval().getLength();
522             final int pos = distalTarget.getInterval().getEnd() - 1 - BreakpointEvidence.SplitRead.UNCERTAINTY;
523             final int start = isForwardStrand ? pos - referenceLength: pos;
524             final int clipLength = readMetadata.getAvgReadLen() - referenceLength;
525             final String cigar = referenceLength >= readMetadata.getAvgReadLen() ? referenceLength + "M"
526                     : (isForwardStrand ? referenceLength + "M" + clipLength + "S"
527                     : clipLength + "S" + referenceLength + "M");
528             final int mapq = Integer.MAX_VALUE;
529             final int mismatches = 0;
530             final String[] tagParts = new String[] {contigName, String.valueOf(start), isForwardStrand ? "+": "-",
531                     cigar, String.valueOf(mapq), String.valueOf(mismatches)};
532             return String.join(",", tagParts) + ";";
533         }
534 
locationFromStringRep(final String locationStr)535         private static SVInterval locationFromStringRep(final String locationStr) {
536             final String[] locationParts = locationStr.split("[\\[\\]:]");
537             validateArg(locationParts.length >= 2, "Could not parse SVInterval from string");
538             final int contig = Integer.parseInt(locationParts[0]);
539             final int start = Integer.parseInt(locationParts[1]);
540             final int end = Integer.parseInt(locationParts[2]);
541             return new SVInterval(contig, start, end);
542         }
543 
strandedLocationFromStringRep(final String locationStr)544         private static StrandedInterval strandedLocationFromStringRep(final String locationStr) {
545             final String[] locationParts = locationStr.split("[\\[\\]:]");
546             validateArg(locationParts.length == 4, "Could not parse StrandedInterval from string");
547             final int contig = Integer.parseInt(locationParts[0]);
548             final int start = Integer.parseInt(locationParts[1]);
549             final int end = Integer.parseInt(locationParts[2]);
550             final boolean strand = locationParts[3].equals("1");
551             return new StrandedInterval(new SVInterval(contig, start, end), strand);
552         }
553 
554     }
555 
556     private static class ClassifierAccuracyData extends JsonMatrixLoader {
557         final EvidenceFeatures[] features;
558         final double[] probability;
559 
ClassifierAccuracyData(final String jsonFileName)560         ClassifierAccuracyData(final String jsonFileName) {
561             try(final InputStream inputStream = new FileInputStream(jsonFileName)) {
562                 final JsonNode testDataNode = new ObjectMapper().readTree(inputStream);
563                 features = getFVecArrayFromJsonNode(testDataNode.get(FEATURES_NODE));
564                 probability = getDoubleArrayFromJsonNode(testDataNode.get(PROBABILITY_NODE));
565             } catch(Exception e) {
566                 throw new GATKException(
567                         "Unable to load classifier test data from " + jsonFileName + ": " + e.getMessage()
568                 );
569             }
570         }
571     }
572 
573     private static class FeaturesTestData extends JsonMatrixLoader {
574         final EvidenceFeatures[] features;
575         final String[] stringReps;
576         final double[] probability;
577         final float coverage;
578         final long[] template_size_cumulative_counts;
579 
FeaturesTestData(final String jsonFileName)580         FeaturesTestData(final String jsonFileName) {
581             try(final InputStream inputStream = new FileInputStream(jsonFileName)) {
582                 final JsonNode testDataNode = new ObjectMapper().readTree(inputStream);
583                 features = getFVecArrayFromJsonNode(testDataNode.get(FEATURES_NODE));
584                 stringReps = getStringArrayFromJsonNode(testDataNode.get(STRING_REPS_NODE));
585                 probability = getDoubleArrayFromJsonNode(testDataNode.get(PROBABILITY_NODE));
586                 coverage = (float)testDataNode.get(MEAN_GENOME_COVERAGE_NODE).asDouble();
587                 template_size_cumulative_counts = getLongArrayFromJsonNode(
588                         testDataNode.get(TEMPLATE_SIZE_CUMULATIVE_COUNTS_NODE)
589                 );
590             } catch(Exception e) {
591                 throw new GATKException(
592                         "Unable to load classifier test data from " + jsonFileName + ": " + e.getMessage()
593                 );
594             }
595         }
596     }
597 
598 
599 }
600