1/* 2 * Licensed to the Apache Software Foundation (ASF) under one or more 3 * contributor license agreements. See the NOTICE file distributed with 4 * this work for additional information regarding copyright ownership. 5 * The ASF licenses this file to You under the Apache License, Version 2.0 6 * (the "License"); you may not use this file except in compliance with 7 * the License. You may obtain a copy of the License at 8 * 9 * http://www.apache.org/licenses/LICENSE-2.0 10 * 11 * Unless required by applicable law or agreed to in writing, software 12 * distributed under the License is distributed on an "AS IS" BASIS, 13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 * See the License for the specific language governing permissions and 15 * limitations under the License. 16 */ 17 18package org.apache.spark.rdd 19 20import scala.reflect.ClassTag 21 22import org.apache.spark._ 23import org.apache.spark.annotation.DeveloperApi 24import org.apache.spark.serializer.Serializer 25 26private[spark] class ShuffledRDDPartition(val idx: Int) extends Partition { 27 override val index: Int = idx 28 29 override def hashCode(): Int = index 30 31 override def equals(other: Any): Boolean = super.equals(other) 32} 33 34/** 35 * :: DeveloperApi :: 36 * The resulting RDD from a shuffle (e.g. repartitioning of data). 37 * @param prev the parent RDD. 38 * @param part the partitioner used to partition the RDD 39 * @tparam K the key class. 40 * @tparam V the value class. 41 * @tparam C the combiner class. 42 */ 43// TODO: Make this return RDD[Product2[K, C]] or have some way to configure mutable pairs 44@DeveloperApi 45class ShuffledRDD[K: ClassTag, V: ClassTag, C: ClassTag]( 46 @transient var prev: RDD[_ <: Product2[K, V]], 47 part: Partitioner) 48 extends RDD[(K, C)](prev.context, Nil) { 49 50 private var userSpecifiedSerializer: Option[Serializer] = None 51 52 private var keyOrdering: Option[Ordering[K]] = None 53 54 private var aggregator: Option[Aggregator[K, V, C]] = None 55 56 private var mapSideCombine: Boolean = false 57 58 /** Set a serializer for this RDD's shuffle, or null to use the default (spark.serializer) */ 59 def setSerializer(serializer: Serializer): ShuffledRDD[K, V, C] = { 60 this.userSpecifiedSerializer = Option(serializer) 61 this 62 } 63 64 /** Set key ordering for RDD's shuffle. */ 65 def setKeyOrdering(keyOrdering: Ordering[K]): ShuffledRDD[K, V, C] = { 66 this.keyOrdering = Option(keyOrdering) 67 this 68 } 69 70 /** Set aggregator for RDD's shuffle. */ 71 def setAggregator(aggregator: Aggregator[K, V, C]): ShuffledRDD[K, V, C] = { 72 this.aggregator = Option(aggregator) 73 this 74 } 75 76 /** Set mapSideCombine flag for RDD's shuffle. */ 77 def setMapSideCombine(mapSideCombine: Boolean): ShuffledRDD[K, V, C] = { 78 this.mapSideCombine = mapSideCombine 79 this 80 } 81 82 override def getDependencies: Seq[Dependency[_]] = { 83 val serializer = userSpecifiedSerializer.getOrElse { 84 val serializerManager = SparkEnv.get.serializerManager 85 if (mapSideCombine) { 86 serializerManager.getSerializer(implicitly[ClassTag[K]], implicitly[ClassTag[C]]) 87 } else { 88 serializerManager.getSerializer(implicitly[ClassTag[K]], implicitly[ClassTag[V]]) 89 } 90 } 91 List(new ShuffleDependency(prev, part, serializer, keyOrdering, aggregator, mapSideCombine)) 92 } 93 94 override val partitioner = Some(part) 95 96 override def getPartitions: Array[Partition] = { 97 Array.tabulate[Partition](part.numPartitions)(i => new ShuffledRDDPartition(i)) 98 } 99 100 override protected def getPreferredLocations(partition: Partition): Seq[String] = { 101 val tracker = SparkEnv.get.mapOutputTracker.asInstanceOf[MapOutputTrackerMaster] 102 val dep = dependencies.head.asInstanceOf[ShuffleDependency[K, V, C]] 103 tracker.getPreferredLocationsForShuffle(dep, partition.index) 104 } 105 106 override def compute(split: Partition, context: TaskContext): Iterator[(K, C)] = { 107 val dep = dependencies.head.asInstanceOf[ShuffleDependency[K, V, C]] 108 SparkEnv.get.shuffleManager.getReader(dep.shuffleHandle, split.index, split.index + 1, context) 109 .read() 110 .asInstanceOf[Iterator[(K, C)]] 111 } 112 113 override def clearDependencies() { 114 super.clearDependencies() 115 prev = null 116 } 117} 118