1/*
2 * Licensed to the Apache Software Foundation (ASF) under one or more
3 * contributor license agreements.  See the NOTICE file distributed with
4 * this work for additional information regarding copyright ownership.
5 * The ASF licenses this file to You under the Apache License, Version 2.0
6 * (the "License"); you may not use this file except in compliance with
7 * the License.  You may obtain a copy of the License at
8 *
9 *    http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17
18package org.apache.spark.shuffle
19
20import org.apache.spark._
21import org.apache.spark.internal.Logging
22import org.apache.spark.serializer.SerializerManager
23import org.apache.spark.storage.{BlockManager, ShuffleBlockFetcherIterator}
24import org.apache.spark.util.CompletionIterator
25import org.apache.spark.util.collection.ExternalSorter
26
27/**
28 * Fetches and reads the partitions in range [startPartition, endPartition) from a shuffle by
29 * requesting them from other nodes' block stores.
30 */
31private[spark] class BlockStoreShuffleReader[K, C](
32    handle: BaseShuffleHandle[K, _, C],
33    startPartition: Int,
34    endPartition: Int,
35    context: TaskContext,
36    serializerManager: SerializerManager = SparkEnv.get.serializerManager,
37    blockManager: BlockManager = SparkEnv.get.blockManager,
38    mapOutputTracker: MapOutputTracker = SparkEnv.get.mapOutputTracker)
39  extends ShuffleReader[K, C] with Logging {
40
41  private val dep = handle.dependency
42
43  /** Read the combined key-values for this reduce task */
44  override def read(): Iterator[Product2[K, C]] = {
45    val blockFetcherItr = new ShuffleBlockFetcherIterator(
46      context,
47      blockManager.shuffleClient,
48      blockManager,
49      mapOutputTracker.getMapSizesByExecutorId(handle.shuffleId, startPartition, endPartition),
50      // Note: we use getSizeAsMb when no suffix is provided for backwards compatibility
51      SparkEnv.get.conf.getSizeAsMb("spark.reducer.maxSizeInFlight", "48m") * 1024 * 1024,
52      SparkEnv.get.conf.getInt("spark.reducer.maxReqsInFlight", Int.MaxValue))
53
54    // Wrap the streams for compression and encryption based on configuration
55    val wrappedStreams = blockFetcherItr.map { case (blockId, inputStream) =>
56      serializerManager.wrapStream(blockId, inputStream)
57    }
58
59    val serializerInstance = dep.serializer.newInstance()
60
61    // Create a key/value iterator for each stream
62    val recordIter = wrappedStreams.flatMap { wrappedStream =>
63      // Note: the asKeyValueIterator below wraps a key/value iterator inside of a
64      // NextIterator. The NextIterator makes sure that close() is called on the
65      // underlying InputStream when all records have been read.
66      serializerInstance.deserializeStream(wrappedStream).asKeyValueIterator
67    }
68
69    // Update the context task metrics for each record read.
70    val readMetrics = context.taskMetrics.createTempShuffleReadMetrics()
71    val metricIter = CompletionIterator[(Any, Any), Iterator[(Any, Any)]](
72      recordIter.map { record =>
73        readMetrics.incRecordsRead(1)
74        record
75      },
76      context.taskMetrics().mergeShuffleReadMetrics())
77
78    // An interruptible iterator must be used here in order to support task cancellation
79    val interruptibleIter = new InterruptibleIterator[(Any, Any)](context, metricIter)
80
81    val aggregatedIter: Iterator[Product2[K, C]] = if (dep.aggregator.isDefined) {
82      if (dep.mapSideCombine) {
83        // We are reading values that are already combined
84        val combinedKeyValuesIterator = interruptibleIter.asInstanceOf[Iterator[(K, C)]]
85        dep.aggregator.get.combineCombinersByKey(combinedKeyValuesIterator, context)
86      } else {
87        // We don't know the value type, but also don't care -- the dependency *should*
88        // have made sure its compatible w/ this aggregator, which will convert the value
89        // type to the combined type C
90        val keyValuesIterator = interruptibleIter.asInstanceOf[Iterator[(K, Nothing)]]
91        dep.aggregator.get.combineValuesByKey(keyValuesIterator, context)
92      }
93    } else {
94      require(!dep.mapSideCombine, "Map-side combine without Aggregator specified!")
95      interruptibleIter.asInstanceOf[Iterator[Product2[K, C]]]
96    }
97
98    // Sort the output if there is a sort ordering defined.
99    dep.keyOrdering match {
100      case Some(keyOrd: Ordering[K]) =>
101        // Create an ExternalSorter to sort the data. Note that if spark.shuffle.spill is disabled,
102        // the ExternalSorter won't spill to disk.
103        val sorter =
104          new ExternalSorter[K, C, C](context, ordering = Some(keyOrd), serializer = dep.serializer)
105        sorter.insertAll(aggregatedIter)
106        context.taskMetrics().incMemoryBytesSpilled(sorter.memoryBytesSpilled)
107        context.taskMetrics().incDiskBytesSpilled(sorter.diskBytesSpilled)
108        context.taskMetrics().incPeakExecutionMemory(sorter.peakMemoryUsedBytes)
109        CompletionIterator[Product2[K, C], Iterator[Product2[K, C]]](sorter.iterator, sorter.stop())
110      case None =>
111        aggregatedIter
112    }
113  }
114}
115