1/* 2 * Licensed to the Apache Software Foundation (ASF) under one or more 3 * contributor license agreements. See the NOTICE file distributed with 4 * this work for additional information regarding copyright ownership. 5 * The ASF licenses this file to You under the Apache License, Version 2.0 6 * (the "License"); you may not use this file except in compliance with 7 * the License. You may obtain a copy of the License at 8 * 9 * http://www.apache.org/licenses/LICENSE-2.0 10 * 11 * Unless required by applicable law or agreed to in writing, software 12 * distributed under the License is distributed on an "AS IS" BASIS, 13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 * See the License for the specific language governing permissions and 15 * limitations under the License. 16 */ 17 18package org.apache.spark.shuffle 19 20import org.apache.spark._ 21import org.apache.spark.internal.Logging 22import org.apache.spark.serializer.SerializerManager 23import org.apache.spark.storage.{BlockManager, ShuffleBlockFetcherIterator} 24import org.apache.spark.util.CompletionIterator 25import org.apache.spark.util.collection.ExternalSorter 26 27/** 28 * Fetches and reads the partitions in range [startPartition, endPartition) from a shuffle by 29 * requesting them from other nodes' block stores. 30 */ 31private[spark] class BlockStoreShuffleReader[K, C]( 32 handle: BaseShuffleHandle[K, _, C], 33 startPartition: Int, 34 endPartition: Int, 35 context: TaskContext, 36 serializerManager: SerializerManager = SparkEnv.get.serializerManager, 37 blockManager: BlockManager = SparkEnv.get.blockManager, 38 mapOutputTracker: MapOutputTracker = SparkEnv.get.mapOutputTracker) 39 extends ShuffleReader[K, C] with Logging { 40 41 private val dep = handle.dependency 42 43 /** Read the combined key-values for this reduce task */ 44 override def read(): Iterator[Product2[K, C]] = { 45 val blockFetcherItr = new ShuffleBlockFetcherIterator( 46 context, 47 blockManager.shuffleClient, 48 blockManager, 49 mapOutputTracker.getMapSizesByExecutorId(handle.shuffleId, startPartition, endPartition), 50 // Note: we use getSizeAsMb when no suffix is provided for backwards compatibility 51 SparkEnv.get.conf.getSizeAsMb("spark.reducer.maxSizeInFlight", "48m") * 1024 * 1024, 52 SparkEnv.get.conf.getInt("spark.reducer.maxReqsInFlight", Int.MaxValue)) 53 54 // Wrap the streams for compression and encryption based on configuration 55 val wrappedStreams = blockFetcherItr.map { case (blockId, inputStream) => 56 serializerManager.wrapStream(blockId, inputStream) 57 } 58 59 val serializerInstance = dep.serializer.newInstance() 60 61 // Create a key/value iterator for each stream 62 val recordIter = wrappedStreams.flatMap { wrappedStream => 63 // Note: the asKeyValueIterator below wraps a key/value iterator inside of a 64 // NextIterator. The NextIterator makes sure that close() is called on the 65 // underlying InputStream when all records have been read. 66 serializerInstance.deserializeStream(wrappedStream).asKeyValueIterator 67 } 68 69 // Update the context task metrics for each record read. 70 val readMetrics = context.taskMetrics.createTempShuffleReadMetrics() 71 val metricIter = CompletionIterator[(Any, Any), Iterator[(Any, Any)]]( 72 recordIter.map { record => 73 readMetrics.incRecordsRead(1) 74 record 75 }, 76 context.taskMetrics().mergeShuffleReadMetrics()) 77 78 // An interruptible iterator must be used here in order to support task cancellation 79 val interruptibleIter = new InterruptibleIterator[(Any, Any)](context, metricIter) 80 81 val aggregatedIter: Iterator[Product2[K, C]] = if (dep.aggregator.isDefined) { 82 if (dep.mapSideCombine) { 83 // We are reading values that are already combined 84 val combinedKeyValuesIterator = interruptibleIter.asInstanceOf[Iterator[(K, C)]] 85 dep.aggregator.get.combineCombinersByKey(combinedKeyValuesIterator, context) 86 } else { 87 // We don't know the value type, but also don't care -- the dependency *should* 88 // have made sure its compatible w/ this aggregator, which will convert the value 89 // type to the combined type C 90 val keyValuesIterator = interruptibleIter.asInstanceOf[Iterator[(K, Nothing)]] 91 dep.aggregator.get.combineValuesByKey(keyValuesIterator, context) 92 } 93 } else { 94 require(!dep.mapSideCombine, "Map-side combine without Aggregator specified!") 95 interruptibleIter.asInstanceOf[Iterator[Product2[K, C]]] 96 } 97 98 // Sort the output if there is a sort ordering defined. 99 dep.keyOrdering match { 100 case Some(keyOrd: Ordering[K]) => 101 // Create an ExternalSorter to sort the data. Note that if spark.shuffle.spill is disabled, 102 // the ExternalSorter won't spill to disk. 103 val sorter = 104 new ExternalSorter[K, C, C](context, ordering = Some(keyOrd), serializer = dep.serializer) 105 sorter.insertAll(aggregatedIter) 106 context.taskMetrics().incMemoryBytesSpilled(sorter.memoryBytesSpilled) 107 context.taskMetrics().incDiskBytesSpilled(sorter.diskBytesSpilled) 108 context.taskMetrics().incPeakExecutionMemory(sorter.peakMemoryUsedBytes) 109 CompletionIterator[Product2[K, C], Iterator[Product2[K, C]]](sorter.iterator, sorter.stop()) 110 case None => 111 aggregatedIter 112 } 113 } 114} 115