1/*
2 * Licensed to the Apache Software Foundation (ASF) under one or more
3 * contributor license agreements.  See the NOTICE file distributed with
4 * this work for additional information regarding copyright ownership.
5 * The ASF licenses this file to You under the Apache License, Version 2.0
6 * (the "License"); you may not use this file except in compliance with
7 * the License.  You may obtain a copy of the License at
8 *
9 *    http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17
18package org.apache.spark.rdd
19
20import java.io.{IOException, ObjectOutputStream}
21
22import scala.reflect.ClassTag
23
24import org.apache.spark.{OneToOneDependency, Partition, SparkContext, TaskContext}
25import org.apache.spark.util.Utils
26
27private[spark] class ZippedPartitionsPartition(
28    idx: Int,
29    @transient private val rdds: Seq[RDD[_]],
30    @transient val preferredLocations: Seq[String])
31  extends Partition {
32
33  override val index: Int = idx
34  var partitionValues = rdds.map(rdd => rdd.partitions(idx))
35  def partitions: Seq[Partition] = partitionValues
36
37  @throws(classOf[IOException])
38  private def writeObject(oos: ObjectOutputStream): Unit = Utils.tryOrIOException {
39    // Update the reference to parent split at the time of task serialization
40    partitionValues = rdds.map(rdd => rdd.partitions(idx))
41    oos.defaultWriteObject()
42  }
43}
44
45private[spark] abstract class ZippedPartitionsBaseRDD[V: ClassTag](
46    sc: SparkContext,
47    var rdds: Seq[RDD[_]],
48    preservesPartitioning: Boolean = false)
49  extends RDD[V](sc, rdds.map(x => new OneToOneDependency(x))) {
50
51  override val partitioner =
52    if (preservesPartitioning) firstParent[Any].partitioner else None
53
54  override def getPartitions: Array[Partition] = {
55    val numParts = rdds.head.partitions.length
56    if (!rdds.forall(rdd => rdd.partitions.length == numParts)) {
57      throw new IllegalArgumentException(
58        s"Can't zip RDDs with unequal numbers of partitions: ${rdds.map(_.partitions.length)}")
59    }
60    Array.tabulate[Partition](numParts) { i =>
61      val prefs = rdds.map(rdd => rdd.preferredLocations(rdd.partitions(i)))
62      // Check whether there are any hosts that match all RDDs; otherwise return the union
63      val exactMatchLocations = prefs.reduce((x, y) => x.intersect(y))
64      val locs = if (!exactMatchLocations.isEmpty) exactMatchLocations else prefs.flatten.distinct
65      new ZippedPartitionsPartition(i, rdds, locs)
66    }
67  }
68
69  override def getPreferredLocations(s: Partition): Seq[String] = {
70    s.asInstanceOf[ZippedPartitionsPartition].preferredLocations
71  }
72
73  override def clearDependencies() {
74    super.clearDependencies()
75    rdds = null
76  }
77}
78
79private[spark] class ZippedPartitionsRDD2[A: ClassTag, B: ClassTag, V: ClassTag](
80    sc: SparkContext,
81    var f: (Iterator[A], Iterator[B]) => Iterator[V],
82    var rdd1: RDD[A],
83    var rdd2: RDD[B],
84    preservesPartitioning: Boolean = false)
85  extends ZippedPartitionsBaseRDD[V](sc, List(rdd1, rdd2), preservesPartitioning) {
86
87  override def compute(s: Partition, context: TaskContext): Iterator[V] = {
88    val partitions = s.asInstanceOf[ZippedPartitionsPartition].partitions
89    f(rdd1.iterator(partitions(0), context), rdd2.iterator(partitions(1), context))
90  }
91
92  override def clearDependencies() {
93    super.clearDependencies()
94    rdd1 = null
95    rdd2 = null
96    f = null
97  }
98}
99
100private[spark] class ZippedPartitionsRDD3
101  [A: ClassTag, B: ClassTag, C: ClassTag, V: ClassTag](
102    sc: SparkContext,
103    var f: (Iterator[A], Iterator[B], Iterator[C]) => Iterator[V],
104    var rdd1: RDD[A],
105    var rdd2: RDD[B],
106    var rdd3: RDD[C],
107    preservesPartitioning: Boolean = false)
108  extends ZippedPartitionsBaseRDD[V](sc, List(rdd1, rdd2, rdd3), preservesPartitioning) {
109
110  override def compute(s: Partition, context: TaskContext): Iterator[V] = {
111    val partitions = s.asInstanceOf[ZippedPartitionsPartition].partitions
112    f(rdd1.iterator(partitions(0), context),
113      rdd2.iterator(partitions(1), context),
114      rdd3.iterator(partitions(2), context))
115  }
116
117  override def clearDependencies() {
118    super.clearDependencies()
119    rdd1 = null
120    rdd2 = null
121    rdd3 = null
122    f = null
123  }
124}
125
126private[spark] class ZippedPartitionsRDD4
127  [A: ClassTag, B: ClassTag, C: ClassTag, D: ClassTag, V: ClassTag](
128    sc: SparkContext,
129    var f: (Iterator[A], Iterator[B], Iterator[C], Iterator[D]) => Iterator[V],
130    var rdd1: RDD[A],
131    var rdd2: RDD[B],
132    var rdd3: RDD[C],
133    var rdd4: RDD[D],
134    preservesPartitioning: Boolean = false)
135  extends ZippedPartitionsBaseRDD[V](sc, List(rdd1, rdd2, rdd3, rdd4), preservesPartitioning) {
136
137  override def compute(s: Partition, context: TaskContext): Iterator[V] = {
138    val partitions = s.asInstanceOf[ZippedPartitionsPartition].partitions
139    f(rdd1.iterator(partitions(0), context),
140      rdd2.iterator(partitions(1), context),
141      rdd3.iterator(partitions(2), context),
142      rdd4.iterator(partitions(3), context))
143  }
144
145  override def clearDependencies() {
146    super.clearDependencies()
147    rdd1 = null
148    rdd2 = null
149    rdd3 = null
150    rdd4 = null
151    f = null
152  }
153}
154