1 package org.broadinstitute.hellbender.metrics;
2 
3 import htsjdk.samtools.SAMReadGroupRecord;
4 import htsjdk.samtools.SAMRecord;
5 import htsjdk.samtools.SamReader;
6 import htsjdk.samtools.SamReaderFactory;
7 import htsjdk.samtools.metrics.MetricsFile;
8 import htsjdk.samtools.reference.ReferenceSequence;
9 import htsjdk.samtools.util.CloserUtil;
10 import org.testng.Assert;
11 import org.testng.annotations.DataProvider;
12 import org.testng.annotations.Test;
13 
14 import java.io.File;
15 import java.io.Serializable;
16 import java.util.*;
17 
18 import static htsjdk.samtools.util.CollectionUtil.makeSet;
19 
20 public final class MultiLevelCollectorTest {
21 
22     public static File TESTFILE = new File("src/test/resources/org/broadinstitute/hellbender/metrics/test.sam");
23 
noneOrStr(final String str)24     public String noneOrStr(final String str) {
25         final String out;
26         if(str == null) {
27             out = "";
28         } else {
29             out = str;
30         }
31         return out;
32     }
33 
34     class TestArg {
35         public final SAMRecord samRecord;
36         public final ReferenceSequence refSeq;
37 
TestArg(final SAMRecord samRecord, final ReferenceSequence refSeq)38         public TestArg(final SAMRecord samRecord, final ReferenceSequence refSeq) {
39             this.samRecord = samRecord;
40             this.refSeq    = refSeq;
41         }
42     }
43 
44     /** We will just Tally up the number of times records were added to this metric and change FINISHED
45      * to true when FINISHED is called
46      */
47     class TotalNumberMetric extends MultiLevelMetrics implements Serializable {
48         private static final long serialVersionUID = 1L;
49 
50         /** The number of these encountered **/
51         public Integer TALLY = 0;
52         public boolean FINISHED = false;
53     }
54 
55     class RecordCountMultiLevelCollector extends MultiLevelCollector<TotalNumberMetric, Integer, TestArg> {
56         private static final long serialVersionUID = 1L;
57 
RecordCountMultiLevelCollector(final Set<MetricAccumulationLevel> accumulationLevels, final List<SAMReadGroupRecord> samRgRecords)58         public RecordCountMultiLevelCollector(final Set<MetricAccumulationLevel> accumulationLevels, final List<SAMReadGroupRecord> samRgRecords) {
59             setup(accumulationLevels, samRgRecords);
60         }
61 
62         //The number of times records were accepted by a RecordCountPerUnitCollectors (note since the same
63         //samRecord might be aggregated by multiple PerUnit collectors, this may be greater than the number of
64         //records in the file
65         private int numProcessed = 0;
66 
getNumProcessed()67         public int getNumProcessed() {
68             return numProcessed;
69         }
70 
71         private final Map<String, TotalNumberMetric> unitsToMetrics = new LinkedHashMap<>();
72 
getUnitsToMetrics()73         public Map<String, TotalNumberMetric> getUnitsToMetrics() {
74             return unitsToMetrics;
75         }
76 
77         @Override
makeArg(final SAMRecord samRec, final ReferenceSequence refSeq)78         protected TestArg makeArg(final SAMRecord samRec, final ReferenceSequence refSeq) {
79             return new TestArg(samRec, refSeq);
80         }
81 
82         @Override
makeChildCollector(final String sample, final String library, final String readGroup)83         protected PerUnitMetricCollector<TotalNumberMetric, Integer, TestArg> makeChildCollector(final String sample, final String library, final String readGroup) {
84             return new RecordCountPerUnitCollector(sample, library, readGroup);
85         }
86 
87         private class RecordCountPerUnitCollector implements PerUnitMetricCollector<TotalNumberMetric, Integer, TestArg>{
88             private static final long serialVersionUID = 1L;
89             private final TotalNumberMetric metric;
90 
RecordCountPerUnitCollector(final String sample, final String library, final String readGroup)91             public RecordCountPerUnitCollector(final String sample, final String library, final String readGroup) {
92                 metric = new TotalNumberMetric();
93                 metric.SAMPLE     = sample;
94                 metric.LIBRARY    = library;
95                 metric.READ_GROUP = readGroup;
96                 unitsToMetrics.put(noneOrStr(sample) + "_" + noneOrStr(library) + "_" + noneOrStr(readGroup), metric);
97             }
98 
99             @Override
acceptRecord(final TestArg args)100             public void acceptRecord(final TestArg args) {
101                 numProcessed += 1;
102                 metric.TALLY += 1;
103                 if(metric.SAMPLE != null) {
104                     Assert.assertEquals(metric.SAMPLE, args.samRecord.getReadGroup().getSample());
105                 }
106                 if(metric.LIBRARY != null) {
107                     Assert.assertEquals(metric.LIBRARY, args.samRecord.getReadGroup().getLibrary());
108                 }
109 
110                 if(metric.READ_GROUP != null) {
111                     Assert.assertEquals(metric.READ_GROUP, args.samRecord.getReadGroup().getPlatformUnit());
112                 }
113             }
114 
115             @Override
finish()116             public void finish() {
117                 metric.FINISHED = true;
118             }
119 
120             @Override
addMetricsToFile(final MetricsFile<TotalNumberMetric, Integer> totalNumberMetricIntegerMetricsFile)121             public void addMetricsToFile(final MetricsFile<TotalNumberMetric, Integer> totalNumberMetricIntegerMetricsFile) {
122                 totalNumberMetricIntegerMetricsFile.addMetric(metric);
123             }
124         }
125     }
126 
127     public static final Map<MetricAccumulationLevel, Map<String, Integer>> accumulationLevelToPerUnitReads = new LinkedHashMap<>();
128     static {
129         HashMap<String, Integer> curMap = new LinkedHashMap<>();
130         curMap.put("__", 19);
accumulationLevelToPerUnitReads.put(MetricAccumulationLevel.ALL_READS, curMap)131         accumulationLevelToPerUnitReads.put(MetricAccumulationLevel.ALL_READS, curMap);
132 
133         curMap = new LinkedHashMap<>();
134         curMap.put("Ma__", 10);
135         curMap.put("Pa__", 9);
accumulationLevelToPerUnitReads.put(MetricAccumulationLevel.SAMPLE, curMap)136         accumulationLevelToPerUnitReads.put(MetricAccumulationLevel.SAMPLE, curMap);
137 
138         curMap = new LinkedHashMap<>();
139         curMap.put("Ma_whatever_", 10);
140         curMap.put("Pa_lib1_",     4);
141         curMap.put("Pa_lib2_",     5);
accumulationLevelToPerUnitReads.put(MetricAccumulationLevel.LIBRARY, curMap)142         accumulationLevelToPerUnitReads.put(MetricAccumulationLevel.LIBRARY, curMap);
143 
144 
145         curMap = new LinkedHashMap<>();
146         curMap.put("Ma_whatever_me",     10);
147         curMap.put("Pa_lib1_myself", 4);
148         curMap.put("Pa_lib2_i",      3);
149         curMap.put("Pa_lib2_i2",     2);
accumulationLevelToPerUnitReads.put(MetricAccumulationLevel.READ_GROUP, curMap)150         accumulationLevelToPerUnitReads.put(MetricAccumulationLevel.READ_GROUP, curMap);
151     }
152 
153     @DataProvider(name = "variedAccumulationLevels")
variedAccumulationLevels()154     public Object[][] variedAccumulationLevels() {
155         return new Object[][] {
156             {makeSet(MetricAccumulationLevel.ALL_READS)},
157             {makeSet(MetricAccumulationLevel.ALL_READS,    MetricAccumulationLevel.SAMPLE)},
158             {makeSet(MetricAccumulationLevel.SAMPLE,       MetricAccumulationLevel.LIBRARY)},
159             {makeSet(MetricAccumulationLevel.READ_GROUP,   MetricAccumulationLevel.LIBRARY)},
160             {makeSet(MetricAccumulationLevel.SAMPLE,       MetricAccumulationLevel.LIBRARY, MetricAccumulationLevel.READ_GROUP)},
161             {makeSet(MetricAccumulationLevel.SAMPLE,       MetricAccumulationLevel.LIBRARY, MetricAccumulationLevel.READ_GROUP, MetricAccumulationLevel.ALL_READS)},
162         };
163     }
164 
165     @Test(dataProvider = "variedAccumulationLevels")
multilevelCollectorTest(final Set<MetricAccumulationLevel> accumulationLevels)166     public void multilevelCollectorTest(final Set<MetricAccumulationLevel> accumulationLevels) {
167         final SamReader in = SamReaderFactory.makeDefault().open(TESTFILE);
168         final RecordCountMultiLevelCollector collector = new RecordCountMultiLevelCollector(accumulationLevels, in.getFileHeader().getReadGroups());
169 
170         for (final SAMRecord rec : in) {
171             collector.acceptRecord(rec, null);
172         }
173 
174         collector.finish();
175 
176         int totalProcessed = 0;
177         int totalMetrics = 0;
178         for(final MetricAccumulationLevel level : accumulationLevels) {
179             final Map<String, Integer> keyToMetrics = accumulationLevelToPerUnitReads.get(level);
180             for(final Map.Entry<String, Integer> entry : keyToMetrics.entrySet()) {
181                 final TotalNumberMetric metric = collector.getUnitsToMetrics().get(entry.getKey());
182                 Assert.assertEquals(entry.getValue(), metric.TALLY);
183                 Assert.assertTrue(metric.FINISHED);
184                 totalProcessed += metric.TALLY;
185                 totalMetrics   += 1;
186             }
187         }
188 
189         Assert.assertEquals(collector.getUnitsToMetrics().size(), totalMetrics);
190         Assert.assertEquals(totalProcessed, collector.getNumProcessed());
191         CloserUtil.close(in);
192     }
193 }
194