1/* 2 * Licensed to the Apache Software Foundation (ASF) under one or more 3 * contributor license agreements. See the NOTICE file distributed with 4 * this work for additional information regarding copyright ownership. 5 * The ASF licenses this file to You under the Apache License, Version 2.0 6 * (the "License"); you may not use this file except in compliance with 7 * the License. You may obtain a copy of the License at 8 * 9 * http://www.apache.org/licenses/LICENSE-2.0 10 * 11 * Unless required by applicable law or agreed to in writing, software 12 * distributed under the License is distributed on an "AS IS" BASIS, 13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 * See the License for the specific language governing permissions and 15 * limitations under the License. 16 */ 17 18package org.apache.spark 19 20import java.io.File 21 22import scala.reflect.ClassTag 23 24import org.apache.hadoop.fs.Path 25 26import org.apache.spark.rdd._ 27import org.apache.spark.storage.{BlockId, StorageLevel, TestBlockId} 28import org.apache.spark.util.Utils 29 30trait RDDCheckpointTester { self: SparkFunSuite => 31 32 protected val partitioner = new HashPartitioner(2) 33 34 private def defaultCollectFunc[T](rdd: RDD[T]): Any = rdd.collect() 35 36 /** Implementations of this trait must implement this method */ 37 protected def sparkContext: SparkContext 38 39 /** 40 * Test checkpointing of the RDD generated by the given operation. It tests whether the 41 * serialized size of the RDD is reduce after checkpointing or not. This function should be called 42 * on all RDDs that have a parent RDD (i.e., do not call on ParallelCollection, BlockRDD, etc.). 43 * 44 * @param op an operation to run on the RDD 45 * @param reliableCheckpoint if true, use reliable checkpoints, otherwise use local checkpoints 46 * @param collectFunc a function for collecting the values in the RDD, in case there are 47 * non-comparable types like arrays that we want to convert to something 48 * that supports == 49 */ 50 protected def testRDD[U: ClassTag]( 51 op: (RDD[Int]) => RDD[U], 52 reliableCheckpoint: Boolean, 53 collectFunc: RDD[U] => Any = defaultCollectFunc[U] _): Unit = { 54 // Generate the final RDD using given RDD operation 55 val baseRDD = generateFatRDD() 56 val operatedRDD = op(baseRDD) 57 val parentDependency = operatedRDD.dependencies.headOption.orNull 58 val rddType = operatedRDD.getClass.getSimpleName 59 val numPartitions = operatedRDD.partitions.length 60 61 // Force initialization of all the data structures in RDDs 62 // Without this, serializing the RDD will give a wrong estimate of the size of the RDD 63 initializeRdd(operatedRDD) 64 65 val partitionsBeforeCheckpoint = operatedRDD.partitions 66 67 // Find serialized sizes before and after the checkpoint 68 logInfo("RDD before checkpoint: " + operatedRDD + "\n" + operatedRDD.toDebugString) 69 val (rddSizeBeforeCheckpoint, partitionSizeBeforeCheckpoint) = getSerializedSizes(operatedRDD) 70 checkpoint(operatedRDD, reliableCheckpoint) 71 val result = collectFunc(operatedRDD) 72 operatedRDD.collect() // force re-initialization of post-checkpoint lazy variables 73 val (rddSizeAfterCheckpoint, partitionSizeAfterCheckpoint) = getSerializedSizes(operatedRDD) 74 logInfo("RDD after checkpoint: " + operatedRDD + "\n" + operatedRDD.toDebugString) 75 76 // Test whether the checkpoint file has been created 77 if (reliableCheckpoint) { 78 assert(operatedRDD.getCheckpointFile.nonEmpty) 79 val recoveredRDD = sparkContext.checkpointFile[U](operatedRDD.getCheckpointFile.get) 80 assert(collectFunc(recoveredRDD) === result) 81 assert(recoveredRDD.partitioner === operatedRDD.partitioner) 82 } 83 84 // Test whether dependencies have been changed from its earlier parent RDD 85 assert(operatedRDD.dependencies.head != parentDependency) 86 87 // Test whether the partitions have been changed from its earlier partitions 88 assert(operatedRDD.partitions.toList != partitionsBeforeCheckpoint.toList) 89 90 // Test whether the partitions have been changed to the new Hadoop partitions 91 assert(operatedRDD.partitions.toList === operatedRDD.checkpointData.get.getPartitions.toList) 92 93 // Test whether the number of partitions is same as before 94 assert(operatedRDD.partitions.length === numPartitions) 95 96 // Test whether the data in the checkpointed RDD is same as original 97 assert(collectFunc(operatedRDD) === result) 98 99 // Test whether serialized size of the RDD has reduced. 100 logInfo("Size of " + rddType + 101 " [" + rddSizeBeforeCheckpoint + " --> " + rddSizeAfterCheckpoint + "]") 102 assert( 103 rddSizeAfterCheckpoint < rddSizeBeforeCheckpoint, 104 "Size of " + rddType + " did not reduce after checkpointing " + 105 " [" + rddSizeBeforeCheckpoint + " --> " + rddSizeAfterCheckpoint + "]" 106 ) 107 } 108 109 /** 110 * Test whether checkpointing of the parent of the generated RDD also 111 * truncates the lineage or not. Some RDDs like CoGroupedRDD hold on to its parent 112 * RDDs partitions. So even if the parent RDD is checkpointed and its partitions changed, 113 * the generated RDD will remember the partitions and therefore potentially the whole lineage. 114 * This function should be called only those RDD whose partitions refer to parent RDD's 115 * partitions (i.e., do not call it on simple RDD like MappedRDD). 116 * 117 * @param op an operation to run on the RDD 118 * @param reliableCheckpoint if true, use reliable checkpoints, otherwise use local checkpoints 119 * @param collectFunc a function for collecting the values in the RDD, in case there are 120 * non-comparable types like arrays that we want to convert to something 121 * that supports == 122 */ 123 protected def testRDDPartitions[U: ClassTag]( 124 op: (RDD[Int]) => RDD[U], 125 reliableCheckpoint: Boolean, 126 collectFunc: RDD[U] => Any = defaultCollectFunc[U] _): Unit = { 127 // Generate the final RDD using given RDD operation 128 val baseRDD = generateFatRDD() 129 val operatedRDD = op(baseRDD) 130 val parentRDDs = operatedRDD.dependencies.map(_.rdd) 131 val rddType = operatedRDD.getClass.getSimpleName 132 133 // Force initialization of all the data structures in RDDs 134 // Without this, serializing the RDD will give a wrong estimate of the size of the RDD 135 initializeRdd(operatedRDD) 136 137 // Find serialized sizes before and after the checkpoint 138 logInfo("RDD after checkpoint: " + operatedRDD + "\n" + operatedRDD.toDebugString) 139 val (rddSizeBeforeCheckpoint, partitionSizeBeforeCheckpoint) = getSerializedSizes(operatedRDD) 140 // checkpoint the parent RDD, not the generated one 141 parentRDDs.foreach { rdd => 142 checkpoint(rdd, reliableCheckpoint) 143 } 144 val result = collectFunc(operatedRDD) // force checkpointing 145 operatedRDD.collect() // force re-initialization of post-checkpoint lazy variables 146 val (rddSizeAfterCheckpoint, partitionSizeAfterCheckpoint) = getSerializedSizes(operatedRDD) 147 logInfo("RDD after checkpoint: " + operatedRDD + "\n" + operatedRDD.toDebugString) 148 149 // Test whether the data in the checkpointed RDD is same as original 150 assert(collectFunc(operatedRDD) === result) 151 152 // Test whether serialized size of the partitions has reduced 153 logInfo("Size of partitions of " + rddType + 154 " [" + partitionSizeBeforeCheckpoint + " --> " + partitionSizeAfterCheckpoint + "]") 155 assert( 156 partitionSizeAfterCheckpoint < partitionSizeBeforeCheckpoint, 157 "Size of " + rddType + " partitions did not reduce after checkpointing parent RDDs" + 158 " [" + partitionSizeBeforeCheckpoint + " --> " + partitionSizeAfterCheckpoint + "]" 159 ) 160 } 161 162 /** 163 * Get serialized sizes of the RDD and its partitions, in order to test whether the size shrinks 164 * upon checkpointing. Ignores the checkpointData field, which may grow when we checkpoint. 165 */ 166 private def getSerializedSizes(rdd: RDD[_]): (Int, Int) = { 167 val rddSize = Utils.serialize(rdd).size 168 val rddCpDataSize = Utils.serialize(rdd.checkpointData).size 169 val rddPartitionSize = Utils.serialize(rdd.partitions).size 170 val rddDependenciesSize = Utils.serialize(rdd.dependencies).size 171 172 // Print detailed size, helps in debugging 173 logInfo("Serialized sizes of " + rdd + 174 ": RDD = " + rddSize + 175 ", RDD checkpoint data = " + rddCpDataSize + 176 ", RDD partitions = " + rddPartitionSize + 177 ", RDD dependencies = " + rddDependenciesSize 178 ) 179 // this makes sure that serializing the RDD's checkpoint data does not 180 // serialize the whole RDD as well 181 assert( 182 rddSize > rddCpDataSize, 183 "RDD's checkpoint data (" + rddCpDataSize + ") is equal or larger than the " + 184 "whole RDD with checkpoint data (" + rddSize + ")" 185 ) 186 (rddSize - rddCpDataSize, rddPartitionSize) 187 } 188 189 /** 190 * Serialize and deserialize an object. This is useful to verify the objects 191 * contents after deserialization (e.g., the contents of an RDD split after 192 * it is sent to a slave along with a task) 193 */ 194 protected def serializeDeserialize[T](obj: T): T = { 195 val bytes = Utils.serialize(obj) 196 Utils.deserialize[T](bytes) 197 } 198 199 /** 200 * Recursively force the initialization of the all members of an RDD and it parents. 201 */ 202 private def initializeRdd(rdd: RDD[_]): Unit = { 203 rdd.partitions // forces the initialization of the partitions 204 rdd.dependencies.map(_.rdd).foreach(initializeRdd) 205 } 206 207 /** Checkpoint the RDD either locally or reliably. */ 208 protected def checkpoint(rdd: RDD[_], reliableCheckpoint: Boolean): Unit = { 209 if (reliableCheckpoint) { 210 rdd.checkpoint() 211 } else { 212 rdd.localCheckpoint() 213 } 214 } 215 216 /** Run a test twice, once for local checkpointing and once for reliable checkpointing. */ 217 protected def runTest( 218 name: String, 219 skipLocalCheckpoint: Boolean = false 220 )(body: Boolean => Unit): Unit = { 221 test(name + " [reliable checkpoint]")(body(true)) 222 if (!skipLocalCheckpoint) { 223 test(name + " [local checkpoint]")(body(false)) 224 } 225 } 226 227 /** 228 * Generate an RDD such that both the RDD and its partitions have large size. 229 */ 230 protected def generateFatRDD(): RDD[Int] = { 231 new FatRDD(sparkContext.makeRDD(1 to 100, 4)).map(x => x) 232 } 233 234 /** 235 * Generate an pair RDD (with partitioner) such that both the RDD and its partitions 236 * have large size. 237 */ 238 protected def generateFatPairRDD(): RDD[(Int, Int)] = { 239 new FatPairRDD(sparkContext.makeRDD(1 to 100, 4), partitioner).mapValues(x => x) 240 } 241} 242 243/** 244 * Test suite for end-to-end checkpointing functionality. 245 * This tests both reliable checkpoints and local checkpoints. 246 */ 247class CheckpointSuite extends SparkFunSuite with RDDCheckpointTester with LocalSparkContext { 248 private var checkpointDir: File = _ 249 250 override def beforeEach(): Unit = { 251 super.beforeEach() 252 checkpointDir = File.createTempFile("temp", "", Utils.createTempDir()) 253 checkpointDir.delete() 254 sc = new SparkContext("local", "test") 255 sc.setCheckpointDir(checkpointDir.toString) 256 } 257 258 override def afterEach(): Unit = { 259 try { 260 Utils.deleteRecursively(checkpointDir) 261 } finally { 262 super.afterEach() 263 } 264 } 265 266 override def sparkContext: SparkContext = sc 267 268 runTest("basic checkpointing") { reliableCheckpoint: Boolean => 269 val parCollection = sc.makeRDD(1 to 4) 270 val flatMappedRDD = parCollection.flatMap(x => 1 to x) 271 checkpoint(flatMappedRDD, reliableCheckpoint) 272 assert(flatMappedRDD.dependencies.head.rdd === parCollection) 273 val result = flatMappedRDD.collect() 274 assert(flatMappedRDD.dependencies.head.rdd != parCollection) 275 assert(flatMappedRDD.collect() === result) 276 } 277 278 runTest("checkpointing partitioners", skipLocalCheckpoint = true) { _: Boolean => 279 280 def testPartitionerCheckpointing( 281 partitioner: Partitioner, 282 corruptPartitionerFile: Boolean = false 283 ): Unit = { 284 val rddWithPartitioner = sc.makeRDD(1 to 4).map { _ -> 1 }.partitionBy(partitioner) 285 rddWithPartitioner.checkpoint() 286 rddWithPartitioner.count() 287 assert(rddWithPartitioner.getCheckpointFile.get.nonEmpty, 288 "checkpointing was not successful") 289 290 if (corruptPartitionerFile) { 291 // Overwrite the partitioner file with garbage data 292 val checkpointDir = new Path(rddWithPartitioner.getCheckpointFile.get) 293 val fs = checkpointDir.getFileSystem(sc.hadoopConfiguration) 294 val partitionerFile = fs.listStatus(checkpointDir) 295 .find(_.getPath.getName.contains("partitioner")) 296 .map(_.getPath) 297 require(partitionerFile.nonEmpty, "could not find the partitioner file for testing") 298 val output = fs.create(partitionerFile.get, true) 299 output.write(100) 300 output.close() 301 } 302 303 val newRDD = sc.checkpointFile[(Int, Int)](rddWithPartitioner.getCheckpointFile.get) 304 assert(newRDD.collect().toSet === rddWithPartitioner.collect().toSet, "RDD not recovered") 305 306 if (!corruptPartitionerFile) { 307 assert(newRDD.partitioner != None, "partitioner not recovered") 308 assert(newRDD.partitioner === rddWithPartitioner.partitioner, 309 "recovered partitioner does not match") 310 } else { 311 assert(newRDD.partitioner == None, "partitioner unexpectedly recovered") 312 } 313 } 314 315 testPartitionerCheckpointing(partitioner) 316 317 // Test that corrupted partitioner file does not prevent recovery of RDD 318 testPartitionerCheckpointing(partitioner, corruptPartitionerFile = true) 319 } 320 321 runTest("RDDs with one-to-one dependencies") { reliableCheckpoint: Boolean => 322 testRDD(_.map(x => x.toString), reliableCheckpoint) 323 testRDD(_.flatMap(x => 1 to x), reliableCheckpoint) 324 testRDD(_.filter(_ % 2 == 0), reliableCheckpoint) 325 testRDD(_.sample(false, 0.5, 0), reliableCheckpoint) 326 testRDD(_.glom(), reliableCheckpoint) 327 testRDD(_.mapPartitions(_.map(_.toString)), reliableCheckpoint) 328 testRDD(_.map(x => (x % 2, 1)).reduceByKey(_ + _).mapValues(_.toString), reliableCheckpoint) 329 testRDD(_.map(x => (x % 2, 1)).reduceByKey(_ + _).flatMapValues(x => 1 to x), 330 reliableCheckpoint) 331 testRDD(_.pipe(Seq("cat")), reliableCheckpoint) 332 } 333 334 runTest("ParallelCollectionRDD") { reliableCheckpoint: Boolean => 335 val parCollection = sc.makeRDD(1 to 4, 2) 336 val numPartitions = parCollection.partitions.size 337 checkpoint(parCollection, reliableCheckpoint) 338 assert(parCollection.dependencies === Nil) 339 val result = parCollection.collect() 340 if (reliableCheckpoint) { 341 assert(sc.checkpointFile[Int](parCollection.getCheckpointFile.get).collect() === result) 342 } 343 assert(parCollection.dependencies != Nil) 344 assert(parCollection.partitions.length === numPartitions) 345 assert(parCollection.partitions.toList === 346 parCollection.checkpointData.get.getPartitions.toList) 347 assert(parCollection.collect() === result) 348 } 349 350 runTest("BlockRDD") { reliableCheckpoint: Boolean => 351 val blockId = TestBlockId("id") 352 val blockManager = SparkEnv.get.blockManager 353 blockManager.putSingle(blockId, "test", StorageLevel.MEMORY_ONLY) 354 val blockRDD = new BlockRDD[String](sc, Array(blockId)) 355 val numPartitions = blockRDD.partitions.size 356 checkpoint(blockRDD, reliableCheckpoint) 357 val result = blockRDD.collect() 358 if (reliableCheckpoint) { 359 assert(sc.checkpointFile[String](blockRDD.getCheckpointFile.get).collect() === result) 360 } 361 assert(blockRDD.dependencies != Nil) 362 assert(blockRDD.partitions.length === numPartitions) 363 assert(blockRDD.partitions.toList === blockRDD.checkpointData.get.getPartitions.toList) 364 assert(blockRDD.collect() === result) 365 } 366 367 runTest("ShuffleRDD") { reliableCheckpoint: Boolean => 368 testRDD(rdd => { 369 // Creating ShuffledRDD directly as PairRDDFunctions.combineByKey produces a MapPartitionedRDD 370 new ShuffledRDD[Int, Int, Int](rdd.map(x => (x % 2, 1)), partitioner) 371 }, reliableCheckpoint) 372 } 373 374 runTest("UnionRDD") { reliableCheckpoint: Boolean => 375 def otherRDD: RDD[Int] = sc.makeRDD(1 to 10, 1) 376 testRDD(_.union(otherRDD), reliableCheckpoint) 377 testRDDPartitions(_.union(otherRDD), reliableCheckpoint) 378 } 379 380 runTest("CartesianRDD") { reliableCheckpoint: Boolean => 381 def otherRDD: RDD[Int] = sc.makeRDD(1 to 10, 1) 382 testRDD(new CartesianRDD(sc, _, otherRDD), reliableCheckpoint) 383 testRDDPartitions(new CartesianRDD(sc, _, otherRDD), reliableCheckpoint) 384 385 // Test that the CartesianRDD updates parent partitions (CartesianRDD.s1/s2) after 386 // the parent RDD has been checkpointed and parent partitions have been changed. 387 // Note that this test is very specific to the current implementation of CartesianRDD. 388 val ones = sc.makeRDD(1 to 100, 10).map(x => x) 389 checkpoint(ones, reliableCheckpoint) // checkpoint that MappedRDD 390 val cartesian = new CartesianRDD(sc, ones, ones) 391 val splitBeforeCheckpoint = 392 serializeDeserialize(cartesian.partitions.head.asInstanceOf[CartesianPartition]) 393 cartesian.count() // do the checkpointing 394 val splitAfterCheckpoint = 395 serializeDeserialize(cartesian.partitions.head.asInstanceOf[CartesianPartition]) 396 assert( 397 (splitAfterCheckpoint.s1.getClass != splitBeforeCheckpoint.s1.getClass) && 398 (splitAfterCheckpoint.s2.getClass != splitBeforeCheckpoint.s2.getClass), 399 "CartesianRDD.s1 and CartesianRDD.s2 not updated after parent RDD is checkpointed" 400 ) 401 } 402 403 runTest("CoalescedRDD") { reliableCheckpoint: Boolean => 404 testRDD(_.coalesce(2), reliableCheckpoint) 405 testRDDPartitions(_.coalesce(2), reliableCheckpoint) 406 407 // Test that the CoalescedRDDPartition updates parent partitions (CoalescedRDDPartition.parents) 408 // after the parent RDD has been checkpointed and parent partitions have been changed. 409 // Note that this test is very specific to the current implementation of 410 // CoalescedRDDPartitions. 411 val ones = sc.makeRDD(1 to 100, 10).map(x => x) 412 checkpoint(ones, reliableCheckpoint) // checkpoint that MappedRDD 413 val coalesced = new CoalescedRDD(ones, 2) 414 val splitBeforeCheckpoint = 415 serializeDeserialize(coalesced.partitions.head.asInstanceOf[CoalescedRDDPartition]) 416 coalesced.count() // do the checkpointing 417 val splitAfterCheckpoint = 418 serializeDeserialize(coalesced.partitions.head.asInstanceOf[CoalescedRDDPartition]) 419 assert( 420 splitAfterCheckpoint.parents.head.getClass != splitBeforeCheckpoint.parents.head.getClass, 421 "CoalescedRDDPartition.parents not updated after parent RDD is checkpointed" 422 ) 423 } 424 425 runTest("CoGroupedRDD") { reliableCheckpoint: Boolean => 426 val longLineageRDD1 = generateFatPairRDD() 427 428 // Collect the RDD as sequences instead of arrays to enable equality tests in testRDD 429 val seqCollectFunc = (rdd: RDD[(Int, Array[Iterable[Int]])]) => 430 rdd.map{case (p, a) => (p, a.toSeq)}.collect(): Any 431 432 testRDD(rdd => { 433 CheckpointSuite.cogroup(longLineageRDD1, rdd.map(x => (x % 2, 1)), partitioner) 434 }, reliableCheckpoint, seqCollectFunc) 435 436 val longLineageRDD2 = generateFatPairRDD() 437 testRDDPartitions(rdd => { 438 CheckpointSuite.cogroup( 439 longLineageRDD2, sc.makeRDD(1 to 2, 2).map(x => (x % 2, 1)), partitioner) 440 }, reliableCheckpoint, seqCollectFunc) 441 } 442 443 runTest("ZippedPartitionsRDD") { reliableCheckpoint: Boolean => 444 testRDD(rdd => rdd.zip(rdd.map(x => x)), reliableCheckpoint) 445 testRDDPartitions(rdd => rdd.zip(rdd.map(x => x)), reliableCheckpoint) 446 447 // Test that ZippedPartitionsRDD updates parent partitions after parent RDDs have 448 // been checkpointed and parent partitions have been changed. 449 // Note that this test is very specific to the implementation of ZippedPartitionsRDD. 450 val rdd = generateFatRDD() 451 val zippedRDD = rdd.zip(rdd.map(x => x)).asInstanceOf[ZippedPartitionsRDD2[_, _, _]] 452 checkpoint(zippedRDD.rdd1, reliableCheckpoint) 453 checkpoint(zippedRDD.rdd2, reliableCheckpoint) 454 val partitionBeforeCheckpoint = 455 serializeDeserialize(zippedRDD.partitions.head.asInstanceOf[ZippedPartitionsPartition]) 456 zippedRDD.count() 457 val partitionAfterCheckpoint = 458 serializeDeserialize(zippedRDD.partitions.head.asInstanceOf[ZippedPartitionsPartition]) 459 assert( 460 partitionAfterCheckpoint.partitions(0).getClass != 461 partitionBeforeCheckpoint.partitions(0).getClass && 462 partitionAfterCheckpoint.partitions(1).getClass != 463 partitionBeforeCheckpoint.partitions(1).getClass, 464 "ZippedPartitionsRDD partition 0 (or 1) not updated after parent RDDs are checkpointed" 465 ) 466 } 467 468 runTest("PartitionerAwareUnionRDD") { reliableCheckpoint: Boolean => 469 testRDD(rdd => { 470 new PartitionerAwareUnionRDD[(Int, Int)](sc, Array( 471 generateFatPairRDD(), 472 rdd.map(x => (x % 2, 1)).reduceByKey(partitioner, _ + _) 473 )) 474 }, reliableCheckpoint) 475 476 testRDDPartitions(rdd => { 477 new PartitionerAwareUnionRDD[(Int, Int)](sc, Array( 478 generateFatPairRDD(), 479 rdd.map(x => (x % 2, 1)).reduceByKey(partitioner, _ + _) 480 )) 481 }, reliableCheckpoint) 482 483 // Test that the PartitionerAwareUnionRDD updates parent partitions 484 // (PartitionerAwareUnionRDD.parents) after the parent RDD has been checkpointed and parent 485 // partitions have been changed. Note that this test is very specific to the current 486 // implementation of PartitionerAwareUnionRDD. 487 val pairRDD = generateFatPairRDD() 488 checkpoint(pairRDD, reliableCheckpoint) 489 val unionRDD = new PartitionerAwareUnionRDD(sc, Array(pairRDD)) 490 val partitionBeforeCheckpoint = serializeDeserialize( 491 unionRDD.partitions.head.asInstanceOf[PartitionerAwareUnionRDDPartition]) 492 pairRDD.count() 493 val partitionAfterCheckpoint = serializeDeserialize( 494 unionRDD.partitions.head.asInstanceOf[PartitionerAwareUnionRDDPartition]) 495 assert( 496 partitionBeforeCheckpoint.parents.head.getClass != 497 partitionAfterCheckpoint.parents.head.getClass, 498 "PartitionerAwareUnionRDDPartition.parents not updated after parent RDD is checkpointed" 499 ) 500 } 501 502 runTest("CheckpointRDD with zero partitions") { reliableCheckpoint: Boolean => 503 val rdd = new BlockRDD[Int](sc, Array.empty[BlockId]) 504 assert(rdd.partitions.size === 0) 505 assert(rdd.isCheckpointed === false) 506 assert(rdd.isCheckpointedAndMaterialized === false) 507 checkpoint(rdd, reliableCheckpoint) 508 assert(rdd.isCheckpointed === false) 509 assert(rdd.isCheckpointedAndMaterialized === false) 510 assert(rdd.count() === 0) 511 assert(rdd.isCheckpointed === true) 512 assert(rdd.isCheckpointedAndMaterialized === true) 513 assert(rdd.partitions.size === 0) 514 } 515 516 runTest("checkpointAllMarkedAncestors") { reliableCheckpoint: Boolean => 517 testCheckpointAllMarkedAncestors(reliableCheckpoint, checkpointAllMarkedAncestors = true) 518 testCheckpointAllMarkedAncestors(reliableCheckpoint, checkpointAllMarkedAncestors = false) 519 } 520 521 private def testCheckpointAllMarkedAncestors( 522 reliableCheckpoint: Boolean, checkpointAllMarkedAncestors: Boolean): Unit = { 523 sc.setLocalProperty(RDD.CHECKPOINT_ALL_MARKED_ANCESTORS, checkpointAllMarkedAncestors.toString) 524 try { 525 val rdd1 = sc.parallelize(1 to 10) 526 checkpoint(rdd1, reliableCheckpoint) 527 val rdd2 = rdd1.map(_ + 1) 528 checkpoint(rdd2, reliableCheckpoint) 529 rdd2.count() 530 assert(rdd1.isCheckpointed === checkpointAllMarkedAncestors) 531 assert(rdd2.isCheckpointed === true) 532 } finally { 533 sc.setLocalProperty(RDD.CHECKPOINT_ALL_MARKED_ANCESTORS, null) 534 } 535 } 536} 537 538/** RDD partition that has large serialized size. */ 539class FatPartition(val partition: Partition) extends Partition { 540 val bigData = new Array[Byte](10000) 541 def index: Int = partition.index 542} 543 544/** RDD that has large serialized size. */ 545class FatRDD(parent: RDD[Int]) extends RDD[Int](parent) { 546 val bigData = new Array[Byte](100000) 547 548 protected def getPartitions: Array[Partition] = { 549 parent.partitions.map(p => new FatPartition(p)) 550 } 551 552 def compute(split: Partition, context: TaskContext): Iterator[Int] = { 553 parent.compute(split.asInstanceOf[FatPartition].partition, context) 554 } 555} 556 557/** Pair RDD that has large serialized size. */ 558class FatPairRDD(parent: RDD[Int], _partitioner: Partitioner) extends RDD[(Int, Int)](parent) { 559 val bigData = new Array[Byte](100000) 560 561 protected def getPartitions: Array[Partition] = { 562 parent.partitions.map(p => new FatPartition(p)) 563 } 564 565 @transient override val partitioner = Some(_partitioner) 566 567 def compute(split: Partition, context: TaskContext): Iterator[(Int, Int)] = { 568 parent.compute(split.asInstanceOf[FatPartition].partition, context).map(x => (x, x)) 569 } 570} 571 572object CheckpointSuite { 573 // This is a custom cogroup function that does not use mapValues like 574 // the PairRDDFunctions.cogroup() 575 def cogroup[K: ClassTag, V: ClassTag](first: RDD[(K, V)], second: RDD[(K, V)], part: Partitioner) 576 : RDD[(K, Array[Iterable[V]])] = { 577 new CoGroupedRDD[K]( 578 Seq(first.asInstanceOf[RDD[(K, _)]], second.asInstanceOf[RDD[(K, _)]]), 579 part 580 ).asInstanceOf[RDD[(K, Array[Iterable[V]])]] 581 } 582} 583