1 package org.broadinstitute.hellbender.utils;
2 
3 import htsjdk.samtools.SAMSequenceDictionary;
4 import htsjdk.samtools.SAMSequenceRecord;
5 import org.apache.commons.lang3.tuple.ImmutablePair;
6 import org.apache.commons.lang3.tuple.Pair;
7 import org.apache.commons.math3.stat.descriptive.moment.StandardDeviation;
8 import org.testng.Assert;
9 import org.testng.annotations.DataProvider;
10 import org.testng.annotations.Test;
11 
12 import java.io.BufferedReader;
13 import java.io.File;
14 import java.io.FileReader;
15 import java.io.IOException;
16 import java.util.ArrayList;
17 import java.util.Arrays;
18 import java.util.List;
19 import java.util.stream.IntStream;
20 
21 /**
22  * Unit tests for {@link RandomDNA}.
23  */
24 public final class RandomDNAUnitTest {
25 
26     private static final int TEST_BASES_PER_LINE = 73;
27 
28     @Test(dataProvider="dictionaries")
testRandomFasta(final SAMSequenceDictionary dict)29     public void testRandomFasta(final SAMSequenceDictionary dict) throws IOException {
30         final RandomDNA randomDNA = new RandomDNA(111);
31         File fastaFile = null;
32         try {
33             fastaFile = randomDNA.nextFasta(dict, TEST_BASES_PER_LINE);
34             assertFastaFileAndDictMatch(fastaFile, TEST_BASES_PER_LINE, dict);
35         } finally {
36             try { if (fastaFile != null) fastaFile.delete(); } catch (final RuntimeException ex) {};
37         }
38     }
39 
40     @Test
testBufferMaxCapacity()41     public void testBufferMaxCapacity() {
42         Assert.assertTrue(RandomDNA.NEXT_BASES_MAX_CAPACITY >= RandomDNA.BASES_IN_AN_INT, "invalid value for NEXT_BASES_MAX_CAPACITY");
43     }
44 
45     @Test
testBasesInIntConstantConstraints()46     public void testBasesInIntConstantConstraints() {
47         Assert.assertTrue(RandomDNA.BASES_IN_AN_INT >= 3);
48         Assert.assertTrue(RandomDNA.BASES_IN_AN_INT <= Integer.SIZE / 2);
49     }
50 
51     /**
52      * Checks that several ways to compose an array of bases actually result on the same sequence of bases
53      * given a fixed seed.
54      */
55     @Test
testSequenceEquivalenceShort()56     public void testSequenceEquivalenceShort() {
57         final int seed = 1311;
58         final RandomDNA rdn1 = new RandomDNA(seed);
59         final RandomDNA rdn2 = new RandomDNA(seed);
60         final RandomDNA rdn3 = new RandomDNA(seed);
61         final RandomDNA rdn4 = new RandomDNA(seed);
62         final RandomDNA rdn5 = new RandomDNA(seed);
63 
64         final byte[] seq1 = rdn1.nextBases(132);
65 
66         final byte[] seq2 = new byte[132];
67         for (int i = 0; i < seq2.length; i++) {
68             seq2[i] = rdn2.nextBase();
69         }
70 
71         final byte[] seq3 = new byte[132];
72         rdn3.nextBases(seq3);
73 
74         final byte[] seq4 = new byte[132];
75         rdn4.nextBases(seq4, 0, 100);
76         rdn4.nextBases(seq4, 100, 32);
77 
78         final byte[] seq5 = new byte[132];
79         rdn5.nextBases(seq5, 0, 90);
80         for (int i = 0; i < 10; i++)
81             seq5[90 + i] = rdn5.nextBase();
82         rdn5.nextBases(seq5, 100, 10);
83         for (int i = 0; i < 22; i++) {
84             seq5[110 + i] = rdn5.nextBase();
85         }
86 
87         Assert.assertEquals(seq1, seq2);
88         Assert.assertEquals(seq1, seq3);
89         Assert.assertEquals(seq1, seq4);
90         Assert.assertEquals(seq1, seq5);
91     }
92 
93     @Test
testSequenceEquivalenceLong()94     public void testSequenceEquivalenceLong() {
95         final int seed = 1311;
96         final RandomDNA rdn1 = new RandomDNA(seed);
97         final RandomDNA rdn2 = new RandomDNA(seed);
98         final RandomDNA rdn3 = new RandomDNA(seed);
99         final RandomDNA rdn4 = new RandomDNA(seed);
100         final RandomDNA rdn5 = new RandomDNA(seed);
101 
102         final byte[] seq1 = rdn1.nextBases(1320);
103 
104         final byte[] seq2 = new byte[1320];
105         for (int i = 0; i < seq2.length; i++) {
106             seq2[i] = rdn2.nextBase();
107         }
108 
109         final byte[] seq3 = new byte[1320];
110         rdn3.nextBases(seq3);
111 
112         final byte[] seq4 = new byte[1320];
113         rdn4.nextBases(seq4, 0, 1000);
114         rdn4.nextBases(seq4, 1000, 320);
115 
116         final byte[] seq5 = new byte[1320];
117         rdn5.nextBases(seq5, 0, 900);
118         for (int i = 0; i < 100; i++)
119             seq5[900 + i] = rdn5.nextBase();
120         rdn5.nextBases(seq5, 1000, 100);
121         for (int i = 0; i < 220; i++) {
122             seq5[1100 + i] = rdn5.nextBase();
123         }
124 
125         for (int i = 0; i < 1320; i++) {
126             Assert.assertEquals(seq1[i], seq2[i], "" + i);
127         }
128         Assert.assertEquals(seq1, seq2);
129         Assert.assertEquals(seq1, seq3);
130         Assert.assertEquals(seq1, seq4);
131         Assert.assertEquals(seq1, seq5);
132     }
133 
134 
assertFastaFileAndDictMatch(final File fastaFile, final int basesPerLine, final SAMSequenceDictionary dict)135     private void assertFastaFileAndDictMatch(final File fastaFile, final int basesPerLine, final SAMSequenceDictionary dict) {
136         try (final BufferedReader reader = new BufferedReader(new FileReader(fastaFile))) {
137             final List<Pair<String, Nucleotide.Counter>> nameLengthAndFreqs = new ArrayList<>();
138             String line = reader.readLine();
139             if (dict.getSequences().isEmpty()) {
140                 Assert.assertNull(line);
141             } else {
142                 Assert.assertNotNull(line);
143                 do {
144                     Assert.assertTrue(line.matches("^>\\S.*$"));
145                     final String name = line.substring(1).split("\\s+")[0];
146                     final Nucleotide.Counter frequencies = new Nucleotide.Counter();
147                     line = reader.readLine();
148                     while (line != null && !line.matches("^>.*$")) {
149                         final String lineBases = line.trim();
150                         final String nextLine = reader.readLine();
151                         for (final byte base : lineBases.getBytes()) {
152                             final Nucleotide nuc = Nucleotide.decode(base);
153                             Assert.assertTrue(nuc.isStandard());
154                             frequencies.add(nuc);
155                         }
156                         if (nextLine != null && !nextLine.matches("^>.*$")){
157                             Assert.assertEquals(lineBases.length(), basesPerLine);
158                         } else {
159                             Assert.assertTrue(lineBases.length() <= basesPerLine);
160                         }
161                         line = nextLine;
162                     }
163                     nameLengthAndFreqs.add(new ImmutablePair<>(name, frequencies));
164                 } while (line != null);
165                 Assert.assertEquals(nameLengthAndFreqs.size(), dict.getSequences().size());
166                 for (int i = 0; i < nameLengthAndFreqs.size(); i++) {
167                     Assert.assertEquals(nameLengthAndFreqs.get(i).getLeft(), dict.getSequence(i).getSequenceName());
168                     Assert.assertEquals(nameLengthAndFreqs.get(i).getRight().sum(), dict.getSequence(i).getSequenceLength());
169                 }
170             }
171         } catch (final IOException ex) {
172             Assert.fail("exception thrown when openning fastaFile", ex);
173         }
174     }
175 
counts(final byte[] bytes)176     private int[] counts(final byte[] bytes){
177         final int[] b= new int[4];
178         for(int i=0; i < bytes.length; i++){
179             switch (bytes[i]){
180                 case 'A': b[0]++; break;
181                 case 'C': b[1]++; break;
182                 case 'G': b[2]++; break;
183                 case 'T': b[3]++; break;
184                 default: throw new IllegalStateException("illegal base:" + bytes[i]);
185             }
186         }
187         return b;
188     }
189     @Test
testBases1()190     public void testBases1(){
191         int[] results = new int[4];
192 
193         final int n = 1000;
194         final int m = 13;
195         for (int i= 0; i < n; i++) {
196             final byte[] b = new RandomDNA().nextBases(m);
197             final int[] b0 = counts(b);
198             results = pairwiseAdd(results, b0);
199         }
200 
201         checkResults(results, n, m);
202     }
203 
204     @Test
testBases()205     public void testBases(){
206         int[] results = new int[4];
207 
208         final int n = 1000;
209         final int m = 13;
210         for (int i= 0; i < n; i++) {
211             final byte[] b = new byte[m];
212             new RandomDNA().nextBases(b);
213             final int[] b0 = counts(b);
214             results = pairwiseAdd(results, b0);
215         }
216 
217         checkResults(results, n, m);
218     }
219 
checkResults(final int[] results, final int n, final int m)220     public void checkResults(final int[] results, final int n, final int m) {
221         final double mean = Arrays.stream(results).average().getAsDouble();
222         final double std = new StandardDeviation().evaluate(Arrays.stream(results).asDoubleStream().toArray());
223         final double expectedMean = (n*m)/4.0;
224         final double s = std; // not really because it's the population not the sample dtd but it'll do
225         Assert.assertTrue(mean < expectedMean + 2 * s / Math.sqrt(n * m), "unexpected mean:" + mean);
226         Assert.assertTrue(mean > expectedMean-2*s/Math.sqrt(n*m), "unexpected mean:" +mean);
227     }
228 
pairwiseAdd(int[] a, int[] b)229     private int[] pairwiseAdd(int[] a, int[] b) {
230         Utils.validateArg(a.length == b.length, "lengths must be equal");
231         return IntStream.range(0, a.length).map(n -> a[n] + b[n]).toArray();
232     }
233 
234     @DataProvider
dictionaries()235     public Object[][] dictionaries() {
236         return Arrays.asList(
237                 new SAMSequenceDictionary(),
238                 new SAMSequenceDictionary(Arrays.asList(
239                         new SAMSequenceRecord("seq1", 1000))),
240                 new SAMSequenceDictionary(Arrays.asList(
241                         new SAMSequenceRecord("chr20", 10_000),
242                         new SAMSequenceRecord("chrX", 1_000),
243                         new SAMSequenceRecord("MT_unknown", 10))),
244                 new SAMSequenceDictionary(Arrays.asList(
245                         new SAMSequenceRecord("1", 0),
246                         new SAMSequenceRecord("2", TEST_BASES_PER_LINE),
247                         new SAMSequenceRecord("3", TEST_BASES_PER_LINE - 1),
248                         new SAMSequenceRecord("mmm", TEST_BASES_PER_LINE + 1),
249                         new SAMSequenceRecord("xxx", TEST_BASES_PER_LINE * 13)))).stream()
250         .map(dict -> new Object[] { dict})
251         .toArray(Object[][]::new);
252     }
253 }
254