1/* 2 * Licensed to the Apache Software Foundation (ASF) under one or more 3 * contributor license agreements. See the NOTICE file distributed with 4 * this work for additional information regarding copyright ownership. 5 * The ASF licenses this file to You under the Apache License, Version 2.0 6 * (the "License"); you may not use this file except in compliance with 7 * the License. You may obtain a copy of the License at 8 * 9 * http://www.apache.org/licenses/LICENSE-2.0 10 * 11 * Unless required by applicable law or agreed to in writing, software 12 * distributed under the License is distributed on an "AS IS" BASIS, 13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 * See the License for the specific language governing permissions and 15 * limitations under the License. 16 */ 17 18package org.apache.spark.rdd 19 20import java.io.{IOException, ObjectOutputStream} 21 22import scala.reflect.ClassTag 23 24import org.apache.spark.{OneToOneDependency, Partition, SparkContext, TaskContext} 25import org.apache.spark.util.Utils 26 27private[spark] class ZippedPartitionsPartition( 28 idx: Int, 29 @transient private val rdds: Seq[RDD[_]], 30 @transient val preferredLocations: Seq[String]) 31 extends Partition { 32 33 override val index: Int = idx 34 var partitionValues = rdds.map(rdd => rdd.partitions(idx)) 35 def partitions: Seq[Partition] = partitionValues 36 37 @throws(classOf[IOException]) 38 private def writeObject(oos: ObjectOutputStream): Unit = Utils.tryOrIOException { 39 // Update the reference to parent split at the time of task serialization 40 partitionValues = rdds.map(rdd => rdd.partitions(idx)) 41 oos.defaultWriteObject() 42 } 43} 44 45private[spark] abstract class ZippedPartitionsBaseRDD[V: ClassTag]( 46 sc: SparkContext, 47 var rdds: Seq[RDD[_]], 48 preservesPartitioning: Boolean = false) 49 extends RDD[V](sc, rdds.map(x => new OneToOneDependency(x))) { 50 51 override val partitioner = 52 if (preservesPartitioning) firstParent[Any].partitioner else None 53 54 override def getPartitions: Array[Partition] = { 55 val numParts = rdds.head.partitions.length 56 if (!rdds.forall(rdd => rdd.partitions.length == numParts)) { 57 throw new IllegalArgumentException( 58 s"Can't zip RDDs with unequal numbers of partitions: ${rdds.map(_.partitions.length)}") 59 } 60 Array.tabulate[Partition](numParts) { i => 61 val prefs = rdds.map(rdd => rdd.preferredLocations(rdd.partitions(i))) 62 // Check whether there are any hosts that match all RDDs; otherwise return the union 63 val exactMatchLocations = prefs.reduce((x, y) => x.intersect(y)) 64 val locs = if (!exactMatchLocations.isEmpty) exactMatchLocations else prefs.flatten.distinct 65 new ZippedPartitionsPartition(i, rdds, locs) 66 } 67 } 68 69 override def getPreferredLocations(s: Partition): Seq[String] = { 70 s.asInstanceOf[ZippedPartitionsPartition].preferredLocations 71 } 72 73 override def clearDependencies() { 74 super.clearDependencies() 75 rdds = null 76 } 77} 78 79private[spark] class ZippedPartitionsRDD2[A: ClassTag, B: ClassTag, V: ClassTag]( 80 sc: SparkContext, 81 var f: (Iterator[A], Iterator[B]) => Iterator[V], 82 var rdd1: RDD[A], 83 var rdd2: RDD[B], 84 preservesPartitioning: Boolean = false) 85 extends ZippedPartitionsBaseRDD[V](sc, List(rdd1, rdd2), preservesPartitioning) { 86 87 override def compute(s: Partition, context: TaskContext): Iterator[V] = { 88 val partitions = s.asInstanceOf[ZippedPartitionsPartition].partitions 89 f(rdd1.iterator(partitions(0), context), rdd2.iterator(partitions(1), context)) 90 } 91 92 override def clearDependencies() { 93 super.clearDependencies() 94 rdd1 = null 95 rdd2 = null 96 f = null 97 } 98} 99 100private[spark] class ZippedPartitionsRDD3 101 [A: ClassTag, B: ClassTag, C: ClassTag, V: ClassTag]( 102 sc: SparkContext, 103 var f: (Iterator[A], Iterator[B], Iterator[C]) => Iterator[V], 104 var rdd1: RDD[A], 105 var rdd2: RDD[B], 106 var rdd3: RDD[C], 107 preservesPartitioning: Boolean = false) 108 extends ZippedPartitionsBaseRDD[V](sc, List(rdd1, rdd2, rdd3), preservesPartitioning) { 109 110 override def compute(s: Partition, context: TaskContext): Iterator[V] = { 111 val partitions = s.asInstanceOf[ZippedPartitionsPartition].partitions 112 f(rdd1.iterator(partitions(0), context), 113 rdd2.iterator(partitions(1), context), 114 rdd3.iterator(partitions(2), context)) 115 } 116 117 override def clearDependencies() { 118 super.clearDependencies() 119 rdd1 = null 120 rdd2 = null 121 rdd3 = null 122 f = null 123 } 124} 125 126private[spark] class ZippedPartitionsRDD4 127 [A: ClassTag, B: ClassTag, C: ClassTag, D: ClassTag, V: ClassTag]( 128 sc: SparkContext, 129 var f: (Iterator[A], Iterator[B], Iterator[C], Iterator[D]) => Iterator[V], 130 var rdd1: RDD[A], 131 var rdd2: RDD[B], 132 var rdd3: RDD[C], 133 var rdd4: RDD[D], 134 preservesPartitioning: Boolean = false) 135 extends ZippedPartitionsBaseRDD[V](sc, List(rdd1, rdd2, rdd3, rdd4), preservesPartitioning) { 136 137 override def compute(s: Partition, context: TaskContext): Iterator[V] = { 138 val partitions = s.asInstanceOf[ZippedPartitionsPartition].partitions 139 f(rdd1.iterator(partitions(0), context), 140 rdd2.iterator(partitions(1), context), 141 rdd3.iterator(partitions(2), context), 142 rdd4.iterator(partitions(3), context)) 143 } 144 145 override def clearDependencies() { 146 super.clearDependencies() 147 rdd1 = null 148 rdd2 = null 149 rdd3 = null 150 rdd4 = null 151 f = null 152 } 153} 154