1/* 2 * Licensed to the Apache Software Foundation (ASF) under one or more 3 * contributor license agreements. See the NOTICE file distributed with 4 * this work for additional information regarding copyright ownership. 5 * The ASF licenses this file to You under the Apache License, Version 2.0 6 * (the "License"); you may not use this file except in compliance with 7 * the License. You may obtain a copy of the License at 8 * 9 * http://www.apache.org/licenses/LICENSE-2.0 10 * 11 * Unless required by applicable law or agreed to in writing, software 12 * distributed under the License is distributed on an "AS IS" BASIS, 13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 * See the License for the specific language governing permissions and 15 * limitations under the License. 16 */ 17 18package org.apache.spark.rdd 19 20import java.io.{IOException, ObjectOutputStream} 21 22import scala.reflect.ClassTag 23 24import org.apache.spark.{OneToOneDependency, Partition, SparkContext, TaskContext} 25import org.apache.spark.util.Utils 26 27/** 28 * Class representing partitions of PartitionerAwareUnionRDD, which maintains the list of 29 * corresponding partitions of parent RDDs. 30 */ 31private[spark] 32class PartitionerAwareUnionRDDPartition( 33 @transient val rdds: Seq[RDD[_]], 34 override val index: Int 35 ) extends Partition { 36 var parents = rdds.map(_.partitions(index)).toArray 37 38 override def hashCode(): Int = index 39 40 override def equals(other: Any): Boolean = super.equals(other) 41 42 @throws(classOf[IOException]) 43 private def writeObject(oos: ObjectOutputStream): Unit = Utils.tryOrIOException { 44 // Update the reference to parent partition at the time of task serialization 45 parents = rdds.map(_.partitions(index)).toArray 46 oos.defaultWriteObject() 47 } 48} 49 50/** 51 * Class representing an RDD that can take multiple RDDs partitioned by the same partitioner and 52 * unify them into a single RDD while preserving the partitioner. So m RDDs with p partitions each 53 * will be unified to a single RDD with p partitions and the same partitioner. The preferred 54 * location for each partition of the unified RDD will be the most common preferred location 55 * of the corresponding partitions of the parent RDDs. For example, location of partition 0 56 * of the unified RDD will be where most of partition 0 of the parent RDDs are located. 57 */ 58private[spark] 59class PartitionerAwareUnionRDD[T: ClassTag]( 60 sc: SparkContext, 61 var rdds: Seq[RDD[T]] 62 ) extends RDD[T](sc, rdds.map(x => new OneToOneDependency(x))) { 63 require(rdds.nonEmpty) 64 require(rdds.forall(_.partitioner.isDefined)) 65 require(rdds.flatMap(_.partitioner).toSet.size == 1, 66 "Parent RDDs have different partitioners: " + rdds.flatMap(_.partitioner)) 67 68 override val partitioner = rdds.head.partitioner 69 70 override def getPartitions: Array[Partition] = { 71 val numPartitions = partitioner.get.numPartitions 72 (0 until numPartitions).map { index => 73 new PartitionerAwareUnionRDDPartition(rdds, index) 74 }.toArray 75 } 76 77 // Get the location where most of the partitions of parent RDDs are located 78 override def getPreferredLocations(s: Partition): Seq[String] = { 79 logDebug("Finding preferred location for " + this + ", partition " + s.index) 80 val parentPartitions = s.asInstanceOf[PartitionerAwareUnionRDDPartition].parents 81 val locations = rdds.zip(parentPartitions).flatMap { 82 case (rdd, part) => 83 val parentLocations = currPrefLocs(rdd, part) 84 logDebug("Location of " + rdd + " partition " + part.index + " = " + parentLocations) 85 parentLocations 86 } 87 val location = if (locations.isEmpty) { 88 None 89 } else { 90 // Find the location that maximum number of parent partitions prefer 91 Some(locations.groupBy(x => x).maxBy(_._2.length)._1) 92 } 93 logDebug("Selected location for " + this + ", partition " + s.index + " = " + location) 94 location.toSeq 95 } 96 97 override def compute(s: Partition, context: TaskContext): Iterator[T] = { 98 val parentPartitions = s.asInstanceOf[PartitionerAwareUnionRDDPartition].parents 99 rdds.zip(parentPartitions).iterator.flatMap { 100 case (rdd, p) => rdd.iterator(p, context) 101 } 102 } 103 104 override def clearDependencies() { 105 super.clearDependencies() 106 rdds = null 107 } 108 109 // Get the *current* preferred locations from the DAGScheduler (as opposed to the static ones) 110 private def currPrefLocs(rdd: RDD[_], part: Partition): Seq[String] = { 111 rdd.context.getPreferredLocs(rdd, part.index).map(tl => tl.host) 112 } 113} 114