1 package org.broadinstitute.hellbender.tools.spark.transforms.markduplicates;
2 
3 import com.esotericsoftware.kryo.Kryo;
4 import com.google.api.client.util.Lists;
5 import com.google.common.collect.ImmutableList;
6 import htsjdk.samtools.*;
7 import org.apache.spark.SparkConf;
8 import org.apache.spark.SparkException;
9 import org.apache.spark.api.java.JavaRDD;
10 import org.apache.spark.api.java.JavaSparkContext;
11 import org.apache.spark.serializer.KryoRegistrator;
12 import org.broadinstitute.hellbender.GATKBaseTest;
13 import org.broadinstitute.hellbender.engine.spark.SAMRecordSerializer;
14 import org.broadinstitute.hellbender.engine.spark.SparkContextFactory;
15 import org.broadinstitute.hellbender.exceptions.UserException;
16 import org.broadinstitute.hellbender.testutils.SparkTestUtils;
17 import org.broadinstitute.hellbender.utils.read.ArtificialReadUtils;
18 import org.broadinstitute.hellbender.utils.read.GATKRead;
19 import org.broadinstitute.hellbender.utils.read.SAMRecordToGATKReadAdapter;
20 import org.broadinstitute.hellbender.utils.read.markduplicates.MarkDuplicatesScoringStrategy;
21 import org.broadinstitute.hellbender.utils.spark.SparkUtils;
22 import org.testng.Assert;
23 import org.testng.annotations.Test;
24 import picard.sam.markduplicates.MarkDuplicates;
25 import picard.sam.markduplicates.util.OpticalDuplicateFinder;
26 import scala.Tuple2;
27 
28 import java.io.IOException;
29 import java.util.*;
30 
31 public class MarkDuplicatesSparkUtilsUnitTest extends GATKBaseTest {
32     @Test(groups = "spark")
testSpanningIterator()33     public void testSpanningIterator() {
34         check(Collections.emptyIterator(), Collections.emptyList());
35         check(ImmutableList.of(pair(1, "a")).iterator(),
36                 ImmutableList.of(pairIterable(1, "a")));
37         check(ImmutableList.of(pair(1, "a"), pair(1, "b")).iterator(),
38                 ImmutableList.of(pairIterable(1, "a", "b")));
39         check(ImmutableList.of(pair(1, "a"), pair(2, "b")).iterator(),
40                 ImmutableList.of(pairIterable(1, "a"), pairIterable(2, "b")));
41         check(ImmutableList.of(pair(1, "a"), pair(1, "b"), pair(2, "c")).iterator(),
42                 ImmutableList.of(pairIterable(1, "a", "b"), pairIterable(2, "c")));
43         check(ImmutableList.of(pair(1, "a"), pair(2, "b"), pair(2, "c")).iterator(),
44                 ImmutableList.of(pairIterable(1, "a"), pairIterable(2, "b", "c")));
45         check(ImmutableList.of(pair(1, "a"), pair(2, "b"), pair(1, "c")).iterator(),
46                 ImmutableList.of(pairIterable(1, "a"), pairIterable(2, "b"), pairIterable(1, "c")));
47     }
48 
49 
getReadGroupId(final SAMFileHeader header, final int index)50     private String getReadGroupId(final SAMFileHeader header, final int index) {
51         return header.getReadGroups().get(index).getReadGroupId();
52     }
53 
check(Iterator<Tuple2<K, V>> it, List<Tuple2<K, Iterable<V>>> expected)54     private static <K, V> void check(Iterator<Tuple2<K, V>> it, List<Tuple2<K, Iterable<V>>> expected) {
55         Iterator<Tuple2<K, Iterable<V>>> spanning = SparkUtils.getSpanningIterator(it);
56         ArrayList<Tuple2<K, Iterable<V>>> actual = Lists.newArrayList(spanning);
57         Assert.assertEquals(actual, expected);
58     }
59 
pair(K key, V value)60     private static <K, V> Tuple2<K, V> pair(K key, V value) {
61         return new Tuple2<>(key, value);
62     }
63 
pairIterable(int i, String... s)64     private static Tuple2<Integer, Iterable<String>> pairIterable(int i, String... s) {
65         return new Tuple2<>(i, ImmutableList.copyOf(s));
66     }
67 
pairIterable(String key, GATKRead... reads)68     private static Tuple2<String, Iterable<GATKRead>> pairIterable(String key, GATKRead... reads) {
69         return new Tuple2<>(key, ImmutableList.copyOf(reads));
70     }
71 
72     @Test(expectedExceptions = UserException.BadInput.class)
testHeaderMissingReadGroupFilds()73     public void testHeaderMissingReadGroupFilds() {
74         JavaSparkContext ctx = SparkContextFactory.getTestSparkContext();
75 
76         SAMRecordSetBuilder samRecordSetBuilder = new SAMRecordSetBuilder(true, SAMFileHeader.SortOrder.queryname,
77                 true, SAMRecordSetBuilder.DEFAULT_CHROMOSOME_LENGTH, SAMRecordSetBuilder.DEFAULT_DUPLICATE_SCORING_STRATEGY);
78 
79         JavaRDD<GATKRead> reads = ctx.parallelize(new ArrayList<>(), 2);
80         SAMFileHeader header = samRecordSetBuilder.getHeader();
81         header.setReadGroups(new ArrayList<>());
82 
83         MarkDuplicatesSparkUtils.transformToDuplicateNames(header, MarkDuplicatesScoringStrategy.SUM_OF_BASE_QUALITIES, null, reads, 2, false).collect();
84     }
85 
86     @Test
testReadsMissingReadGroups()87     public void testReadsMissingReadGroups() {
88         JavaSparkContext ctx = SparkContextFactory.getTestSparkContext();
89 
90         SAMRecordSetBuilder samRecordSetBuilder = new SAMRecordSetBuilder(true, SAMFileHeader.SortOrder.queryname,
91                 true, SAMRecordSetBuilder.DEFAULT_CHROMOSOME_LENGTH, SAMRecordSetBuilder.DEFAULT_DUPLICATE_SCORING_STRATEGY);
92         samRecordSetBuilder.addFrag("READ" , 0, 10000, false);
93 
94         JavaRDD<GATKRead> reads = ctx.parallelize(Lists.newArrayList(samRecordSetBuilder.getRecords()), 2).map(SAMRecordToGATKReadAdapter::new);
95         reads = reads.map(r -> {r.setReadGroup(null); return r;});
96         SAMFileHeader header = samRecordSetBuilder.getHeader();
97 
98         try {
99             MarkDuplicatesSparkUtils.transformToDuplicateNames(header, MarkDuplicatesScoringStrategy.SUM_OF_BASE_QUALITIES, null, reads, 2, false).collect();
100             Assert.fail("Should have thrown an exception");
101         } catch (Exception e){
102             Assert.assertTrue(e instanceof SparkException);
103             Assert.assertTrue(e.getCause() instanceof UserException.ReadMissingReadGroup);
104         }
105     }
106 
107     @Test
108     // Test that asserts the duplicate marking is sorting agnostic, specifically this is testing that when reads are scrambled across
109     // partitions in the input that all reads in a group are getting properly duplicate marked together as they are for queryname sorted bams
testSortOrderPartitioningCorrectness()110     public void testSortOrderPartitioningCorrectness() throws IOException {
111 
112         JavaSparkContext ctx = SparkContextFactory.getTestSparkContext();
113         JavaRDD<GATKRead> unsortedReads = generateReadsWithDuplicates(10000,3, ctx, 99, true);
114         JavaRDD<GATKRead> pariedEndsQueryGrouped = generateReadsWithDuplicates(10000,3, ctx,1, false); //Use only one partition to avoid having to do edge fixing.
115 
116         // Create headers reflecting the respective sort ordering of the trial reads
117         SAMReadGroupRecord readGroup1 = new SAMReadGroupRecord("1");
118         readGroup1.setAttribute(SAMReadGroupRecord.READ_GROUP_SAMPLE_TAG, "test");
119 
120         SAMFileHeader unsortedHeader = hg19Header.clone();
121         unsortedHeader.addReadGroup(readGroup1);
122         unsortedHeader.setSortOrder(SAMFileHeader.SortOrder.unsorted);
123         SAMFileHeader sortedHeader = hg19Header.clone();
124         sortedHeader.addReadGroup(readGroup1);
125         sortedHeader.setSortOrder(SAMFileHeader.SortOrder.queryname);
126 
127         // Using the header flagged as unsorted will result in the reads being sorted again
128         JavaRDD<GATKRead> unsortedReadsMarked = MarkDuplicatesSpark.mark(unsortedReads,unsortedHeader, MarkDuplicatesScoringStrategy.SUM_OF_BASE_QUALITIES,new OpticalDuplicateFinder(),100,true,MarkDuplicates.DuplicateTaggingPolicy.DontTag);
129         JavaRDD<GATKRead> sortedReadsMarked = MarkDuplicatesSpark.mark(pariedEndsQueryGrouped,sortedHeader, MarkDuplicatesScoringStrategy.SUM_OF_BASE_QUALITIES,new OpticalDuplicateFinder(),1, true, MarkDuplicates.DuplicateTaggingPolicy.DontTag);
130 
131         Iterator<GATKRead> sortedReadsFinal = sortedReadsMarked.sortBy(GATKRead::commonToString, false, 1).collect().iterator();
132         Iterator<GATKRead> unsortedReadsFinal = unsortedReadsMarked.sortBy(GATKRead::commonToString, false, 1).collect().iterator();
133 
134         // Comparing the output reads to ensure they are all duplicate marked correctly
135         while (sortedReadsFinal.hasNext()) {
136             GATKRead read1 = sortedReadsFinal.next();
137             GATKRead read2 = unsortedReadsFinal.next();
138             Assert.assertEquals(read1.getName(), read2.getName());
139             Assert.assertEquals(read1.isDuplicate(), read2.isDuplicate());
140         }
141     }
142 
143     // This helper method is used to generate groups reads that will be duplicate marked. It does this by generating numDuplicatesPerGroup
144     // pairs of reads starting at randomly selected starting locations. The start locations are random so that if the resulting RDD is
145     // coordinate sorted that it is more or less guaranteed that a large portion of the reads will reside on separate partitions from
146     // their mates. It also handles sorting of the reads into either queryname or coordinate orders.
generateReadsWithDuplicates(int numReadGroups, int numDuplicatesPerGroup, JavaSparkContext ctx, int numPartitions, boolean coordinate)147     private JavaRDD<GATKRead> generateReadsWithDuplicates(int numReadGroups, int numDuplicatesPerGroup, JavaSparkContext ctx, int numPartitions, boolean coordinate) {
148         int readNameCounter = 0;
149         SAMRecordSetBuilder samRecordSetBuilder = new SAMRecordSetBuilder(true, SAMFileHeader.SortOrder.coordinate,
150                 true, SAMRecordSetBuilder.DEFAULT_CHROMOSOME_LENGTH, SAMRecordSetBuilder.DEFAULT_DUPLICATE_SCORING_STRATEGY);
151 
152         Random rand = new Random(10);
153         for (int i = 0; i < numReadGroups; i++ ) {
154             int start1 = rand.nextInt(SAMRecordSetBuilder.DEFAULT_CHROMOSOME_LENGTH);
155             int start2 = rand.nextInt(SAMRecordSetBuilder.DEFAULT_CHROMOSOME_LENGTH);
156             for (int j = 0; j < numDuplicatesPerGroup; j++) {
157                 samRecordSetBuilder.addPair("READ" + readNameCounter++, 0, start1, start2);
158             }
159         }
160         List<SAMRecord> records = Lists.newArrayList(samRecordSetBuilder.getRecords());
161         if (coordinate) {
162             records.sort(new SAMRecordCoordinateComparator());
163         } else {
164             records.sort(new SAMRecordQueryNameComparator());
165         }
166 
167         return ctx.parallelize(records, numPartitions).map(SAMRecordToGATKReadAdapter::new);
168     }
169 
170     @Test
testChangingContigsOnHeaderlessSAMRecord()171     public void testChangingContigsOnHeaderlessSAMRecord() {
172         final SparkConf conf = new SparkConf().set("spark.kryo.registrator",
173                 "org.broadinstitute.hellbender.tools.spark.transforms.markduplicates.MarkDuplicatesSparkUtilsUnitTest$TestGATKRegistrator");
174         final SAMRecord read = ((SAMRecordToGATKReadAdapter) ArtificialReadUtils.createHeaderlessSamBackedRead("read1", "1", 100, 50)).getEncapsulatedSamRecord();
175         final OpticalDuplicateFinder finder = new OpticalDuplicateFinder(OpticalDuplicateFinder.DEFAULT_READ_NAME_REGEX,2500, null);
176 
177         final OpticalDuplicateFinder roundTrippedRead = SparkTestUtils.roundTripInKryo(finder, OpticalDuplicateFinder.class, conf);
178         Assert.assertEquals(roundTrippedRead.opticalDuplicatePixelDistance, finder.opticalDuplicatePixelDistance);
179     }
180 
181     public static class TestGATKRegistrator implements KryoRegistrator {
182         @SuppressWarnings("unchecked")
183         @Override
registerClasses(Kryo kryo)184         public void registerClasses(Kryo kryo) {
185             kryo.register(SAMRecord.class, new SAMRecordSerializer());
186         }
187     }
188 
189 
190 }
191