1 package org.broadinstitute.hellbender.tools.spark.sv.evidence; 2 3 import biz.k11i.xgboost.Predictor; 4 import biz.k11i.xgboost.util.FVec; 5 import com.fasterxml.jackson.databind.JsonNode; 6 import com.fasterxml.jackson.databind.ObjectMapper; 7 import htsjdk.samtools.SAMFileHeader; 8 import htsjdk.samtools.SAMReadGroupRecord; 9 import org.apache.spark.api.java.JavaDoubleRDD; 10 import org.apache.spark.api.java.JavaRDD; 11 import org.apache.spark.api.java.JavaSparkContext; 12 import org.broadinstitute.hellbender.GATKBaseTest; 13 import org.broadinstitute.hellbender.engine.spark.SparkContextFactory; 14 import org.broadinstitute.hellbender.exceptions.GATKException; 15 import org.broadinstitute.hellbender.tools.spark.sv.discovery.TestUtilsForAssemblyBasedSVDiscovery; 16 import org.broadinstitute.hellbender.tools.spark.sv.utils.SVInterval; 17 import org.broadinstitute.hellbender.tools.spark.sv.utils.StrandedInterval; 18 import org.broadinstitute.hellbender.tools.spark.utils.IntHistogram; 19 import org.broadinstitute.hellbender.utils.Utils; 20 import org.testng.Assert; 21 import org.testng.annotations.Test; 22 import org.broadinstitute.hellbender.tools.spark.sv.StructuralVariationDiscoveryArgumentCollection.FindBreakpointEvidenceSparkArgumentCollection; 23 24 25 import java.io.*; 26 import java.util.*; 27 import java.util.stream.Collectors; 28 29 import static org.broadinstitute.hellbender.utils.Utils.validateArg; 30 31 public class XGBoostEvidenceFilterUnitTest extends GATKBaseTest { 32 private static final String SV_EVIDENCE_TEST_DIR = toolsTestDir + "spark/sv/evidence/FindBreakpointEvidenceSpark/"; 33 private static final String testAccuracyDataJsonFile = SV_EVIDENCE_TEST_DIR + "sv_classifier_test_data.json"; 34 private static final String classifierModelFile = "/large/sv_evidence_classifier.bin"; 35 private static final String localClassifierModelFile 36 = new File(publicMainResourcesDir, classifierModelFile).getAbsolutePath(); 37 private static final String testFeaturesJsonFile = SV_EVIDENCE_TEST_DIR + "sv_features_test_data.json"; 38 private static final double probabilityTol = 2.0e-3; 39 private static final double featuresTol = 1.0e-5; 40 private static final String SV_GENOME_UMAP_S100_FILE = SV_EVIDENCE_TEST_DIR + "hg38_umap_s100.bed.gz"; 41 private static final String SV_GENOME_GAPS_FILE = SV_EVIDENCE_TEST_DIR + "hg38_gaps.bed.gz"; 42 43 private static final String PANDAS_TABLE_NODE = "pandas.DataFrame"; 44 private static final String PANDAS_COLUMN_NODE = "pandas.Series"; 45 private static final String NUMPY_NODE = "numpy.array"; 46 private static final String FEATURES_NODE = "features"; 47 private static final String STRING_REPS_NODE = "string_reps"; 48 private static final String PROBABILITY_NODE = "proba"; 49 private static final String MEAN_GENOME_COVERAGE_NODE = "coverage"; 50 private static final String TEMPLATE_SIZE_CUMULATIVE_COUNTS_NODE = "template_size_cumulative_counts"; 51 52 private static final ClassifierAccuracyData classifierAccuracyData = new ClassifierAccuracyData(testAccuracyDataJsonFile); 53 private static final Predictor testPredictor = XGBoostEvidenceFilter.loadPredictor(localClassifierModelFile); 54 private static final double[] predictedProbabilitySerial = predictProbability( 55 testPredictor, classifierAccuracyData.features 56 ); 57 private static final FeaturesTestData featuresTestData = new FeaturesTestData(testFeaturesJsonFile); 58 59 private static final FindBreakpointEvidenceSparkArgumentCollection params = initParams(); 60 61 private static final SAMFileHeader artificialSamHeader = initSAMFileHeader(); 62 private static final String readGroupName = "Pond-Testing"; 63 private static final String DEFAULT_SAMPLE_NAME = "SampleX"; 64 private static final ReadMetadata readMetadata = initMetadata(); 65 private static final PartitionCrossingChecker emptyCrossingChecker = new PartitionCrossingChecker(); 66 private static final BreakpointEvidenceFactory breakpointEvidenceFactory = new BreakpointEvidenceFactory(readMetadata); 67 private static final List<BreakpointEvidence> evidenceList = Arrays.stream(featuresTestData.stringReps) 68 .map(breakpointEvidenceFactory::fromStringRep).collect(Collectors.toList()); 69 initParams()70 private static FindBreakpointEvidenceSparkArgumentCollection initParams() { 71 final FindBreakpointEvidenceSparkArgumentCollection params = new FindBreakpointEvidenceSparkArgumentCollection(); 72 params.svGenomeUmapS100File = SV_GENOME_UMAP_S100_FILE; 73 params.svGenomeGapsFile = SV_GENOME_GAPS_FILE; 74 return params; 75 } 76 initSAMFileHeader()77 private static SAMFileHeader initSAMFileHeader() { 78 final SAMFileHeader samHeader = createArtificialSamHeader(); 79 SAMReadGroupRecord readGroup = new SAMReadGroupRecord(readGroupName); 80 readGroup.setSample(DEFAULT_SAMPLE_NAME); 81 samHeader.addReadGroup(readGroup); 82 return samHeader; 83 } 84 85 /** 86 * Create synthetic SAM Header comptible with genome tracts (e.g. has all the primary contigs) 87 */ createArtificialSamHeader()88 public static SAMFileHeader createArtificialSamHeader() { 89 final SAMFileHeader header = new SAMFileHeader(); 90 header.setSortOrder(SAMFileHeader.SortOrder.coordinate); 91 header.setSequenceDictionary(TestUtilsForAssemblyBasedSVDiscovery.bareBoneHg38SAMSeqDict); 92 return header; 93 } 94 initMetadata()95 private static ReadMetadata initMetadata() { 96 final ReadMetadata.PartitionBounds[] partitionBounds = new ReadMetadata.PartitionBounds[3]; 97 partitionBounds[0] = new ReadMetadata.PartitionBounds(0, 1, 0, 10000, 9999); 98 partitionBounds[1] = new ReadMetadata.PartitionBounds(0, 10001, 0, 20000, 9999); 99 partitionBounds[2] = new ReadMetadata.PartitionBounds(0, 20001, 0, 30000, 9999); 100 return new ReadMetadata(Collections.emptySet(), artificialSamHeader, 101 new LibraryStatistics(cumulativeCountsToCDF(featuresTestData.template_size_cumulative_counts), 102 60000000000L, 600000000L, 1200000000000L, 3000000000L), 103 partitionBounds, 100, 10, featuresTestData.coverage); 104 } 105 cumulativeCountsToCDF(final long[] cumulativeCounts)106 private static IntHistogram.CDF cumulativeCountsToCDF(final long[] cumulativeCounts) { 107 final long totalObservations = cumulativeCounts[cumulativeCounts.length - 1]; 108 final float[] cdfFractions = new float[cumulativeCounts.length]; 109 for(int index = 0; index < cdfFractions.length; ++index) { 110 cdfFractions[index] = cumulativeCounts[index] / (float)totalObservations; 111 } 112 return new IntHistogram.CDF(cdfFractions, totalObservations); 113 } 114 115 @Test(groups = "sv") testLocalXGBoostClassifierAccuracy()116 protected void testLocalXGBoostClassifierAccuracy() { 117 // check accuracy: predictions are same as classifierAccuracyData up to tolerance 118 assertArrayEquals(predictedProbabilitySerial, classifierAccuracyData.probability, probabilityTol, "Probabilities predicted by classifier do not match saved correct answers" 119 ); 120 } 121 122 @Test(groups = "sv") testLocalXGBoostClassifierSpark()123 protected void testLocalXGBoostClassifierSpark() { 124 final Predictor localPredictor = XGBoostEvidenceFilter.loadPredictor(localClassifierModelFile); 125 // get spark ctx 126 final JavaSparkContext ctx = SparkContextFactory.getTestSparkContext(); 127 // parallelize classifierAccuracyData to RDD 128 JavaRDD<FVec> testFeaturesRdd = ctx.parallelize(Arrays.asList(classifierAccuracyData.features)); 129 // predict in parallel 130 JavaDoubleRDD predictedProbabilityRdd 131 = testFeaturesRdd.mapToDouble(f -> localPredictor.predictSingle(f, false, 0)); 132 // pull back to local array 133 final double[] predictedProbabilitySpark = predictedProbabilityRdd.collect() 134 .stream().mapToDouble(Double::doubleValue).toArray(); 135 // check probabilities from spark are identical to serial 136 assertArrayEquals(predictedProbabilitySpark, predictedProbabilitySerial, 0.0, "Probabilities predicted in spark context differ from serial" 137 ); 138 } 139 140 @Test(groups = "sv") testResourceXGBoostClassifier()141 protected void testResourceXGBoostClassifier() { 142 // load classifier from resource 143 final Predictor resourcePredictor = XGBoostEvidenceFilter.loadPredictor(null); 144 final double[] predictedProbabilityResource = predictProbability(resourcePredictor, classifierAccuracyData.features); 145 // check that predictions from resource are identical to local 146 assertArrayEquals(predictedProbabilityResource, predictedProbabilitySerial, 0.0, "Predictions via loading predictor from resource is not identical to local file" 147 ); 148 } 149 150 @Test(groups = "sv") testFeatureConstruction()151 protected void testFeatureConstruction() { 152 final XGBoostEvidenceFilter evidenceFilter = new XGBoostEvidenceFilter( 153 evidenceList.iterator(), readMetadata, params, emptyCrossingChecker 154 ); 155 for(int ind = 0; ind < featuresTestData.stringReps.length; ++ind) { 156 final BreakpointEvidence evidence = evidenceList.get(ind); 157 final String stringRep = featuresTestData.stringReps[ind]; 158 final EvidenceFeatures fVec = featuresTestData.features[ind]; 159 final double probability = featuresTestData.probability[ind]; 160 161 final BreakpointEvidence convertedEvidence = breakpointEvidenceFactory.fromStringRep(stringRep); 162 final String convertedRep = convertedEvidence.stringRep(readMetadata, params.minEvidenceMapQ); 163 Assert.assertEquals(convertedRep.trim(), stringRep.trim(), 164 "BreakpointEvidenceFactory.fromStringRep does not invert BreakpointEvidence.stringRep"); 165 final EvidenceFeatures calcFVec = evidenceFilter.getFeatures(evidence); 166 assertArrayEquals(calcFVec.getValues(), fVec.getValues(), featuresTol, 167 "Features calculated by XGBoostEvidenceFilter don't match expected features"); 168 final double calcProbability = evidenceFilter.predictProbability(evidence); 169 Assert.assertEquals(calcProbability, probability, probabilityTol, 170 "Probability calculated by XGBoostEvidenceFilter doesn't match expected probability"); 171 } 172 } 173 174 @Test(groups = "sv") testFilter()175 protected void testFilter() { 176 final XGBoostEvidenceFilter evidenceFilter = new XGBoostEvidenceFilter( 177 evidenceList.iterator(), readMetadata, params, emptyCrossingChecker 178 ); 179 180 // construct list of BreakpointEvidence that is expected to pass the filter 181 final List<BreakpointEvidence> expectedPassed = new ArrayList<>(); 182 List<BreakpointEvidence> sameLocationEvidence = new ArrayList<>(); 183 boolean locationPassed = false; 184 SVInterval previous = null; 185 for(final BreakpointEvidence evidence : evidenceList) { 186 // Use the classifier to calculate probability, to ensure that minor fluctuations that happen to cross the 187 // decision threshold don't cause test failure. Here we only test if the filtering mechanism works correctly. 188 // Accuracy of probability calculation is tested in testFeatureConstruction. 189 final double probability = evidenceFilter.predictProbability(evidence); 190 final boolean matchesPrevious = evidence.getLocation().equals(previous); 191 locationPassed = matchesPrevious ? 192 locationPassed || probability > params.svEvidenceFilterThresholdProbability 193 : probability > params.svEvidenceFilterThresholdProbability; 194 if(locationPassed) { 195 if(matchesPrevious) { 196 expectedPassed.addAll(sameLocationEvidence); 197 } else { 198 previous = evidence.getLocation(); 199 } 200 sameLocationEvidence.clear(); 201 expectedPassed.add(evidence); 202 } else if(matchesPrevious) { 203 sameLocationEvidence.add(evidence); 204 } else { 205 sameLocationEvidence.clear(); 206 previous = evidence.getLocation(); 207 } 208 } 209 sameLocationEvidence.clear(); 210 211 // use evidenceFilter to populate array with passed evidence 212 final List<BreakpointEvidence> passedEvidence = new ArrayList<>(); 213 evidenceFilter.forEachRemaining(passedEvidence::add); 214 215 Assert.assertEquals(passedEvidence, expectedPassed, 216 "Evidence passed by XGBoostEvidenceFilter not the same as expected"); 217 } 218 assertArrayEquals(final double[] actuals, final double[] expecteds, final double tol, final String message)219 private static void assertArrayEquals(final double[] actuals, final double[] expecteds, final double tol, 220 final String message) { 221 Assert.assertEquals(actuals.length, expecteds.length, "Lengths not equal: " + message); 222 for(int index = 0; index < expecteds.length; ++index) { 223 Assert.assertEquals(actuals[index], expecteds[index], tol, "at index=" + index + ": " + message); 224 } 225 } 226 predictProbability(final Predictor predictor, final FVec[] testFeatures)227 private static double[] predictProbability(final Predictor predictor, final FVec[] testFeatures) { 228 229 return Arrays.stream(testFeatures).mapToDouble( 230 features -> predictor.predictSingle(features, false, 0) 231 ).toArray(); 232 } 233 234 static class JsonMatrixLoader { 235 236 private static final String CLASS_NODE = "__class__"; 237 private static final String DATA_NODE = "data"; 238 private static final String VALUES_NODE = "values"; 239 private static final String CODES_NODE = "codes"; 240 getFVecArrayFromJsonNode(final JsonNode matrixNode)241 static EvidenceFeatures[] getFVecArrayFromJsonNode(final JsonNode matrixNode) { 242 if(!matrixNode.has(CLASS_NODE)) { 243 throw new IllegalArgumentException("JSON node does not store python matrix data"); 244 } 245 String matrixClass = matrixNode.get(CLASS_NODE).asText(); 246 switch(matrixClass) { 247 case PANDAS_TABLE_NODE: 248 return getFVecArrayFromPandasJsonNode(matrixNode.get(DATA_NODE)); 249 case NUMPY_NODE: 250 return getFVecArrayFromNumpyJsonNode(matrixNode.get(DATA_NODE)); 251 default: 252 throw new IllegalArgumentException("JSON node has " + CLASS_NODE + " = " + matrixClass 253 + " which is not a supported matrix type"); 254 } 255 } 256 getFVecArrayFromNumpyJsonNode(final JsonNode dataNode)257 private static EvidenceFeatures[] getFVecArrayFromNumpyJsonNode(final JsonNode dataNode) { 258 Utils.validateArg(dataNode.isArray(), "dataNode does not encode a valid numpy array"); 259 final int numRows = dataNode.size(); 260 final EvidenceFeatures[] matrix = new EvidenceFeatures[numRows]; 261 if (numRows == 0) { 262 return matrix; 263 } 264 matrix[0] = new EvidenceFeatures(getDoubleArrayFromJsonArrayNode(dataNode.get(0))); 265 final int numColumns = matrix[0].length(); 266 for (int row = 1; row < numRows; ++row) { 267 matrix[row] = new EvidenceFeatures(getDoubleArrayFromJsonArrayNode(dataNode.get(row))); 268 final int numRowColumns = matrix[row].length(); 269 Utils.validateArg(numRowColumns == numColumns, "Rows in JSONArray have different lengths."); 270 } 271 return matrix; 272 } 273 getFVecArrayFromPandasJsonNode(final JsonNode dataNode)274 private static EvidenceFeatures[] getFVecArrayFromPandasJsonNode(final JsonNode dataNode) { 275 if(!dataNode.isObject()) { 276 throw new IllegalArgumentException("dataNode does not encode a valid pandas DataFrame"); 277 } 278 final int numColumns = dataNode.size(); 279 if(numColumns == 0) { 280 return new EvidenceFeatures[0]; 281 } 282 283 final String firstColumnName = dataNode.fieldNames().next(); 284 final int numRows = getColumnArrayNode(dataNode.get(firstColumnName)).size(); 285 final EvidenceFeatures[] matrix = new EvidenceFeatures[numRows]; 286 if (numRows == 0) { 287 return matrix; 288 } 289 // allocate each EvidenceFeatures in matrix 290 for(int rowIndex = 0; rowIndex < numRows; ++rowIndex) { 291 matrix[rowIndex] = new EvidenceFeatures(numColumns); 292 } 293 int columnIndex = 0; 294 for(final Iterator<Map.Entry<String, JsonNode>> fieldIter = dataNode.fields(); fieldIter.hasNext();) { 295 // loop over columns 296 final Map.Entry<String, JsonNode> columnEntry = fieldIter.next(); 297 final JsonNode columnArrayNode = getColumnArrayNode(columnEntry.getValue()); 298 Utils.validateArg(columnArrayNode.size() == numRows, 299 "field " + columnEntry.getKey() + " has " + columnArrayNode.size() + " rows (expected " + numRows + ")"); 300 // for each FVec in matrix, assign feature from this column 301 int rowIndex = 0; 302 for(final JsonNode valueNode: columnArrayNode) { 303 final EvidenceFeatures fVec = matrix[rowIndex]; 304 fVec.setValue(columnIndex, valueNode.asDouble()); 305 ++rowIndex; 306 } 307 ++columnIndex; 308 } 309 return matrix; 310 } 311 getColumnArrayNode(final JsonNode columnNode)312 private static JsonNode getColumnArrayNode(final JsonNode columnNode) { 313 return columnNode.has(VALUES_NODE) ? columnNode.get(VALUES_NODE) : columnNode.get(CODES_NODE); 314 } 315 getDoubleArrayFromJsonNode(final JsonNode vectorNode)316 static double[] getDoubleArrayFromJsonNode(final JsonNode vectorNode) { 317 if(!vectorNode.has(CLASS_NODE)) { 318 return getDoubleArrayFromJsonArrayNode(vectorNode); 319 } 320 final String vectorClass = vectorNode.get(CLASS_NODE).asText(); 321 switch(vectorClass) { 322 case PANDAS_COLUMN_NODE: 323 return getDoubleArrayFromJsonArrayNode(getColumnArrayNode(vectorNode)); 324 case NUMPY_NODE: 325 return getDoubleArrayFromJsonArrayNode(vectorNode.get(DATA_NODE)); 326 default: 327 throw new IllegalArgumentException("JSON node has " + CLASS_NODE + " = " + vectorClass 328 + "which is not a supported vector type"); 329 } 330 } 331 getDoubleArrayFromJsonArrayNode(final JsonNode arrayNode)332 private static double [] getDoubleArrayFromJsonArrayNode(final JsonNode arrayNode) { 333 if(!arrayNode.isArray()) { 334 throw new IllegalArgumentException("JsonNode does not contain an Array"); 335 } 336 final int numData = arrayNode.size(); 337 final double[] data = new double[numData]; 338 int ind = 0; 339 for(final JsonNode valueNode : arrayNode) { 340 data[ind] = valueNode.asDouble(); 341 ++ind; 342 } 343 return data; 344 } 345 getLongArrayFromJsonNode(final JsonNode vectorNode)346 static long[] getLongArrayFromJsonNode(final JsonNode vectorNode) { 347 if(!vectorNode.has(CLASS_NODE)) { 348 return getLongArrayFromJsonArrayNode(vectorNode); 349 } 350 final String vectorClass = vectorNode.get(CLASS_NODE).asText(); 351 switch(vectorClass) { 352 case PANDAS_COLUMN_NODE: 353 return getLongArrayFromJsonArrayNode(getColumnArrayNode(vectorNode)); 354 case NUMPY_NODE: 355 return getLongArrayFromJsonArrayNode(vectorNode.get(DATA_NODE)); 356 default: 357 throw new IllegalArgumentException("JSON node has " + CLASS_NODE + " = " + vectorClass 358 + "which is not a supported vector type"); 359 } 360 } 361 getLongArrayFromJsonArrayNode(final JsonNode arrayNode)362 private static long [] getLongArrayFromJsonArrayNode(final JsonNode arrayNode) { 363 if(!arrayNode.isArray()) { 364 throw new IllegalArgumentException("JsonNode does not contain an Array"); 365 } 366 final int numData = arrayNode.size(); 367 final long[] data = new long[numData]; 368 int ind = 0; 369 for(final JsonNode valueNode : arrayNode) { 370 data[ind] = valueNode.asInt(); 371 ++ind; 372 } 373 return data; 374 } 375 getStringArrayFromJsonNode(final JsonNode arrayNode)376 static String[] getStringArrayFromJsonNode(final JsonNode arrayNode) { 377 if(!arrayNode.isArray()) { 378 throw new IllegalArgumentException("JsonNode does not contain an Array"); 379 } 380 final int numStrings = arrayNode.size(); 381 final String[] stringArray = new String[numStrings]; 382 int ind = 0; 383 for(final JsonNode stringNode : arrayNode) { 384 stringArray[ind] = stringNode.asText(); 385 ++ind; 386 } 387 return stringArray; 388 } 389 } 390 391 private static class BreakpointEvidenceFactory { 392 private static final String DEFAULT_POND_NAME = "Pond-Testing"; 393 final ReadMetadata readMetadata; 394 BreakpointEvidenceFactory(final ReadMetadata readMetadata)395 BreakpointEvidenceFactory(final ReadMetadata readMetadata) { 396 this.readMetadata = readMetadata; 397 } 398 399 /** 400 * Returns BreakpointEvidence constructed from string representation. Used to reconstruct BreakpointEvidence for 401 * unit tests. It is intended for stringRep() to be an inverse of this function, but not the other way around. i.e. 402 * fromStringRep(strRep, readMetadata).stringRep(readMetadata, minEvidenceMapQ) == strRep 403 * but it may be the case that 404 * fromStringRep(evidence.stringRep(readMetadata, minEvidenceMapQ), readMetadata) != evidence 405 */ fromStringRep(final String strRep)406 BreakpointEvidence fromStringRep(final String strRep) { 407 final String[] words = strRep.split("\t"); 408 409 final SVInterval location = locationFromStringRep(words[0]); 410 411 final int weight = Integer.parseInt(words[1]); 412 413 final String evidenceType = words[2]; 414 if(evidenceType.equals("TemplateSizeAnomaly")) { 415 final int readCount = Integer.parseInt(words[4]); 416 return new BreakpointEvidence.TemplateSizeAnomaly(location, weight, readCount); 417 } else { 418 final List<StrandedInterval> distalTargets = words[3].isEmpty() ? new ArrayList<>() 419 : Arrays.stream(words[3].split(";")).map(BreakpointEvidenceFactory::strandedLocationFromStringRep) 420 .collect(Collectors.toList()); 421 validateArg(distalTargets.size() <= 1, "BreakpointEvidence must have 0 or 1 distal targets"); 422 final String[] templateParts = words[4].split("/"); 423 final String templateName = templateParts[0]; 424 final TemplateFragmentOrdinal fragmentOrdinal; 425 if(templateParts.length <= 1) { 426 fragmentOrdinal = TemplateFragmentOrdinal.UNPAIRED; 427 } else switch (templateParts[1]) { 428 case "0": 429 fragmentOrdinal = TemplateFragmentOrdinal.PAIRED_INTERIOR; 430 break; 431 case "1": 432 fragmentOrdinal = TemplateFragmentOrdinal.PAIRED_FIRST; 433 break; 434 case "2": 435 fragmentOrdinal = TemplateFragmentOrdinal.PAIRED_SECOND; 436 break; 437 case "?": 438 fragmentOrdinal = TemplateFragmentOrdinal.PAIRED_UNKNOWN; 439 break; 440 default: 441 throw new IllegalArgumentException("Unknown Template Fragment Ordinal: /" + templateParts[1]); 442 } 443 final boolean forwardStrand = words[5].equals("1"); 444 final int templateSize = Integer.parseInt(words[6]); 445 final String cigarString = words[7]; 446 final int mappingQuality = Integer.parseInt(words[8]); 447 final String readGroup = DEFAULT_POND_NAME; // for now, just fake this, only for testing. 448 final boolean validated = false; 449 450 451 final SVInterval target; 452 final boolean targetForwardStrand; 453 final int targetQuality; 454 switch(distalTargets.size()) { 455 case 0: 456 target = new SVInterval(0, 0, 0); 457 targetForwardStrand = false; 458 targetQuality = -1; 459 break; 460 case 1: 461 target = distalTargets.get(0).getInterval(); 462 targetForwardStrand = distalTargets.get(0).getStrand(); 463 targetQuality = Integer.MAX_VALUE; 464 break; 465 default: 466 throw new IllegalArgumentException("BreakpointEvidence must have <= 1 distal target"); 467 } 468 469 switch(evidenceType) { 470 case "SplitRead": 471 // NOTE: can't identically reconstruct original values, but can make self-consistent values that reproduce 472 // the known distal targets. Make plausible cigar strings, primaryAlignmentClippedAtStart and 473 // primaryAlignmentForwardStrand that are compatible with dumped distal targets. 474 final String tagSA = distalTargets.isEmpty() ? null : distalTargets.stream().map(this::distalTargetToTagSA).collect(Collectors.joining()); 475 return new BreakpointEvidence.SplitRead(location, weight, templateName, fragmentOrdinal, validated, 476 forwardStrand, cigarString, mappingQuality, templateSize, readGroup, 477 forwardStrand, forwardStrand, tagSA); 478 479 case "LargeIndel": 480 Utils.validateArg(distalTargets.isEmpty(), "LargeIndel should have no distal targets"); 481 return new BreakpointEvidence.LargeIndel(location, weight, templateName, fragmentOrdinal, validated, 482 forwardStrand, cigarString, mappingQuality, templateSize, readGroup); 483 484 case "MateUnmapped": 485 Utils.validateArg(distalTargets.isEmpty(), "MateUnmapped should have no distal targets"); 486 return new BreakpointEvidence.MateUnmapped(location, weight, templateName, fragmentOrdinal, validated, 487 forwardStrand, cigarString, mappingQuality, templateSize, readGroup); 488 489 case "InterContigPair": 490 return new BreakpointEvidence.InterContigPair( 491 location, weight, templateName, fragmentOrdinal, validated, forwardStrand, cigarString, 492 mappingQuality, templateSize, readGroup, target, targetForwardStrand, targetQuality 493 ); 494 495 case "OutiesPair": 496 return new BreakpointEvidence.OutiesPair( 497 location, weight, templateName, fragmentOrdinal, validated, forwardStrand, cigarString, 498 mappingQuality, templateSize, readGroup, target, targetForwardStrand, targetQuality 499 ); 500 501 case "SameStrandPair": 502 return new BreakpointEvidence.SameStrandPair( 503 location, weight, templateName, fragmentOrdinal, validated, forwardStrand, cigarString, 504 mappingQuality, templateSize, readGroup, target, targetForwardStrand, targetQuality 505 ); 506 507 case "WeirdTemplateSize": 508 return new BreakpointEvidence.WeirdTemplateSize( 509 location, weight, templateName, fragmentOrdinal, validated, forwardStrand, cigarString, 510 mappingQuality, templateSize, readGroup, target, targetForwardStrand, targetQuality 511 ); 512 default: 513 throw new IllegalArgumentException("Unknown BreakpointEvidence type: " + evidenceType); 514 } 515 } 516 } 517 distalTargetToTagSA(final StrandedInterval distalTarget)518 private String distalTargetToTagSA(final StrandedInterval distalTarget) { 519 final String contigName = readMetadata.getContigName(distalTarget.getInterval().getContig()); 520 final boolean isForwardStrand = distalTarget.getStrand(); 521 final int referenceLength = distalTarget.getInterval().getLength(); 522 final int pos = distalTarget.getInterval().getEnd() - 1 - BreakpointEvidence.SplitRead.UNCERTAINTY; 523 final int start = isForwardStrand ? pos - referenceLength: pos; 524 final int clipLength = readMetadata.getAvgReadLen() - referenceLength; 525 final String cigar = referenceLength >= readMetadata.getAvgReadLen() ? referenceLength + "M" 526 : (isForwardStrand ? referenceLength + "M" + clipLength + "S" 527 : clipLength + "S" + referenceLength + "M"); 528 final int mapq = Integer.MAX_VALUE; 529 final int mismatches = 0; 530 final String[] tagParts = new String[] {contigName, String.valueOf(start), isForwardStrand ? "+": "-", 531 cigar, String.valueOf(mapq), String.valueOf(mismatches)}; 532 return String.join(",", tagParts) + ";"; 533 } 534 locationFromStringRep(final String locationStr)535 private static SVInterval locationFromStringRep(final String locationStr) { 536 final String[] locationParts = locationStr.split("[\\[\\]:]"); 537 validateArg(locationParts.length >= 2, "Could not parse SVInterval from string"); 538 final int contig = Integer.parseInt(locationParts[0]); 539 final int start = Integer.parseInt(locationParts[1]); 540 final int end = Integer.parseInt(locationParts[2]); 541 return new SVInterval(contig, start, end); 542 } 543 strandedLocationFromStringRep(final String locationStr)544 private static StrandedInterval strandedLocationFromStringRep(final String locationStr) { 545 final String[] locationParts = locationStr.split("[\\[\\]:]"); 546 validateArg(locationParts.length == 4, "Could not parse StrandedInterval from string"); 547 final int contig = Integer.parseInt(locationParts[0]); 548 final int start = Integer.parseInt(locationParts[1]); 549 final int end = Integer.parseInt(locationParts[2]); 550 final boolean strand = locationParts[3].equals("1"); 551 return new StrandedInterval(new SVInterval(contig, start, end), strand); 552 } 553 554 } 555 556 private static class ClassifierAccuracyData extends JsonMatrixLoader { 557 final EvidenceFeatures[] features; 558 final double[] probability; 559 ClassifierAccuracyData(final String jsonFileName)560 ClassifierAccuracyData(final String jsonFileName) { 561 try(final InputStream inputStream = new FileInputStream(jsonFileName)) { 562 final JsonNode testDataNode = new ObjectMapper().readTree(inputStream); 563 features = getFVecArrayFromJsonNode(testDataNode.get(FEATURES_NODE)); 564 probability = getDoubleArrayFromJsonNode(testDataNode.get(PROBABILITY_NODE)); 565 } catch(Exception e) { 566 throw new GATKException( 567 "Unable to load classifier test data from " + jsonFileName + ": " + e.getMessage() 568 ); 569 } 570 } 571 } 572 573 private static class FeaturesTestData extends JsonMatrixLoader { 574 final EvidenceFeatures[] features; 575 final String[] stringReps; 576 final double[] probability; 577 final float coverage; 578 final long[] template_size_cumulative_counts; 579 FeaturesTestData(final String jsonFileName)580 FeaturesTestData(final String jsonFileName) { 581 try(final InputStream inputStream = new FileInputStream(jsonFileName)) { 582 final JsonNode testDataNode = new ObjectMapper().readTree(inputStream); 583 features = getFVecArrayFromJsonNode(testDataNode.get(FEATURES_NODE)); 584 stringReps = getStringArrayFromJsonNode(testDataNode.get(STRING_REPS_NODE)); 585 probability = getDoubleArrayFromJsonNode(testDataNode.get(PROBABILITY_NODE)); 586 coverage = (float)testDataNode.get(MEAN_GENOME_COVERAGE_NODE).asDouble(); 587 template_size_cumulative_counts = getLongArrayFromJsonNode( 588 testDataNode.get(TEMPLATE_SIZE_CUMULATIVE_COUNTS_NODE) 589 ); 590 } catch(Exception e) { 591 throw new GATKException( 592 "Unable to load classifier test data from " + jsonFileName + ": " + e.getMessage() 593 ); 594 } 595 } 596 } 597 598 599 } 600