1 package org.broadinstitute.hellbender.tools.spark.transforms.markduplicates; 2 3 import com.esotericsoftware.kryo.Kryo; 4 import com.google.api.client.util.Lists; 5 import com.google.common.collect.ImmutableList; 6 import htsjdk.samtools.*; 7 import org.apache.spark.SparkConf; 8 import org.apache.spark.SparkException; 9 import org.apache.spark.api.java.JavaRDD; 10 import org.apache.spark.api.java.JavaSparkContext; 11 import org.apache.spark.serializer.KryoRegistrator; 12 import org.broadinstitute.hellbender.GATKBaseTest; 13 import org.broadinstitute.hellbender.engine.spark.SAMRecordSerializer; 14 import org.broadinstitute.hellbender.engine.spark.SparkContextFactory; 15 import org.broadinstitute.hellbender.exceptions.UserException; 16 import org.broadinstitute.hellbender.testutils.SparkTestUtils; 17 import org.broadinstitute.hellbender.utils.read.ArtificialReadUtils; 18 import org.broadinstitute.hellbender.utils.read.GATKRead; 19 import org.broadinstitute.hellbender.utils.read.SAMRecordToGATKReadAdapter; 20 import org.broadinstitute.hellbender.utils.read.markduplicates.MarkDuplicatesScoringStrategy; 21 import org.broadinstitute.hellbender.utils.spark.SparkUtils; 22 import org.testng.Assert; 23 import org.testng.annotations.Test; 24 import picard.sam.markduplicates.MarkDuplicates; 25 import picard.sam.markduplicates.util.OpticalDuplicateFinder; 26 import scala.Tuple2; 27 28 import java.io.IOException; 29 import java.util.*; 30 31 public class MarkDuplicatesSparkUtilsUnitTest extends GATKBaseTest { 32 @Test(groups = "spark") testSpanningIterator()33 public void testSpanningIterator() { 34 check(Collections.emptyIterator(), Collections.emptyList()); 35 check(ImmutableList.of(pair(1, "a")).iterator(), 36 ImmutableList.of(pairIterable(1, "a"))); 37 check(ImmutableList.of(pair(1, "a"), pair(1, "b")).iterator(), 38 ImmutableList.of(pairIterable(1, "a", "b"))); 39 check(ImmutableList.of(pair(1, "a"), pair(2, "b")).iterator(), 40 ImmutableList.of(pairIterable(1, "a"), pairIterable(2, "b"))); 41 check(ImmutableList.of(pair(1, "a"), pair(1, "b"), pair(2, "c")).iterator(), 42 ImmutableList.of(pairIterable(1, "a", "b"), pairIterable(2, "c"))); 43 check(ImmutableList.of(pair(1, "a"), pair(2, "b"), pair(2, "c")).iterator(), 44 ImmutableList.of(pairIterable(1, "a"), pairIterable(2, "b", "c"))); 45 check(ImmutableList.of(pair(1, "a"), pair(2, "b"), pair(1, "c")).iterator(), 46 ImmutableList.of(pairIterable(1, "a"), pairIterable(2, "b"), pairIterable(1, "c"))); 47 } 48 49 getReadGroupId(final SAMFileHeader header, final int index)50 private String getReadGroupId(final SAMFileHeader header, final int index) { 51 return header.getReadGroups().get(index).getReadGroupId(); 52 } 53 check(Iterator<Tuple2<K, V>> it, List<Tuple2<K, Iterable<V>>> expected)54 private static <K, V> void check(Iterator<Tuple2<K, V>> it, List<Tuple2<K, Iterable<V>>> expected) { 55 Iterator<Tuple2<K, Iterable<V>>> spanning = SparkUtils.getSpanningIterator(it); 56 ArrayList<Tuple2<K, Iterable<V>>> actual = Lists.newArrayList(spanning); 57 Assert.assertEquals(actual, expected); 58 } 59 pair(K key, V value)60 private static <K, V> Tuple2<K, V> pair(K key, V value) { 61 return new Tuple2<>(key, value); 62 } 63 pairIterable(int i, String... s)64 private static Tuple2<Integer, Iterable<String>> pairIterable(int i, String... s) { 65 return new Tuple2<>(i, ImmutableList.copyOf(s)); 66 } 67 pairIterable(String key, GATKRead... reads)68 private static Tuple2<String, Iterable<GATKRead>> pairIterable(String key, GATKRead... reads) { 69 return new Tuple2<>(key, ImmutableList.copyOf(reads)); 70 } 71 72 @Test(expectedExceptions = UserException.BadInput.class) testHeaderMissingReadGroupFilds()73 public void testHeaderMissingReadGroupFilds() { 74 JavaSparkContext ctx = SparkContextFactory.getTestSparkContext(); 75 76 SAMRecordSetBuilder samRecordSetBuilder = new SAMRecordSetBuilder(true, SAMFileHeader.SortOrder.queryname, 77 true, SAMRecordSetBuilder.DEFAULT_CHROMOSOME_LENGTH, SAMRecordSetBuilder.DEFAULT_DUPLICATE_SCORING_STRATEGY); 78 79 JavaRDD<GATKRead> reads = ctx.parallelize(new ArrayList<>(), 2); 80 SAMFileHeader header = samRecordSetBuilder.getHeader(); 81 header.setReadGroups(new ArrayList<>()); 82 83 MarkDuplicatesSparkUtils.transformToDuplicateNames(header, MarkDuplicatesScoringStrategy.SUM_OF_BASE_QUALITIES, null, reads, 2, false).collect(); 84 } 85 86 @Test testReadsMissingReadGroups()87 public void testReadsMissingReadGroups() { 88 JavaSparkContext ctx = SparkContextFactory.getTestSparkContext(); 89 90 SAMRecordSetBuilder samRecordSetBuilder = new SAMRecordSetBuilder(true, SAMFileHeader.SortOrder.queryname, 91 true, SAMRecordSetBuilder.DEFAULT_CHROMOSOME_LENGTH, SAMRecordSetBuilder.DEFAULT_DUPLICATE_SCORING_STRATEGY); 92 samRecordSetBuilder.addFrag("READ" , 0, 10000, false); 93 94 JavaRDD<GATKRead> reads = ctx.parallelize(Lists.newArrayList(samRecordSetBuilder.getRecords()), 2).map(SAMRecordToGATKReadAdapter::new); 95 reads = reads.map(r -> {r.setReadGroup(null); return r;}); 96 SAMFileHeader header = samRecordSetBuilder.getHeader(); 97 98 try { 99 MarkDuplicatesSparkUtils.transformToDuplicateNames(header, MarkDuplicatesScoringStrategy.SUM_OF_BASE_QUALITIES, null, reads, 2, false).collect(); 100 Assert.fail("Should have thrown an exception"); 101 } catch (Exception e){ 102 Assert.assertTrue(e instanceof SparkException); 103 Assert.assertTrue(e.getCause() instanceof UserException.ReadMissingReadGroup); 104 } 105 } 106 107 @Test 108 // Test that asserts the duplicate marking is sorting agnostic, specifically this is testing that when reads are scrambled across 109 // partitions in the input that all reads in a group are getting properly duplicate marked together as they are for queryname sorted bams testSortOrderPartitioningCorrectness()110 public void testSortOrderPartitioningCorrectness() throws IOException { 111 112 JavaSparkContext ctx = SparkContextFactory.getTestSparkContext(); 113 JavaRDD<GATKRead> unsortedReads = generateReadsWithDuplicates(10000,3, ctx, 99, true); 114 JavaRDD<GATKRead> pariedEndsQueryGrouped = generateReadsWithDuplicates(10000,3, ctx,1, false); //Use only one partition to avoid having to do edge fixing. 115 116 // Create headers reflecting the respective sort ordering of the trial reads 117 SAMReadGroupRecord readGroup1 = new SAMReadGroupRecord("1"); 118 readGroup1.setAttribute(SAMReadGroupRecord.READ_GROUP_SAMPLE_TAG, "test"); 119 120 SAMFileHeader unsortedHeader = hg19Header.clone(); 121 unsortedHeader.addReadGroup(readGroup1); 122 unsortedHeader.setSortOrder(SAMFileHeader.SortOrder.unsorted); 123 SAMFileHeader sortedHeader = hg19Header.clone(); 124 sortedHeader.addReadGroup(readGroup1); 125 sortedHeader.setSortOrder(SAMFileHeader.SortOrder.queryname); 126 127 // Using the header flagged as unsorted will result in the reads being sorted again 128 JavaRDD<GATKRead> unsortedReadsMarked = MarkDuplicatesSpark.mark(unsortedReads,unsortedHeader, MarkDuplicatesScoringStrategy.SUM_OF_BASE_QUALITIES,new OpticalDuplicateFinder(),100,true,MarkDuplicates.DuplicateTaggingPolicy.DontTag); 129 JavaRDD<GATKRead> sortedReadsMarked = MarkDuplicatesSpark.mark(pariedEndsQueryGrouped,sortedHeader, MarkDuplicatesScoringStrategy.SUM_OF_BASE_QUALITIES,new OpticalDuplicateFinder(),1, true, MarkDuplicates.DuplicateTaggingPolicy.DontTag); 130 131 Iterator<GATKRead> sortedReadsFinal = sortedReadsMarked.sortBy(GATKRead::commonToString, false, 1).collect().iterator(); 132 Iterator<GATKRead> unsortedReadsFinal = unsortedReadsMarked.sortBy(GATKRead::commonToString, false, 1).collect().iterator(); 133 134 // Comparing the output reads to ensure they are all duplicate marked correctly 135 while (sortedReadsFinal.hasNext()) { 136 GATKRead read1 = sortedReadsFinal.next(); 137 GATKRead read2 = unsortedReadsFinal.next(); 138 Assert.assertEquals(read1.getName(), read2.getName()); 139 Assert.assertEquals(read1.isDuplicate(), read2.isDuplicate()); 140 } 141 } 142 143 // This helper method is used to generate groups reads that will be duplicate marked. It does this by generating numDuplicatesPerGroup 144 // pairs of reads starting at randomly selected starting locations. The start locations are random so that if the resulting RDD is 145 // coordinate sorted that it is more or less guaranteed that a large portion of the reads will reside on separate partitions from 146 // their mates. It also handles sorting of the reads into either queryname or coordinate orders. generateReadsWithDuplicates(int numReadGroups, int numDuplicatesPerGroup, JavaSparkContext ctx, int numPartitions, boolean coordinate)147 private JavaRDD<GATKRead> generateReadsWithDuplicates(int numReadGroups, int numDuplicatesPerGroup, JavaSparkContext ctx, int numPartitions, boolean coordinate) { 148 int readNameCounter = 0; 149 SAMRecordSetBuilder samRecordSetBuilder = new SAMRecordSetBuilder(true, SAMFileHeader.SortOrder.coordinate, 150 true, SAMRecordSetBuilder.DEFAULT_CHROMOSOME_LENGTH, SAMRecordSetBuilder.DEFAULT_DUPLICATE_SCORING_STRATEGY); 151 152 Random rand = new Random(10); 153 for (int i = 0; i < numReadGroups; i++ ) { 154 int start1 = rand.nextInt(SAMRecordSetBuilder.DEFAULT_CHROMOSOME_LENGTH); 155 int start2 = rand.nextInt(SAMRecordSetBuilder.DEFAULT_CHROMOSOME_LENGTH); 156 for (int j = 0; j < numDuplicatesPerGroup; j++) { 157 samRecordSetBuilder.addPair("READ" + readNameCounter++, 0, start1, start2); 158 } 159 } 160 List<SAMRecord> records = Lists.newArrayList(samRecordSetBuilder.getRecords()); 161 if (coordinate) { 162 records.sort(new SAMRecordCoordinateComparator()); 163 } else { 164 records.sort(new SAMRecordQueryNameComparator()); 165 } 166 167 return ctx.parallelize(records, numPartitions).map(SAMRecordToGATKReadAdapter::new); 168 } 169 170 @Test testChangingContigsOnHeaderlessSAMRecord()171 public void testChangingContigsOnHeaderlessSAMRecord() { 172 final SparkConf conf = new SparkConf().set("spark.kryo.registrator", 173 "org.broadinstitute.hellbender.tools.spark.transforms.markduplicates.MarkDuplicatesSparkUtilsUnitTest$TestGATKRegistrator"); 174 final SAMRecord read = ((SAMRecordToGATKReadAdapter) ArtificialReadUtils.createHeaderlessSamBackedRead("read1", "1", 100, 50)).getEncapsulatedSamRecord(); 175 final OpticalDuplicateFinder finder = new OpticalDuplicateFinder(OpticalDuplicateFinder.DEFAULT_READ_NAME_REGEX,2500, null); 176 177 final OpticalDuplicateFinder roundTrippedRead = SparkTestUtils.roundTripInKryo(finder, OpticalDuplicateFinder.class, conf); 178 Assert.assertEquals(roundTrippedRead.opticalDuplicatePixelDistance, finder.opticalDuplicatePixelDistance); 179 } 180 181 public static class TestGATKRegistrator implements KryoRegistrator { 182 @SuppressWarnings("unchecked") 183 @Override registerClasses(Kryo kryo)184 public void registerClasses(Kryo kryo) { 185 kryo.register(SAMRecord.class, new SAMRecordSerializer()); 186 } 187 } 188 189 190 } 191