1 package org.broadinstitute.hellbender.utils.spark; 2 3 import com.google.common.collect.AbstractIterator; 4 import com.google.common.collect.Iterators; 5 import com.google.common.collect.Lists; 6 import com.google.common.collect.PeekingIterator; 7 import htsjdk.samtools.SAMFileHeader; 8 import htsjdk.samtools.SAMSequenceRecord; 9 import htsjdk.samtools.SAMTextHeaderCodec; 10 import htsjdk.samtools.util.BinaryCodec; 11 import htsjdk.samtools.util.BlockCompressedOutputStream; 12 import htsjdk.samtools.util.BlockCompressedStreamConstants; 13 import htsjdk.samtools.util.RuntimeIOException; 14 import org.apache.commons.io.FileUtils; 15 import org.apache.hadoop.fs.FileSystem; 16 import org.apache.hadoop.fs.Path; 17 import org.apache.logging.log4j.LogManager; 18 import org.apache.logging.log4j.Logger; 19 import org.apache.spark.api.java.JavaPairRDD; 20 import org.apache.spark.api.java.JavaRDD; 21 import org.apache.spark.api.java.JavaSparkContext; 22 import org.apache.spark.broadcast.Broadcast; 23 import org.broadinstitute.hellbender.engine.spark.datasources.ReadsSparkSink; 24 import org.broadinstitute.hellbender.exceptions.GATKException; 25 import org.broadinstitute.hellbender.exceptions.UserException; 26 import org.broadinstitute.hellbender.utils.Utils; 27 import org.broadinstitute.hellbender.utils.read.*; 28 import scala.Tuple2; 29 30 import java.io.*; 31 import java.net.URI; 32 import java.util.*; 33 34 /** 35 * Miscellaneous Spark-related utilities 36 */ 37 public final class SparkUtils { 38 private static final Logger logger = LogManager.getLogger(SparkUtils.class); 39 40 /** Sometimes Spark has trouble destroying a broadcast variable, but we'd like the app to continue anyway. */ destroyBroadcast(final Broadcast<T> broadcast, final String whatBroadcast )41 public static <T> void destroyBroadcast(final Broadcast<T> broadcast, final String whatBroadcast ) { 42 try { 43 broadcast.destroy(); 44 } catch ( final Exception e ) { 45 logger.warn("Failed to destroy broadcast for " + whatBroadcast, e); 46 } 47 } 48 SparkUtils()49 private SparkUtils() {} 50 51 /** 52 * Converts a headerless Hadoop bam shard (eg., a part0000, part0001, etc. file produced by 53 * {@link org.broadinstitute.hellbender.engine.spark.datasources.ReadsSparkSink}) into a readable bam file 54 * by adding a header and a BGZF terminator. 55 * 56 * This method is not intended for use with Hadoop bam shards that already have a header -- these shards are 57 * already readable using samtools. Currently {@link ReadsSparkSink} saves the "shards" with a header for the 58 * {@link ReadsWriteFormat#SHARDED} case, and without a header for the {@link ReadsWriteFormat#SINGLE} case. 59 * 60 * @param bamShard The headerless Hadoop bam shard to convert 61 * @param header header for the BAM file to be created 62 * @param destination path to which to write the new BAM file 63 */ convertHeaderlessHadoopBamShardToBam( final File bamShard, final SAMFileHeader header, final File destination )64 public static void convertHeaderlessHadoopBamShardToBam( final File bamShard, final SAMFileHeader header, final File destination ) { 65 try ( FileOutputStream outStream = new FileOutputStream(destination) ) { 66 writeBAMHeaderToStream(header, outStream); 67 FileUtils.copyFile(bamShard, outStream); 68 outStream.write(BlockCompressedStreamConstants.EMPTY_GZIP_BLOCK); 69 } 70 catch ( IOException e ) { 71 throw new UserException("Error writing to " + destination.getAbsolutePath(), e); 72 } 73 } 74 75 /** 76 * Private helper method for {@link #convertHeaderlessHadoopBamShardToBam} that takes a SAMFileHeader and writes it 77 * to the provided `OutputStream`, correctly encoded for the BAM format and preceded by the BAM magic bytes. 78 * 79 * @param samFileHeader SAM header to write 80 * @param outputStream stream to write the SAM header to 81 */ writeBAMHeaderToStream( final SAMFileHeader samFileHeader, final OutputStream outputStream )82 private static void writeBAMHeaderToStream( final SAMFileHeader samFileHeader, final OutputStream outputStream ) { 83 final BlockCompressedOutputStream blockCompressedOutputStream = new BlockCompressedOutputStream(outputStream, (File)null); 84 final BinaryCodec outputBinaryCodec = new BinaryCodec(new DataOutputStream(blockCompressedOutputStream)); 85 86 final String headerString; 87 final Writer stringWriter = new StringWriter(); 88 new SAMTextHeaderCodec().encode(stringWriter, samFileHeader, true); 89 headerString = stringWriter.toString(); 90 91 outputBinaryCodec.writeBytes(ReadUtils.BAM_MAGIC); 92 93 // calculate and write the length of the SAM file header text and the header text 94 outputBinaryCodec.writeString(headerString, true, false); 95 96 // write the sequences binarily. This is redundant with the text header 97 outputBinaryCodec.writeInt(samFileHeader.getSequenceDictionary().size()); 98 for (final SAMSequenceRecord sequenceRecord: samFileHeader.getSequenceDictionary().getSequences()) { 99 outputBinaryCodec.writeString(sequenceRecord.getSequenceName(), true, true); 100 outputBinaryCodec.writeInt(sequenceRecord.getSequenceLength()); 101 } 102 103 try { 104 blockCompressedOutputStream.flush(); 105 } catch (final IOException ioe) { 106 throw new RuntimeIOException(ioe); 107 } 108 } 109 110 /** 111 * Determine if the <code>targetPath</code> exists. 112 * @param ctx JavaSparkContext 113 * @param targetURI the <code>org.apache.hadoop.fs.Path</code> URI to check 114 * @return true if the targetPath exists, otherwise false 115 */ hadoopPathExists(final JavaSparkContext ctx, final URI targetURI)116 public static boolean hadoopPathExists(final JavaSparkContext ctx, final URI targetURI) { 117 Utils.nonNull(ctx); 118 Utils.nonNull(targetURI); 119 try { 120 final Path targetHadoopPath = new Path(targetURI); 121 final FileSystem fs = targetHadoopPath.getFileSystem(ctx.hadoopConfiguration()); 122 return fs.exists(targetHadoopPath); 123 } catch (IOException e) { 124 throw new UserException("Error validating existence of path " + targetURI + ": " + e.getMessage()); 125 } 126 } 127 128 /** 129 * Do a total sort of an RDD of {@link GATKRead} according to the sort order in the header. 130 * @param reads a JavaRDD of reads which may or may not be sorted 131 * @param header a header which specifies the desired new sort order. 132 * Only {@link SAMFileHeader.SortOrder#coordinate} and {@link SAMFileHeader.SortOrder#queryname} are supported. 133 * All others will result in {@link GATKException} 134 * @param numReducers number of reducers to use when sorting 135 * @return a new JavaRDD or reads which is globally sorted in a way that is consistent with the sort order given in the header 136 */ sortReadsAccordingToHeader(final JavaRDD<GATKRead> reads, final SAMFileHeader header, final int numReducers)137 public static JavaRDD<GATKRead> sortReadsAccordingToHeader(final JavaRDD<GATKRead> reads, final SAMFileHeader header, final int numReducers){ 138 final SAMFileHeader.SortOrder order = header.getSortOrder(); 139 switch (order){ 140 case coordinate: 141 return sortUsingElementsAsKeys(reads, new ReadCoordinateComparator(header), numReducers); 142 case queryname: 143 final JavaRDD<GATKRead> sortedReads = sortUsingElementsAsKeys(reads, new ReadQueryNameComparator(), numReducers); 144 return putReadsWithTheSameNameInTheSamePartition(header, sortedReads, JavaSparkContext.fromSparkContext(reads.context())); 145 default: 146 throw new GATKException("Sort order: " + order + " is not supported."); 147 } 148 } 149 150 /** 151 * Do a global sort of an RDD using the given comparator. 152 * This method uses the RDD elements themselves as the keys in the spark key/value sort. This may be inefficient 153 * if the comparator only uses looks at a small fraction of the element to perform the comparison. 154 */ sortUsingElementsAsKeys(JavaRDD<T> elements, Comparator<T> comparator, int numReducers)155 public static <T> JavaRDD<T> sortUsingElementsAsKeys(JavaRDD<T> elements, Comparator<T> comparator, int numReducers) { 156 Utils.nonNull(comparator); 157 Utils.nonNull(elements); 158 159 // Turn into key-value pairs so we can sort (by key). Values are null so there is no overhead in the amount 160 // of data going through the shuffle. 161 final JavaPairRDD<T, Void> rddReadPairs = elements.mapToPair(read -> new Tuple2<>(read, (Void) null)); 162 163 final JavaPairRDD<T, Void> readVoidPairs; 164 if (numReducers > 0) { 165 readVoidPairs = rddReadPairs.sortByKey(comparator, true, numReducers); 166 } else { 167 readVoidPairs = rddReadPairs.sortByKey(comparator); 168 } 169 return readVoidPairs.keys(); 170 } 171 172 /** 173 * Ensure all reads with the same name appear in the same partition of a queryname sorted RDD. 174 * This avoids a global shuffle and only transfers the leading elements from each partition which is fast in most 175 * cases. 176 * 177 * The RDD must be queryname sorted. If there are so many reads with the same name that they span multiple partitions 178 * this will throw {@link GATKException}. 179 */ putReadsWithTheSameNameInTheSamePartition( final SAMFileHeader header, final JavaRDD<GATKRead> reads, final JavaSparkContext ctx )180 public static JavaRDD<GATKRead> putReadsWithTheSameNameInTheSamePartition( final SAMFileHeader header, 181 final JavaRDD<GATKRead> reads, 182 final JavaSparkContext ctx ) { 183 Utils.validateArg(ReadUtils.isReadNameGroupedBam(header), () -> "Reads must be queryname grouped or sorted. " + 184 "Actual sort:" + header.getSortOrder() + " Actual grouping:" + header.getGroupOrder()); 185 186 // Find the first group in each partition 187 final List<List<GATKRead>> firstReadNameGroupInEachPartition = reads 188 .mapPartitions(it -> { 189 if ( !it.hasNext() ) { 190 return Iterators.singletonIterator(Collections.<GATKRead>emptyList()); 191 } 192 final List<GATKRead> firstGroup = new ArrayList<>(2); 193 final GATKRead firstRead = it.next(); 194 firstGroup.add(firstRead); 195 final String groupName = firstRead.getName(); 196 while ( it.hasNext() ) { 197 final GATKRead read = it.next(); 198 if ( !groupName.equals(read.getName()) ) { 199 break; 200 } 201 firstGroup.add(read); 202 } 203 return Iterators.singletonIterator(firstGroup); 204 }) 205 .collect(); 206 207 // Shift left, so that each partition will be zipped with the first read group from the _next_ partition 208 final int numPartitions = reads.getNumPartitions(); 209 final List<List<GATKRead>> firstGroupFromNextPartition = 210 new ArrayList<>(firstReadNameGroupInEachPartition.subList(1, numPartitions)); 211 firstGroupFromNextPartition.add(Collections.emptyList()); // the last partition does not have any reads to add to it 212 213 // Take care of the situation where an entire partition contains reads with the same name 214 // (unlikely, but could happen with very long reads, or very small partitions). 215 for ( int idx = numPartitions - 1; idx >= 1; --idx ) { 216 final List<GATKRead> curGroup = firstGroupFromNextPartition.get(idx); 217 if ( !curGroup.isEmpty() ) { 218 final String groupName = curGroup.get(0).getName(); 219 int idx2 = idx; 220 while ( --idx2 >= 0 ) { 221 final List<GATKRead> prevGroup = firstGroupFromNextPartition.get(idx2); 222 if ( !prevGroup.isEmpty() ) { 223 if ( groupName.equals(prevGroup.get(0).getName()) ) { 224 prevGroup.addAll(curGroup); 225 curGroup.clear(); 226 } 227 break; 228 } 229 } 230 } 231 } 232 233 // Peel off the first group in each partition 234 final int[] firstGroupSizes = firstReadNameGroupInEachPartition.stream().mapToInt(List::size).toArray(); 235 firstGroupSizes[0] = 0; // first partition has no predecessor to handle its first group of reads 236 JavaRDD<GATKRead> readsSansFirstGroup = reads.mapPartitionsWithIndex( (idx, itr) -> 237 { int groupSize = firstGroupSizes[idx]; 238 while ( itr.hasNext() && groupSize-- > 0 ) { 239 itr.next(); 240 } 241 return itr; }, true); 242 243 // Zip up the remaining reads with the first read group from the _next_ partition 244 return readsSansFirstGroup.zipPartitions(ctx.parallelize(firstGroupFromNextPartition, numPartitions), 245 (it1, it2) -> Iterators.concat(it1, it2.next().iterator())); 246 } 247 248 /** 249 * Like <code>groupByKey</code>, but assumes that values are already sorted by key, so no shuffle is needed, 250 * which is much faster. 251 * @param rdd the input RDD 252 * @param <K> type of keys 253 * @param <V> type of values 254 * @return an RDD where each the values for each key are grouped into an iterable collection 255 */ spanByKey(JavaPairRDD<K, V> rdd)256 public static <K, V> JavaPairRDD<K, Iterable<V>> spanByKey(JavaPairRDD<K, V> rdd) { 257 return rdd.mapPartitionsToPair(SparkUtils::getSpanningIterator); 258 } 259 260 /** 261 * An iterator that groups values having the same key into iterable collections. 262 * @param iterator an iterator over key-value pairs 263 * @param <K> type of keys 264 * @param <V> type of values 265 * @return an iterator over pairs of keys and grouped values 266 */ getSpanningIterator(Iterator<Tuple2<K, V>> iterator)267 public static <K, V> Iterator<Tuple2<K, Iterable<V>>> getSpanningIterator(Iterator<Tuple2<K, V>> iterator) { 268 final PeekingIterator<Tuple2<K, V>> iter = Iterators.peekingIterator(iterator); 269 return new AbstractIterator<Tuple2<K, Iterable<V>>>() { 270 @Override 271 protected Tuple2<K, Iterable<V>> computeNext() { 272 K key = null; 273 List<V> group = Lists.newArrayList(); 274 while (iter.hasNext()) { 275 if (key == null) { 276 Tuple2<K, V> next = iter.next(); 277 key = next._1(); 278 V value = next._2(); 279 group.add(value); 280 continue; 281 } 282 K nextKey = iter.peek()._1(); // don't advance... 283 if (nextKey.equals(key)) { 284 group.add(iter.next()._2()); // .. unless the keys match 285 } else { 286 return new Tuple2<>(key, group); 287 } 288 } 289 if (key != null) { 290 return new Tuple2<>(key, group); 291 } 292 return endOfData(); 293 } 294 }; 295 } 296 297 /** 298 * Sort reads into queryname order if they are not already sorted 299 */ 300 public static JavaRDD<GATKRead> querynameSortReadsIfNecessary(JavaRDD<GATKRead> reads, int numReducers, SAMFileHeader header) { 301 JavaRDD<GATKRead> sortedReadsForMarking; 302 if (ReadUtils.isReadNameGroupedBam(header)) { 303 sortedReadsForMarking = reads; 304 } else { 305 header.setSortOrder(SAMFileHeader.SortOrder.queryname); 306 sortedReadsForMarking = sortReadsAccordingToHeader(reads, header, numReducers); 307 } 308 return sortedReadsForMarking; 309 } 310 } 311