1 package org.broadinstitute.hellbender.metrics; 2 3 import htsjdk.samtools.SAMReadGroupRecord; 4 import htsjdk.samtools.SAMRecord; 5 import htsjdk.samtools.SamReader; 6 import htsjdk.samtools.SamReaderFactory; 7 import htsjdk.samtools.metrics.MetricsFile; 8 import htsjdk.samtools.reference.ReferenceSequence; 9 import htsjdk.samtools.util.CloserUtil; 10 import org.testng.Assert; 11 import org.testng.annotations.DataProvider; 12 import org.testng.annotations.Test; 13 14 import java.io.File; 15 import java.io.Serializable; 16 import java.util.*; 17 18 import static htsjdk.samtools.util.CollectionUtil.makeSet; 19 20 public final class MultiLevelCollectorTest { 21 22 public static File TESTFILE = new File("src/test/resources/org/broadinstitute/hellbender/metrics/test.sam"); 23 noneOrStr(final String str)24 public String noneOrStr(final String str) { 25 final String out; 26 if(str == null) { 27 out = ""; 28 } else { 29 out = str; 30 } 31 return out; 32 } 33 34 class TestArg { 35 public final SAMRecord samRecord; 36 public final ReferenceSequence refSeq; 37 TestArg(final SAMRecord samRecord, final ReferenceSequence refSeq)38 public TestArg(final SAMRecord samRecord, final ReferenceSequence refSeq) { 39 this.samRecord = samRecord; 40 this.refSeq = refSeq; 41 } 42 } 43 44 /** We will just Tally up the number of times records were added to this metric and change FINISHED 45 * to true when FINISHED is called 46 */ 47 class TotalNumberMetric extends MultiLevelMetrics implements Serializable { 48 private static final long serialVersionUID = 1L; 49 50 /** The number of these encountered **/ 51 public Integer TALLY = 0; 52 public boolean FINISHED = false; 53 } 54 55 class RecordCountMultiLevelCollector extends MultiLevelCollector<TotalNumberMetric, Integer, TestArg> { 56 private static final long serialVersionUID = 1L; 57 RecordCountMultiLevelCollector(final Set<MetricAccumulationLevel> accumulationLevels, final List<SAMReadGroupRecord> samRgRecords)58 public RecordCountMultiLevelCollector(final Set<MetricAccumulationLevel> accumulationLevels, final List<SAMReadGroupRecord> samRgRecords) { 59 setup(accumulationLevels, samRgRecords); 60 } 61 62 //The number of times records were accepted by a RecordCountPerUnitCollectors (note since the same 63 //samRecord might be aggregated by multiple PerUnit collectors, this may be greater than the number of 64 //records in the file 65 private int numProcessed = 0; 66 getNumProcessed()67 public int getNumProcessed() { 68 return numProcessed; 69 } 70 71 private final Map<String, TotalNumberMetric> unitsToMetrics = new LinkedHashMap<>(); 72 getUnitsToMetrics()73 public Map<String, TotalNumberMetric> getUnitsToMetrics() { 74 return unitsToMetrics; 75 } 76 77 @Override makeArg(final SAMRecord samRec, final ReferenceSequence refSeq)78 protected TestArg makeArg(final SAMRecord samRec, final ReferenceSequence refSeq) { 79 return new TestArg(samRec, refSeq); 80 } 81 82 @Override makeChildCollector(final String sample, final String library, final String readGroup)83 protected PerUnitMetricCollector<TotalNumberMetric, Integer, TestArg> makeChildCollector(final String sample, final String library, final String readGroup) { 84 return new RecordCountPerUnitCollector(sample, library, readGroup); 85 } 86 87 private class RecordCountPerUnitCollector implements PerUnitMetricCollector<TotalNumberMetric, Integer, TestArg>{ 88 private static final long serialVersionUID = 1L; 89 private final TotalNumberMetric metric; 90 RecordCountPerUnitCollector(final String sample, final String library, final String readGroup)91 public RecordCountPerUnitCollector(final String sample, final String library, final String readGroup) { 92 metric = new TotalNumberMetric(); 93 metric.SAMPLE = sample; 94 metric.LIBRARY = library; 95 metric.READ_GROUP = readGroup; 96 unitsToMetrics.put(noneOrStr(sample) + "_" + noneOrStr(library) + "_" + noneOrStr(readGroup), metric); 97 } 98 99 @Override acceptRecord(final TestArg args)100 public void acceptRecord(final TestArg args) { 101 numProcessed += 1; 102 metric.TALLY += 1; 103 if(metric.SAMPLE != null) { 104 Assert.assertEquals(metric.SAMPLE, args.samRecord.getReadGroup().getSample()); 105 } 106 if(metric.LIBRARY != null) { 107 Assert.assertEquals(metric.LIBRARY, args.samRecord.getReadGroup().getLibrary()); 108 } 109 110 if(metric.READ_GROUP != null) { 111 Assert.assertEquals(metric.READ_GROUP, args.samRecord.getReadGroup().getPlatformUnit()); 112 } 113 } 114 115 @Override finish()116 public void finish() { 117 metric.FINISHED = true; 118 } 119 120 @Override addMetricsToFile(final MetricsFile<TotalNumberMetric, Integer> totalNumberMetricIntegerMetricsFile)121 public void addMetricsToFile(final MetricsFile<TotalNumberMetric, Integer> totalNumberMetricIntegerMetricsFile) { 122 totalNumberMetricIntegerMetricsFile.addMetric(metric); 123 } 124 } 125 } 126 127 public static final Map<MetricAccumulationLevel, Map<String, Integer>> accumulationLevelToPerUnitReads = new LinkedHashMap<>(); 128 static { 129 HashMap<String, Integer> curMap = new LinkedHashMap<>(); 130 curMap.put("__", 19); accumulationLevelToPerUnitReads.put(MetricAccumulationLevel.ALL_READS, curMap)131 accumulationLevelToPerUnitReads.put(MetricAccumulationLevel.ALL_READS, curMap); 132 133 curMap = new LinkedHashMap<>(); 134 curMap.put("Ma__", 10); 135 curMap.put("Pa__", 9); accumulationLevelToPerUnitReads.put(MetricAccumulationLevel.SAMPLE, curMap)136 accumulationLevelToPerUnitReads.put(MetricAccumulationLevel.SAMPLE, curMap); 137 138 curMap = new LinkedHashMap<>(); 139 curMap.put("Ma_whatever_", 10); 140 curMap.put("Pa_lib1_", 4); 141 curMap.put("Pa_lib2_", 5); accumulationLevelToPerUnitReads.put(MetricAccumulationLevel.LIBRARY, curMap)142 accumulationLevelToPerUnitReads.put(MetricAccumulationLevel.LIBRARY, curMap); 143 144 145 curMap = new LinkedHashMap<>(); 146 curMap.put("Ma_whatever_me", 10); 147 curMap.put("Pa_lib1_myself", 4); 148 curMap.put("Pa_lib2_i", 3); 149 curMap.put("Pa_lib2_i2", 2); accumulationLevelToPerUnitReads.put(MetricAccumulationLevel.READ_GROUP, curMap)150 accumulationLevelToPerUnitReads.put(MetricAccumulationLevel.READ_GROUP, curMap); 151 } 152 153 @DataProvider(name = "variedAccumulationLevels") variedAccumulationLevels()154 public Object[][] variedAccumulationLevels() { 155 return new Object[][] { 156 {makeSet(MetricAccumulationLevel.ALL_READS)}, 157 {makeSet(MetricAccumulationLevel.ALL_READS, MetricAccumulationLevel.SAMPLE)}, 158 {makeSet(MetricAccumulationLevel.SAMPLE, MetricAccumulationLevel.LIBRARY)}, 159 {makeSet(MetricAccumulationLevel.READ_GROUP, MetricAccumulationLevel.LIBRARY)}, 160 {makeSet(MetricAccumulationLevel.SAMPLE, MetricAccumulationLevel.LIBRARY, MetricAccumulationLevel.READ_GROUP)}, 161 {makeSet(MetricAccumulationLevel.SAMPLE, MetricAccumulationLevel.LIBRARY, MetricAccumulationLevel.READ_GROUP, MetricAccumulationLevel.ALL_READS)}, 162 }; 163 } 164 165 @Test(dataProvider = "variedAccumulationLevels") multilevelCollectorTest(final Set<MetricAccumulationLevel> accumulationLevels)166 public void multilevelCollectorTest(final Set<MetricAccumulationLevel> accumulationLevels) { 167 final SamReader in = SamReaderFactory.makeDefault().open(TESTFILE); 168 final RecordCountMultiLevelCollector collector = new RecordCountMultiLevelCollector(accumulationLevels, in.getFileHeader().getReadGroups()); 169 170 for (final SAMRecord rec : in) { 171 collector.acceptRecord(rec, null); 172 } 173 174 collector.finish(); 175 176 int totalProcessed = 0; 177 int totalMetrics = 0; 178 for(final MetricAccumulationLevel level : accumulationLevels) { 179 final Map<String, Integer> keyToMetrics = accumulationLevelToPerUnitReads.get(level); 180 for(final Map.Entry<String, Integer> entry : keyToMetrics.entrySet()) { 181 final TotalNumberMetric metric = collector.getUnitsToMetrics().get(entry.getKey()); 182 Assert.assertEquals(entry.getValue(), metric.TALLY); 183 Assert.assertTrue(metric.FINISHED); 184 totalProcessed += metric.TALLY; 185 totalMetrics += 1; 186 } 187 } 188 189 Assert.assertEquals(collector.getUnitsToMetrics().size(), totalMetrics); 190 Assert.assertEquals(totalProcessed, collector.getNumProcessed()); 191 CloserUtil.close(in); 192 } 193 } 194