1/*
2 * Licensed to the Apache Software Foundation (ASF) under one or more
3 * contributor license agreements.  See the NOTICE file distributed with
4 * this work for additional information regarding copyright ownership.
5 * The ASF licenses this file to You under the Apache License, Version 2.0
6 * (the "License"); you may not use this file except in compliance with
7 * the License.  You may obtain a copy of the License at
8 *
9 *    http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17
18package org.apache.spark.rdd
19
20import java.io.{IOException, ObjectOutputStream}
21
22import scala.reflect.ClassTag
23
24import org.apache.spark.{OneToOneDependency, Partition, SparkContext, TaskContext}
25import org.apache.spark.util.Utils
26
27/**
28 * Class representing partitions of PartitionerAwareUnionRDD, which maintains the list of
29 * corresponding partitions of parent RDDs.
30 */
31private[spark]
32class PartitionerAwareUnionRDDPartition(
33    @transient val rdds: Seq[RDD[_]],
34    override val index: Int
35  ) extends Partition {
36  var parents = rdds.map(_.partitions(index)).toArray
37
38  override def hashCode(): Int = index
39
40  override def equals(other: Any): Boolean = super.equals(other)
41
42  @throws(classOf[IOException])
43  private def writeObject(oos: ObjectOutputStream): Unit = Utils.tryOrIOException {
44    // Update the reference to parent partition at the time of task serialization
45    parents = rdds.map(_.partitions(index)).toArray
46    oos.defaultWriteObject()
47  }
48}
49
50/**
51 * Class representing an RDD that can take multiple RDDs partitioned by the same partitioner and
52 * unify them into a single RDD while preserving the partitioner. So m RDDs with p partitions each
53 * will be unified to a single RDD with p partitions and the same partitioner. The preferred
54 * location for each partition of the unified RDD will be the most common preferred location
55 * of the corresponding partitions of the parent RDDs. For example, location of partition 0
56 * of the unified RDD will be where most of partition 0 of the parent RDDs are located.
57 */
58private[spark]
59class PartitionerAwareUnionRDD[T: ClassTag](
60    sc: SparkContext,
61    var rdds: Seq[RDD[T]]
62  ) extends RDD[T](sc, rdds.map(x => new OneToOneDependency(x))) {
63  require(rdds.nonEmpty)
64  require(rdds.forall(_.partitioner.isDefined))
65  require(rdds.flatMap(_.partitioner).toSet.size == 1,
66    "Parent RDDs have different partitioners: " + rdds.flatMap(_.partitioner))
67
68  override val partitioner = rdds.head.partitioner
69
70  override def getPartitions: Array[Partition] = {
71    val numPartitions = partitioner.get.numPartitions
72    (0 until numPartitions).map { index =>
73      new PartitionerAwareUnionRDDPartition(rdds, index)
74    }.toArray
75  }
76
77  // Get the location where most of the partitions of parent RDDs are located
78  override def getPreferredLocations(s: Partition): Seq[String] = {
79    logDebug("Finding preferred location for " + this + ", partition " + s.index)
80    val parentPartitions = s.asInstanceOf[PartitionerAwareUnionRDDPartition].parents
81    val locations = rdds.zip(parentPartitions).flatMap {
82      case (rdd, part) =>
83        val parentLocations = currPrefLocs(rdd, part)
84        logDebug("Location of " + rdd + " partition " + part.index + " = " + parentLocations)
85        parentLocations
86    }
87    val location = if (locations.isEmpty) {
88      None
89    } else {
90      // Find the location that maximum number of parent partitions prefer
91      Some(locations.groupBy(x => x).maxBy(_._2.length)._1)
92    }
93    logDebug("Selected location for " + this + ", partition " + s.index + " = " + location)
94    location.toSeq
95  }
96
97  override def compute(s: Partition, context: TaskContext): Iterator[T] = {
98    val parentPartitions = s.asInstanceOf[PartitionerAwareUnionRDDPartition].parents
99    rdds.zip(parentPartitions).iterator.flatMap {
100      case (rdd, p) => rdd.iterator(p, context)
101    }
102  }
103
104  override def clearDependencies() {
105    super.clearDependencies()
106    rdds = null
107  }
108
109  // Get the *current* preferred locations from the DAGScheduler (as opposed to the static ones)
110  private def currPrefLocs(rdd: RDD[_], part: Partition): Seq[String] = {
111    rdd.context.getPreferredLocs(rdd, part.index).map(tl => tl.host)
112  }
113}
114