1 package org.broadinstitute.hellbender.utils.nio;
2 
3 import com.google.common.base.Stopwatch;
4 import org.apache.logging.log4j.LogManager;
5 import org.apache.logging.log4j.Logger;
6 import org.broadinstitute.hellbender.GATKBaseTest;
7 import org.broadinstitute.hellbender.utils.gcs.BucketUtils;
8 import org.broadinstitute.hellbender.utils.io.IOUtils;
9 import org.testng.Assert;
10 import org.testng.annotations.Test;
11 
12 import java.nio.ByteBuffer;
13 import java.nio.channels.SeekableByteChannel;
14 import java.nio.file.Files;
15 import java.nio.file.Path;
16 import java.util.ArrayList;
17 import java.util.concurrent.ExecutorService;
18 import java.util.concurrent.Executors;
19 import java.util.concurrent.ThreadFactory;
20 import java.util.concurrent.TimeUnit;
21 
22 /**
23  * Stress test for reading lots of data from the cloud using a very small prefetch buffer.
24  * Do not run this too often.
25  */
26 public final class ExtremeReadsTest extends GATKBaseTest {
27 
28     static final String fname = GCS_GATK_TEST_RESOURCES + "large/CEUTrio.HiSeq.WGS.b37.NA12878.20.21.bam";
29     static final int THREAD_COUNT = 1000;
30     static final int CHANNELS_PER_THREAD = 1000;
31     private static Logger logger = LogManager.getLogger(ExtremeReadsTest.class);
32 
33     static volatile int errors = 0;
34 
35     private static class Runner implements Runnable {
36         /**
37          * Read a bunch of bytes. Part of the manyParallelReads test.
38          */
39         @Override
run()40         public void run() {
41             try {
42                 Path path = IOUtils.getPath(ExtremeReadsTest.fname);
43                 ArrayList<SeekableByteChannel> chans = new ArrayList<SeekableByteChannel>();
44                 for (int i=0; i<CHANNELS_PER_THREAD; i++) {
45                     SeekableByteChannel chan = BucketUtils.addPrefetcher(2, Files.newByteChannel(path));
46                     // skip the first half
47                     chan.position(chan.position()/2);
48                     chans.add(chan);
49                 }
50                 long size = chans.get(0).size();
51                 ByteBuffer buf = ByteBuffer.allocate(1024*1024 - 5);
52                 while (!chans.isEmpty()) {
53                     SeekableByteChannel chan = chans.remove(0);
54                     buf.clear();
55                     int read = chan.read(buf);
56                     if (read>=0) {
57                         chans.add(chan);
58                         continue;
59                     }
60                     // EOF
61                     long position = chan.position();
62                     if (size != position) {
63                         logger.info("Done at wrong position! " + position + " != " + size);
64                         ExtremeReadsTest.errors++;
65                     }
66                 }
67             } catch (Exception x) {
68                 ExtremeReadsTest.errors++;
69                 logger.info("Caught: " + x.getMessage());
70                 x.printStackTrace();
71             }
72         }
73     }
74 
75     /**
76      * This test takes about a half hour and reads a fair amount of data.
77      * It definitely shouldn't be part of the normal test suite (that's why it's disabled)
78      * but it's kept here so we can manually run it should we need to investigate mysterious
79      * disconnects again.
80      **/
81     @Test(groups={"bucket"}, enabled=false)
manyParallelReads()82     public void manyParallelReads() throws InterruptedException {
83         final ExecutorService executor = Executors.newFixedThreadPool(THREAD_COUNT,
84             new ThreadFactory() {
85                 public Thread newThread(Runnable r) {
86                     Thread t = Executors.defaultThreadFactory().newThread(r);
87                     t.setDaemon(true);
88                     return t;
89                 }
90             });
91         Stopwatch sw = Stopwatch.createStarted();
92         errors = 0;
93         for (int i=0; i<THREAD_COUNT; i++) {
94             executor.execute(new Runner());
95         }
96         long parallel_reads = THREAD_COUNT * CHANNELS_PER_THREAD;
97         logger.info(parallel_reads + " parallel reads via " + THREAD_COUNT + " threads (this will take a while).");
98         executor.shutdown();
99         executor.awaitTermination(1, TimeUnit.DAYS);
100         sw.stop();
101         logger.info("All done. Elapsed: " + sw.elapsed(TimeUnit.MINUTES) + " min.");
102         logger.info("There were " + errors + " error(s).");
103         Assert.assertEquals(errors, 0);
104     }
105 }
106