1 package org.broadinstitute.hellbender.utils.spark;
2 
3 import com.google.common.collect.AbstractIterator;
4 import com.google.common.collect.Iterators;
5 import com.google.common.collect.Lists;
6 import com.google.common.collect.PeekingIterator;
7 import htsjdk.samtools.SAMFileHeader;
8 import htsjdk.samtools.SAMSequenceRecord;
9 import htsjdk.samtools.SAMTextHeaderCodec;
10 import htsjdk.samtools.util.BinaryCodec;
11 import htsjdk.samtools.util.BlockCompressedOutputStream;
12 import htsjdk.samtools.util.BlockCompressedStreamConstants;
13 import htsjdk.samtools.util.RuntimeIOException;
14 import org.apache.commons.io.FileUtils;
15 import org.apache.hadoop.fs.FileSystem;
16 import org.apache.hadoop.fs.Path;
17 import org.apache.logging.log4j.LogManager;
18 import org.apache.logging.log4j.Logger;
19 import org.apache.spark.api.java.JavaPairRDD;
20 import org.apache.spark.api.java.JavaRDD;
21 import org.apache.spark.api.java.JavaSparkContext;
22 import org.apache.spark.broadcast.Broadcast;
23 import org.broadinstitute.hellbender.engine.spark.datasources.ReadsSparkSink;
24 import org.broadinstitute.hellbender.exceptions.GATKException;
25 import org.broadinstitute.hellbender.exceptions.UserException;
26 import org.broadinstitute.hellbender.utils.Utils;
27 import org.broadinstitute.hellbender.utils.read.*;
28 import scala.Tuple2;
29 
30 import java.io.*;
31 import java.net.URI;
32 import java.util.*;
33 
34 /**
35  * Miscellaneous Spark-related utilities
36  */
37 public final class SparkUtils {
38     private static final Logger logger = LogManager.getLogger(SparkUtils.class);
39 
40     /** Sometimes Spark has trouble destroying a broadcast variable, but we'd like the app to continue anyway. */
destroyBroadcast(final Broadcast<T> broadcast, final String whatBroadcast )41     public static <T> void destroyBroadcast(final Broadcast<T> broadcast, final String whatBroadcast ) {
42         try {
43             broadcast.destroy();
44         } catch ( final Exception e ) {
45             logger.warn("Failed to destroy broadcast for " + whatBroadcast, e);
46         }
47     }
48 
SparkUtils()49     private SparkUtils() {}
50 
51     /**
52      * Converts a headerless Hadoop bam shard (eg., a part0000, part0001, etc. file produced by
53      * {@link org.broadinstitute.hellbender.engine.spark.datasources.ReadsSparkSink}) into a readable bam file
54      * by adding a header and a BGZF terminator.
55      *
56      * This method is not intended for use with Hadoop bam shards that already have a header -- these shards are
57      * already readable using samtools. Currently {@link ReadsSparkSink} saves the "shards" with a header for the
58      * {@link ReadsWriteFormat#SHARDED} case, and without a header for the {@link ReadsWriteFormat#SINGLE} case.
59      *
60      * @param bamShard The headerless Hadoop bam shard to convert
61      * @param header header for the BAM file to be created
62      * @param destination path to which to write the new BAM file
63      */
convertHeaderlessHadoopBamShardToBam( final File bamShard, final SAMFileHeader header, final File destination )64     public static void convertHeaderlessHadoopBamShardToBam( final File bamShard, final SAMFileHeader header, final File destination ) {
65         try ( FileOutputStream outStream = new FileOutputStream(destination) ) {
66             writeBAMHeaderToStream(header, outStream);
67             FileUtils.copyFile(bamShard, outStream);
68             outStream.write(BlockCompressedStreamConstants.EMPTY_GZIP_BLOCK);
69         }
70         catch ( IOException e ) {
71             throw new UserException("Error writing to " + destination.getAbsolutePath(), e);
72         }
73     }
74 
75     /**
76      * Private helper method for {@link #convertHeaderlessHadoopBamShardToBam} that takes a SAMFileHeader and writes it
77      * to the provided `OutputStream`, correctly encoded for the BAM format and preceded by the BAM magic bytes.
78      *
79      * @param samFileHeader SAM header to write
80      * @param outputStream stream to write the SAM header to
81      */
writeBAMHeaderToStream( final SAMFileHeader samFileHeader, final OutputStream outputStream )82     private static void writeBAMHeaderToStream( final SAMFileHeader samFileHeader, final OutputStream outputStream ) {
83         final BlockCompressedOutputStream blockCompressedOutputStream = new BlockCompressedOutputStream(outputStream, (File)null);
84         final BinaryCodec outputBinaryCodec = new BinaryCodec(new DataOutputStream(blockCompressedOutputStream));
85 
86         final String headerString;
87         final Writer stringWriter = new StringWriter();
88         new SAMTextHeaderCodec().encode(stringWriter, samFileHeader, true);
89         headerString = stringWriter.toString();
90 
91         outputBinaryCodec.writeBytes(ReadUtils.BAM_MAGIC);
92 
93         // calculate and write the length of the SAM file header text and the header text
94         outputBinaryCodec.writeString(headerString, true, false);
95 
96         // write the sequences binarily.  This is redundant with the text header
97         outputBinaryCodec.writeInt(samFileHeader.getSequenceDictionary().size());
98         for (final SAMSequenceRecord sequenceRecord: samFileHeader.getSequenceDictionary().getSequences()) {
99             outputBinaryCodec.writeString(sequenceRecord.getSequenceName(), true, true);
100             outputBinaryCodec.writeInt(sequenceRecord.getSequenceLength());
101         }
102 
103         try {
104             blockCompressedOutputStream.flush();
105         } catch (final IOException ioe) {
106             throw new RuntimeIOException(ioe);
107         }
108     }
109 
110     /**
111      * Determine if the <code>targetPath</code> exists.
112      * @param ctx JavaSparkContext
113      * @param targetURI the <code>org.apache.hadoop.fs.Path</code> URI to check
114      * @return true if the targetPath exists, otherwise false
115      */
hadoopPathExists(final JavaSparkContext ctx, final URI targetURI)116     public static boolean hadoopPathExists(final JavaSparkContext ctx, final URI targetURI) {
117         Utils.nonNull(ctx);
118         Utils.nonNull(targetURI);
119         try {
120             final Path targetHadoopPath = new Path(targetURI);
121             final FileSystem fs = targetHadoopPath.getFileSystem(ctx.hadoopConfiguration());
122             return fs.exists(targetHadoopPath);
123         } catch (IOException e) {
124             throw new UserException("Error validating existence of path " + targetURI + ": " + e.getMessage());
125         }
126     }
127 
128     /**
129      * Do a total sort of an RDD of {@link GATKRead} according to the sort order in the header.
130      * @param reads a JavaRDD of reads which may or may not be sorted
131      * @param header a header which specifies the desired new sort order.
132      *               Only {@link SAMFileHeader.SortOrder#coordinate} and {@link SAMFileHeader.SortOrder#queryname} are supported.
133      *               All others will result in {@link GATKException}
134      * @param numReducers number of reducers to use when sorting
135      * @return a new JavaRDD or reads which is globally sorted in a way that is consistent with the sort order given in the header
136      */
sortReadsAccordingToHeader(final JavaRDD<GATKRead> reads, final SAMFileHeader header, final int numReducers)137     public static JavaRDD<GATKRead> sortReadsAccordingToHeader(final JavaRDD<GATKRead> reads, final SAMFileHeader header, final int numReducers){
138         final SAMFileHeader.SortOrder order = header.getSortOrder();
139         switch (order){
140             case coordinate:
141                 return sortUsingElementsAsKeys(reads, new ReadCoordinateComparator(header), numReducers);
142             case queryname:
143                 final JavaRDD<GATKRead> sortedReads = sortUsingElementsAsKeys(reads, new ReadQueryNameComparator(), numReducers);
144                 return putReadsWithTheSameNameInTheSamePartition(header, sortedReads, JavaSparkContext.fromSparkContext(reads.context()));
145             default:
146                 throw new GATKException("Sort order: " + order + " is not supported.");
147         }
148     }
149 
150     /**
151      *   Do a global sort of an RDD using the given comparator.
152      *   This method uses the RDD elements themselves as the keys in the spark key/value sort.  This may be inefficient
153      *   if the comparator only uses looks at a small fraction of the element to perform the comparison.
154      */
sortUsingElementsAsKeys(JavaRDD<T> elements, Comparator<T> comparator, int numReducers)155     public static <T> JavaRDD<T> sortUsingElementsAsKeys(JavaRDD<T> elements, Comparator<T> comparator, int numReducers) {
156         Utils.nonNull(comparator);
157         Utils.nonNull(elements);
158 
159         // Turn into key-value pairs so we can sort (by key). Values are null so there is no overhead in the amount
160         // of data going through the shuffle.
161         final JavaPairRDD<T, Void> rddReadPairs = elements.mapToPair(read -> new Tuple2<>(read, (Void) null));
162 
163         final JavaPairRDD<T, Void> readVoidPairs;
164         if (numReducers > 0) {
165             readVoidPairs = rddReadPairs.sortByKey(comparator, true, numReducers);
166         } else {
167             readVoidPairs = rddReadPairs.sortByKey(comparator);
168         }
169         return readVoidPairs.keys();
170     }
171 
172     /**
173      * Ensure all reads with the same name appear in the same partition of a queryname sorted RDD.
174      * This avoids a global shuffle and only transfers the leading elements from each partition which is fast in most
175      * cases.
176      *
177      * The RDD must be queryname sorted.  If there are so many reads with the same name that they span multiple partitions
178      * this will throw {@link GATKException}.
179      */
putReadsWithTheSameNameInTheSamePartition( final SAMFileHeader header, final JavaRDD<GATKRead> reads, final JavaSparkContext ctx )180     public static JavaRDD<GATKRead> putReadsWithTheSameNameInTheSamePartition( final SAMFileHeader header,
181                                                                                final JavaRDD<GATKRead> reads,
182                                                                                final JavaSparkContext ctx ) {
183         Utils.validateArg(ReadUtils.isReadNameGroupedBam(header), () -> "Reads must be queryname grouped or sorted. " +
184                 "Actual sort:" + header.getSortOrder() + "  Actual grouping:" + header.getGroupOrder());
185 
186         // Find the first group in each partition
187         final List<List<GATKRead>> firstReadNameGroupInEachPartition = reads
188                 .mapPartitions(it -> {
189                     if ( !it.hasNext() ) {
190                         return Iterators.singletonIterator(Collections.<GATKRead>emptyList());
191                     }
192                     final List<GATKRead> firstGroup = new ArrayList<>(2);
193                     final GATKRead firstRead = it.next();
194                     firstGroup.add(firstRead);
195                     final String groupName = firstRead.getName();
196                     while ( it.hasNext() ) {
197                         final GATKRead read = it.next();
198                         if ( !groupName.equals(read.getName()) ) {
199                             break;
200                         }
201                         firstGroup.add(read);
202                     }
203                     return Iterators.singletonIterator(firstGroup);
204                 })
205                 .collect();
206 
207         // Shift left, so that each partition will be zipped with the first read group from the _next_ partition
208         final int numPartitions = reads.getNumPartitions();
209         final List<List<GATKRead>> firstGroupFromNextPartition =
210                 new ArrayList<>(firstReadNameGroupInEachPartition.subList(1, numPartitions));
211         firstGroupFromNextPartition.add(Collections.emptyList()); // the last partition does not have any reads to add to it
212 
213         // Take care of the situation where an entire partition contains reads with the same name
214         // (unlikely, but could happen with very long reads, or very small partitions).
215         for ( int idx = numPartitions - 1; idx >= 1; --idx ) {
216             final List<GATKRead> curGroup = firstGroupFromNextPartition.get(idx);
217             if ( !curGroup.isEmpty() ) {
218                 final String groupName = curGroup.get(0).getName();
219                 int idx2 = idx;
220                 while ( --idx2 >= 0 ) {
221                     final List<GATKRead> prevGroup = firstGroupFromNextPartition.get(idx2);
222                     if ( !prevGroup.isEmpty() ) {
223                         if ( groupName.equals(prevGroup.get(0).getName()) ) {
224                             prevGroup.addAll(curGroup);
225                             curGroup.clear();
226                         }
227                         break;
228                     }
229                 }
230             }
231         }
232 
233         // Peel off the first group in each partition
234         final int[] firstGroupSizes = firstReadNameGroupInEachPartition.stream().mapToInt(List::size).toArray();
235         firstGroupSizes[0] = 0; // first partition has no predecessor to handle its first group of reads
236         JavaRDD<GATKRead> readsSansFirstGroup = reads.mapPartitionsWithIndex( (idx, itr) ->
237             { int groupSize = firstGroupSizes[idx];
238               while ( itr.hasNext() && groupSize-- > 0 ) {
239                   itr.next();
240               }
241               return itr; }, true);
242 
243         // Zip up the remaining reads with the first read group from the _next_ partition
244         return readsSansFirstGroup.zipPartitions(ctx.parallelize(firstGroupFromNextPartition, numPartitions),
245                 (it1, it2) -> Iterators.concat(it1, it2.next().iterator()));
246     }
247 
248     /**
249      * Like <code>groupByKey</code>, but assumes that values are already sorted by key, so no shuffle is needed,
250      * which is much faster.
251      * @param rdd the input RDD
252      * @param <K> type of keys
253      * @param <V> type of values
254      * @return an RDD where each the values for each key are grouped into an iterable collection
255      */
spanByKey(JavaPairRDD<K, V> rdd)256     public static <K, V> JavaPairRDD<K, Iterable<V>> spanByKey(JavaPairRDD<K, V> rdd) {
257         return rdd.mapPartitionsToPair(SparkUtils::getSpanningIterator);
258     }
259 
260     /**
261      * An iterator that groups values having the same key into iterable collections.
262      * @param iterator an iterator over key-value pairs
263      * @param <K> type of keys
264      * @param <V> type of values
265      * @return an iterator over pairs of keys and grouped values
266      */
getSpanningIterator(Iterator<Tuple2<K, V>> iterator)267     public static <K, V> Iterator<Tuple2<K, Iterable<V>>> getSpanningIterator(Iterator<Tuple2<K, V>> iterator) {
268         final PeekingIterator<Tuple2<K, V>> iter = Iterators.peekingIterator(iterator);
269         return new AbstractIterator<Tuple2<K, Iterable<V>>>() {
270             @Override
271             protected Tuple2<K, Iterable<V>> computeNext() {
272                 K key = null;
273                 List<V> group = Lists.newArrayList();
274                 while (iter.hasNext()) {
275                     if (key == null) {
276                         Tuple2<K, V> next = iter.next();
277                         key = next._1();
278                         V value = next._2();
279                         group.add(value);
280                         continue;
281                     }
282                     K nextKey = iter.peek()._1(); // don't advance...
283                     if (nextKey.equals(key)) {
284                         group.add(iter.next()._2()); // .. unless the keys match
285                     } else {
286                         return new Tuple2<>(key, group);
287                     }
288                 }
289                 if (key != null) {
290                     return new Tuple2<>(key, group);
291                 }
292                 return endOfData();
293             }
294         };
295     }
296 
297     /**
298      * Sort reads into queryname order if they are not already sorted
299      */
300     public static JavaRDD<GATKRead> querynameSortReadsIfNecessary(JavaRDD<GATKRead> reads, int numReducers, SAMFileHeader header) {
301         JavaRDD<GATKRead> sortedReadsForMarking;
302         if (ReadUtils.isReadNameGroupedBam(header)) {
303             sortedReadsForMarking = reads;
304         } else {
305             header.setSortOrder(SAMFileHeader.SortOrder.queryname);
306             sortedReadsForMarking = sortReadsAccordingToHeader(reads, header, numReducers);
307         }
308         return sortedReadsForMarking;
309     }
310 }
311