1/*
2 * Licensed to the Apache Software Foundation (ASF) under one or more
3 * contributor license agreements.  See the NOTICE file distributed with
4 * this work for additional information regarding copyright ownership.
5 * The ASF licenses this file to You under the Apache License, Version 2.0
6 * (the "License"); you may not use this file except in compliance with
7 * the License.  You may obtain a copy of the License at
8 *
9 *    http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17
18package org.apache.spark
19
20import java.io.File
21
22import scala.reflect.ClassTag
23
24import org.apache.hadoop.fs.Path
25
26import org.apache.spark.rdd._
27import org.apache.spark.storage.{BlockId, StorageLevel, TestBlockId}
28import org.apache.spark.util.Utils
29
30trait RDDCheckpointTester { self: SparkFunSuite =>
31
32  protected val partitioner = new HashPartitioner(2)
33
34  private def defaultCollectFunc[T](rdd: RDD[T]): Any = rdd.collect()
35
36  /** Implementations of this trait must implement this method */
37  protected def sparkContext: SparkContext
38
39  /**
40   * Test checkpointing of the RDD generated by the given operation. It tests whether the
41   * serialized size of the RDD is reduce after checkpointing or not. This function should be called
42   * on all RDDs that have a parent RDD (i.e., do not call on ParallelCollection, BlockRDD, etc.).
43   *
44   * @param op an operation to run on the RDD
45   * @param reliableCheckpoint if true, use reliable checkpoints, otherwise use local checkpoints
46   * @param collectFunc a function for collecting the values in the RDD, in case there are
47   *                    non-comparable types like arrays that we want to convert to something
48   *                    that supports ==
49   */
50  protected def testRDD[U: ClassTag](
51      op: (RDD[Int]) => RDD[U],
52      reliableCheckpoint: Boolean,
53      collectFunc: RDD[U] => Any = defaultCollectFunc[U] _): Unit = {
54    // Generate the final RDD using given RDD operation
55    val baseRDD = generateFatRDD()
56    val operatedRDD = op(baseRDD)
57    val parentDependency = operatedRDD.dependencies.headOption.orNull
58    val rddType = operatedRDD.getClass.getSimpleName
59    val numPartitions = operatedRDD.partitions.length
60
61    // Force initialization of all the data structures in RDDs
62    // Without this, serializing the RDD will give a wrong estimate of the size of the RDD
63    initializeRdd(operatedRDD)
64
65    val partitionsBeforeCheckpoint = operatedRDD.partitions
66
67    // Find serialized sizes before and after the checkpoint
68    logInfo("RDD before checkpoint: " + operatedRDD + "\n" + operatedRDD.toDebugString)
69    val (rddSizeBeforeCheckpoint, partitionSizeBeforeCheckpoint) = getSerializedSizes(operatedRDD)
70    checkpoint(operatedRDD, reliableCheckpoint)
71    val result = collectFunc(operatedRDD)
72    operatedRDD.collect() // force re-initialization of post-checkpoint lazy variables
73    val (rddSizeAfterCheckpoint, partitionSizeAfterCheckpoint) = getSerializedSizes(operatedRDD)
74    logInfo("RDD after checkpoint: " + operatedRDD + "\n" + operatedRDD.toDebugString)
75
76    // Test whether the checkpoint file has been created
77    if (reliableCheckpoint) {
78      assert(operatedRDD.getCheckpointFile.nonEmpty)
79      val recoveredRDD = sparkContext.checkpointFile[U](operatedRDD.getCheckpointFile.get)
80      assert(collectFunc(recoveredRDD) === result)
81      assert(recoveredRDD.partitioner === operatedRDD.partitioner)
82    }
83
84    // Test whether dependencies have been changed from its earlier parent RDD
85    assert(operatedRDD.dependencies.head != parentDependency)
86
87    // Test whether the partitions have been changed from its earlier partitions
88    assert(operatedRDD.partitions.toList != partitionsBeforeCheckpoint.toList)
89
90    // Test whether the partitions have been changed to the new Hadoop partitions
91    assert(operatedRDD.partitions.toList === operatedRDD.checkpointData.get.getPartitions.toList)
92
93    // Test whether the number of partitions is same as before
94    assert(operatedRDD.partitions.length === numPartitions)
95
96    // Test whether the data in the checkpointed RDD is same as original
97    assert(collectFunc(operatedRDD) === result)
98
99    // Test whether serialized size of the RDD has reduced.
100    logInfo("Size of " + rddType +
101      " [" + rddSizeBeforeCheckpoint + " --> " + rddSizeAfterCheckpoint + "]")
102    assert(
103      rddSizeAfterCheckpoint < rddSizeBeforeCheckpoint,
104      "Size of " + rddType + " did not reduce after checkpointing " +
105        " [" + rddSizeBeforeCheckpoint + " --> " + rddSizeAfterCheckpoint + "]"
106    )
107  }
108
109  /**
110   * Test whether checkpointing of the parent of the generated RDD also
111   * truncates the lineage or not. Some RDDs like CoGroupedRDD hold on to its parent
112   * RDDs partitions. So even if the parent RDD is checkpointed and its partitions changed,
113   * the generated RDD will remember the partitions and therefore potentially the whole lineage.
114   * This function should be called only those RDD whose partitions refer to parent RDD's
115   * partitions (i.e., do not call it on simple RDD like MappedRDD).
116   *
117   * @param op an operation to run on the RDD
118   * @param reliableCheckpoint if true, use reliable checkpoints, otherwise use local checkpoints
119   * @param collectFunc a function for collecting the values in the RDD, in case there are
120   *                    non-comparable types like arrays that we want to convert to something
121   *                    that supports ==
122   */
123  protected def testRDDPartitions[U: ClassTag](
124      op: (RDD[Int]) => RDD[U],
125      reliableCheckpoint: Boolean,
126      collectFunc: RDD[U] => Any = defaultCollectFunc[U] _): Unit = {
127    // Generate the final RDD using given RDD operation
128    val baseRDD = generateFatRDD()
129    val operatedRDD = op(baseRDD)
130    val parentRDDs = operatedRDD.dependencies.map(_.rdd)
131    val rddType = operatedRDD.getClass.getSimpleName
132
133    // Force initialization of all the data structures in RDDs
134    // Without this, serializing the RDD will give a wrong estimate of the size of the RDD
135    initializeRdd(operatedRDD)
136
137    // Find serialized sizes before and after the checkpoint
138    logInfo("RDD after checkpoint: " + operatedRDD + "\n" + operatedRDD.toDebugString)
139    val (rddSizeBeforeCheckpoint, partitionSizeBeforeCheckpoint) = getSerializedSizes(operatedRDD)
140    // checkpoint the parent RDD, not the generated one
141    parentRDDs.foreach { rdd =>
142      checkpoint(rdd, reliableCheckpoint)
143    }
144    val result = collectFunc(operatedRDD) // force checkpointing
145    operatedRDD.collect() // force re-initialization of post-checkpoint lazy variables
146    val (rddSizeAfterCheckpoint, partitionSizeAfterCheckpoint) = getSerializedSizes(operatedRDD)
147    logInfo("RDD after checkpoint: " + operatedRDD + "\n" + operatedRDD.toDebugString)
148
149    // Test whether the data in the checkpointed RDD is same as original
150    assert(collectFunc(operatedRDD) === result)
151
152    // Test whether serialized size of the partitions has reduced
153    logInfo("Size of partitions of " + rddType +
154      " [" + partitionSizeBeforeCheckpoint + " --> " + partitionSizeAfterCheckpoint + "]")
155    assert(
156      partitionSizeAfterCheckpoint < partitionSizeBeforeCheckpoint,
157      "Size of " + rddType + " partitions did not reduce after checkpointing parent RDDs" +
158        " [" + partitionSizeBeforeCheckpoint + " --> " + partitionSizeAfterCheckpoint + "]"
159    )
160  }
161
162  /**
163   * Get serialized sizes of the RDD and its partitions, in order to test whether the size shrinks
164   * upon checkpointing. Ignores the checkpointData field, which may grow when we checkpoint.
165   */
166  private def getSerializedSizes(rdd: RDD[_]): (Int, Int) = {
167    val rddSize = Utils.serialize(rdd).size
168    val rddCpDataSize = Utils.serialize(rdd.checkpointData).size
169    val rddPartitionSize = Utils.serialize(rdd.partitions).size
170    val rddDependenciesSize = Utils.serialize(rdd.dependencies).size
171
172    // Print detailed size, helps in debugging
173    logInfo("Serialized sizes of " + rdd +
174      ": RDD = " + rddSize +
175      ", RDD checkpoint data = " + rddCpDataSize +
176      ", RDD partitions = " + rddPartitionSize +
177      ", RDD dependencies = " + rddDependenciesSize
178    )
179    // this makes sure that serializing the RDD's checkpoint data does not
180    // serialize the whole RDD as well
181    assert(
182      rddSize > rddCpDataSize,
183      "RDD's checkpoint data (" + rddCpDataSize + ") is equal or larger than the " +
184        "whole RDD with checkpoint data (" + rddSize + ")"
185    )
186    (rddSize - rddCpDataSize, rddPartitionSize)
187  }
188
189  /**
190   * Serialize and deserialize an object. This is useful to verify the objects
191   * contents after deserialization (e.g., the contents of an RDD split after
192   * it is sent to a slave along with a task)
193   */
194  protected def serializeDeserialize[T](obj: T): T = {
195    val bytes = Utils.serialize(obj)
196    Utils.deserialize[T](bytes)
197  }
198
199  /**
200   * Recursively force the initialization of the all members of an RDD and it parents.
201   */
202  private def initializeRdd(rdd: RDD[_]): Unit = {
203    rdd.partitions // forces the initialization of the partitions
204    rdd.dependencies.map(_.rdd).foreach(initializeRdd)
205  }
206
207  /** Checkpoint the RDD either locally or reliably. */
208  protected def checkpoint(rdd: RDD[_], reliableCheckpoint: Boolean): Unit = {
209    if (reliableCheckpoint) {
210      rdd.checkpoint()
211    } else {
212      rdd.localCheckpoint()
213    }
214  }
215
216  /** Run a test twice, once for local checkpointing and once for reliable checkpointing. */
217  protected def runTest(
218      name: String,
219      skipLocalCheckpoint: Boolean = false
220    )(body: Boolean => Unit): Unit = {
221    test(name + " [reliable checkpoint]")(body(true))
222    if (!skipLocalCheckpoint) {
223      test(name + " [local checkpoint]")(body(false))
224    }
225  }
226
227  /**
228   * Generate an RDD such that both the RDD and its partitions have large size.
229   */
230  protected def generateFatRDD(): RDD[Int] = {
231    new FatRDD(sparkContext.makeRDD(1 to 100, 4)).map(x => x)
232  }
233
234  /**
235   * Generate an pair RDD (with partitioner) such that both the RDD and its partitions
236   * have large size.
237   */
238  protected def generateFatPairRDD(): RDD[(Int, Int)] = {
239    new FatPairRDD(sparkContext.makeRDD(1 to 100, 4), partitioner).mapValues(x => x)
240  }
241}
242
243/**
244 * Test suite for end-to-end checkpointing functionality.
245 * This tests both reliable checkpoints and local checkpoints.
246 */
247class CheckpointSuite extends SparkFunSuite with RDDCheckpointTester with LocalSparkContext {
248  private var checkpointDir: File = _
249
250  override def beforeEach(): Unit = {
251    super.beforeEach()
252    checkpointDir = File.createTempFile("temp", "", Utils.createTempDir())
253    checkpointDir.delete()
254    sc = new SparkContext("local", "test")
255    sc.setCheckpointDir(checkpointDir.toString)
256  }
257
258  override def afterEach(): Unit = {
259    try {
260      Utils.deleteRecursively(checkpointDir)
261    } finally {
262      super.afterEach()
263    }
264  }
265
266  override def sparkContext: SparkContext = sc
267
268  runTest("basic checkpointing") { reliableCheckpoint: Boolean =>
269    val parCollection = sc.makeRDD(1 to 4)
270    val flatMappedRDD = parCollection.flatMap(x => 1 to x)
271    checkpoint(flatMappedRDD, reliableCheckpoint)
272    assert(flatMappedRDD.dependencies.head.rdd === parCollection)
273    val result = flatMappedRDD.collect()
274    assert(flatMappedRDD.dependencies.head.rdd != parCollection)
275    assert(flatMappedRDD.collect() === result)
276  }
277
278  runTest("checkpointing partitioners", skipLocalCheckpoint = true) { _: Boolean =>
279
280    def testPartitionerCheckpointing(
281        partitioner: Partitioner,
282        corruptPartitionerFile: Boolean = false
283      ): Unit = {
284      val rddWithPartitioner = sc.makeRDD(1 to 4).map { _ -> 1 }.partitionBy(partitioner)
285      rddWithPartitioner.checkpoint()
286      rddWithPartitioner.count()
287      assert(rddWithPartitioner.getCheckpointFile.get.nonEmpty,
288        "checkpointing was not successful")
289
290      if (corruptPartitionerFile) {
291        // Overwrite the partitioner file with garbage data
292        val checkpointDir = new Path(rddWithPartitioner.getCheckpointFile.get)
293        val fs = checkpointDir.getFileSystem(sc.hadoopConfiguration)
294        val partitionerFile = fs.listStatus(checkpointDir)
295            .find(_.getPath.getName.contains("partitioner"))
296            .map(_.getPath)
297        require(partitionerFile.nonEmpty, "could not find the partitioner file for testing")
298        val output = fs.create(partitionerFile.get, true)
299        output.write(100)
300        output.close()
301      }
302
303      val newRDD = sc.checkpointFile[(Int, Int)](rddWithPartitioner.getCheckpointFile.get)
304      assert(newRDD.collect().toSet === rddWithPartitioner.collect().toSet, "RDD not recovered")
305
306      if (!corruptPartitionerFile) {
307        assert(newRDD.partitioner != None, "partitioner not recovered")
308        assert(newRDD.partitioner === rddWithPartitioner.partitioner,
309          "recovered partitioner does not match")
310      } else {
311        assert(newRDD.partitioner == None, "partitioner unexpectedly recovered")
312      }
313    }
314
315    testPartitionerCheckpointing(partitioner)
316
317    // Test that corrupted partitioner file does not prevent recovery of RDD
318    testPartitionerCheckpointing(partitioner, corruptPartitionerFile = true)
319  }
320
321  runTest("RDDs with one-to-one dependencies") { reliableCheckpoint: Boolean =>
322    testRDD(_.map(x => x.toString), reliableCheckpoint)
323    testRDD(_.flatMap(x => 1 to x), reliableCheckpoint)
324    testRDD(_.filter(_ % 2 == 0), reliableCheckpoint)
325    testRDD(_.sample(false, 0.5, 0), reliableCheckpoint)
326    testRDD(_.glom(), reliableCheckpoint)
327    testRDD(_.mapPartitions(_.map(_.toString)), reliableCheckpoint)
328    testRDD(_.map(x => (x % 2, 1)).reduceByKey(_ + _).mapValues(_.toString), reliableCheckpoint)
329    testRDD(_.map(x => (x % 2, 1)).reduceByKey(_ + _).flatMapValues(x => 1 to x),
330      reliableCheckpoint)
331    testRDD(_.pipe(Seq("cat")), reliableCheckpoint)
332  }
333
334  runTest("ParallelCollectionRDD") { reliableCheckpoint: Boolean =>
335    val parCollection = sc.makeRDD(1 to 4, 2)
336    val numPartitions = parCollection.partitions.size
337    checkpoint(parCollection, reliableCheckpoint)
338    assert(parCollection.dependencies === Nil)
339    val result = parCollection.collect()
340    if (reliableCheckpoint) {
341      assert(sc.checkpointFile[Int](parCollection.getCheckpointFile.get).collect() === result)
342    }
343    assert(parCollection.dependencies != Nil)
344    assert(parCollection.partitions.length === numPartitions)
345    assert(parCollection.partitions.toList ===
346      parCollection.checkpointData.get.getPartitions.toList)
347    assert(parCollection.collect() === result)
348  }
349
350  runTest("BlockRDD") { reliableCheckpoint: Boolean =>
351    val blockId = TestBlockId("id")
352    val blockManager = SparkEnv.get.blockManager
353    blockManager.putSingle(blockId, "test", StorageLevel.MEMORY_ONLY)
354    val blockRDD = new BlockRDD[String](sc, Array(blockId))
355    val numPartitions = blockRDD.partitions.size
356    checkpoint(blockRDD, reliableCheckpoint)
357    val result = blockRDD.collect()
358    if (reliableCheckpoint) {
359      assert(sc.checkpointFile[String](blockRDD.getCheckpointFile.get).collect() === result)
360    }
361    assert(blockRDD.dependencies != Nil)
362    assert(blockRDD.partitions.length === numPartitions)
363    assert(blockRDD.partitions.toList === blockRDD.checkpointData.get.getPartitions.toList)
364    assert(blockRDD.collect() === result)
365  }
366
367  runTest("ShuffleRDD") { reliableCheckpoint: Boolean =>
368    testRDD(rdd => {
369      // Creating ShuffledRDD directly as PairRDDFunctions.combineByKey produces a MapPartitionedRDD
370      new ShuffledRDD[Int, Int, Int](rdd.map(x => (x % 2, 1)), partitioner)
371    }, reliableCheckpoint)
372  }
373
374  runTest("UnionRDD") { reliableCheckpoint: Boolean =>
375    def otherRDD: RDD[Int] = sc.makeRDD(1 to 10, 1)
376    testRDD(_.union(otherRDD), reliableCheckpoint)
377    testRDDPartitions(_.union(otherRDD), reliableCheckpoint)
378  }
379
380  runTest("CartesianRDD") { reliableCheckpoint: Boolean =>
381    def otherRDD: RDD[Int] = sc.makeRDD(1 to 10, 1)
382    testRDD(new CartesianRDD(sc, _, otherRDD), reliableCheckpoint)
383    testRDDPartitions(new CartesianRDD(sc, _, otherRDD), reliableCheckpoint)
384
385    // Test that the CartesianRDD updates parent partitions (CartesianRDD.s1/s2) after
386    // the parent RDD has been checkpointed and parent partitions have been changed.
387    // Note that this test is very specific to the current implementation of CartesianRDD.
388    val ones = sc.makeRDD(1 to 100, 10).map(x => x)
389    checkpoint(ones, reliableCheckpoint) // checkpoint that MappedRDD
390    val cartesian = new CartesianRDD(sc, ones, ones)
391    val splitBeforeCheckpoint =
392      serializeDeserialize(cartesian.partitions.head.asInstanceOf[CartesianPartition])
393    cartesian.count() // do the checkpointing
394    val splitAfterCheckpoint =
395      serializeDeserialize(cartesian.partitions.head.asInstanceOf[CartesianPartition])
396    assert(
397      (splitAfterCheckpoint.s1.getClass != splitBeforeCheckpoint.s1.getClass) &&
398        (splitAfterCheckpoint.s2.getClass != splitBeforeCheckpoint.s2.getClass),
399      "CartesianRDD.s1 and CartesianRDD.s2 not updated after parent RDD is checkpointed"
400    )
401  }
402
403  runTest("CoalescedRDD") { reliableCheckpoint: Boolean =>
404    testRDD(_.coalesce(2), reliableCheckpoint)
405    testRDDPartitions(_.coalesce(2), reliableCheckpoint)
406
407    // Test that the CoalescedRDDPartition updates parent partitions (CoalescedRDDPartition.parents)
408    // after the parent RDD has been checkpointed and parent partitions have been changed.
409    // Note that this test is very specific to the current implementation of
410    // CoalescedRDDPartitions.
411    val ones = sc.makeRDD(1 to 100, 10).map(x => x)
412    checkpoint(ones, reliableCheckpoint) // checkpoint that MappedRDD
413    val coalesced = new CoalescedRDD(ones, 2)
414    val splitBeforeCheckpoint =
415      serializeDeserialize(coalesced.partitions.head.asInstanceOf[CoalescedRDDPartition])
416    coalesced.count() // do the checkpointing
417    val splitAfterCheckpoint =
418      serializeDeserialize(coalesced.partitions.head.asInstanceOf[CoalescedRDDPartition])
419    assert(
420      splitAfterCheckpoint.parents.head.getClass != splitBeforeCheckpoint.parents.head.getClass,
421      "CoalescedRDDPartition.parents not updated after parent RDD is checkpointed"
422    )
423  }
424
425  runTest("CoGroupedRDD") { reliableCheckpoint: Boolean =>
426    val longLineageRDD1 = generateFatPairRDD()
427
428    // Collect the RDD as sequences instead of arrays to enable equality tests in testRDD
429    val seqCollectFunc = (rdd: RDD[(Int, Array[Iterable[Int]])]) =>
430      rdd.map{case (p, a) => (p, a.toSeq)}.collect(): Any
431
432    testRDD(rdd => {
433      CheckpointSuite.cogroup(longLineageRDD1, rdd.map(x => (x % 2, 1)), partitioner)
434    }, reliableCheckpoint, seqCollectFunc)
435
436    val longLineageRDD2 = generateFatPairRDD()
437    testRDDPartitions(rdd => {
438      CheckpointSuite.cogroup(
439        longLineageRDD2, sc.makeRDD(1 to 2, 2).map(x => (x % 2, 1)), partitioner)
440    }, reliableCheckpoint, seqCollectFunc)
441  }
442
443  runTest("ZippedPartitionsRDD") { reliableCheckpoint: Boolean =>
444    testRDD(rdd => rdd.zip(rdd.map(x => x)), reliableCheckpoint)
445    testRDDPartitions(rdd => rdd.zip(rdd.map(x => x)), reliableCheckpoint)
446
447    // Test that ZippedPartitionsRDD updates parent partitions after parent RDDs have
448    // been checkpointed and parent partitions have been changed.
449    // Note that this test is very specific to the implementation of ZippedPartitionsRDD.
450    val rdd = generateFatRDD()
451    val zippedRDD = rdd.zip(rdd.map(x => x)).asInstanceOf[ZippedPartitionsRDD2[_, _, _]]
452    checkpoint(zippedRDD.rdd1, reliableCheckpoint)
453    checkpoint(zippedRDD.rdd2, reliableCheckpoint)
454    val partitionBeforeCheckpoint =
455      serializeDeserialize(zippedRDD.partitions.head.asInstanceOf[ZippedPartitionsPartition])
456    zippedRDD.count()
457    val partitionAfterCheckpoint =
458      serializeDeserialize(zippedRDD.partitions.head.asInstanceOf[ZippedPartitionsPartition])
459    assert(
460      partitionAfterCheckpoint.partitions(0).getClass !=
461        partitionBeforeCheckpoint.partitions(0).getClass &&
462      partitionAfterCheckpoint.partitions(1).getClass !=
463        partitionBeforeCheckpoint.partitions(1).getClass,
464      "ZippedPartitionsRDD partition 0 (or 1) not updated after parent RDDs are checkpointed"
465    )
466  }
467
468  runTest("PartitionerAwareUnionRDD") { reliableCheckpoint: Boolean =>
469    testRDD(rdd => {
470      new PartitionerAwareUnionRDD[(Int, Int)](sc, Array(
471        generateFatPairRDD(),
472        rdd.map(x => (x % 2, 1)).reduceByKey(partitioner, _ + _)
473      ))
474    }, reliableCheckpoint)
475
476    testRDDPartitions(rdd => {
477      new PartitionerAwareUnionRDD[(Int, Int)](sc, Array(
478        generateFatPairRDD(),
479        rdd.map(x => (x % 2, 1)).reduceByKey(partitioner, _ + _)
480      ))
481    }, reliableCheckpoint)
482
483    // Test that the PartitionerAwareUnionRDD updates parent partitions
484    // (PartitionerAwareUnionRDD.parents) after the parent RDD has been checkpointed and parent
485    // partitions have been changed. Note that this test is very specific to the current
486    // implementation of PartitionerAwareUnionRDD.
487    val pairRDD = generateFatPairRDD()
488    checkpoint(pairRDD, reliableCheckpoint)
489    val unionRDD = new PartitionerAwareUnionRDD(sc, Array(pairRDD))
490    val partitionBeforeCheckpoint = serializeDeserialize(
491      unionRDD.partitions.head.asInstanceOf[PartitionerAwareUnionRDDPartition])
492    pairRDD.count()
493    val partitionAfterCheckpoint = serializeDeserialize(
494      unionRDD.partitions.head.asInstanceOf[PartitionerAwareUnionRDDPartition])
495    assert(
496      partitionBeforeCheckpoint.parents.head.getClass !=
497        partitionAfterCheckpoint.parents.head.getClass,
498      "PartitionerAwareUnionRDDPartition.parents not updated after parent RDD is checkpointed"
499    )
500  }
501
502  runTest("CheckpointRDD with zero partitions") { reliableCheckpoint: Boolean =>
503    val rdd = new BlockRDD[Int](sc, Array.empty[BlockId])
504    assert(rdd.partitions.size === 0)
505    assert(rdd.isCheckpointed === false)
506    assert(rdd.isCheckpointedAndMaterialized === false)
507    checkpoint(rdd, reliableCheckpoint)
508    assert(rdd.isCheckpointed === false)
509    assert(rdd.isCheckpointedAndMaterialized === false)
510    assert(rdd.count() === 0)
511    assert(rdd.isCheckpointed === true)
512    assert(rdd.isCheckpointedAndMaterialized === true)
513    assert(rdd.partitions.size === 0)
514  }
515
516  runTest("checkpointAllMarkedAncestors") { reliableCheckpoint: Boolean =>
517    testCheckpointAllMarkedAncestors(reliableCheckpoint, checkpointAllMarkedAncestors = true)
518    testCheckpointAllMarkedAncestors(reliableCheckpoint, checkpointAllMarkedAncestors = false)
519  }
520
521  private def testCheckpointAllMarkedAncestors(
522      reliableCheckpoint: Boolean, checkpointAllMarkedAncestors: Boolean): Unit = {
523    sc.setLocalProperty(RDD.CHECKPOINT_ALL_MARKED_ANCESTORS, checkpointAllMarkedAncestors.toString)
524    try {
525      val rdd1 = sc.parallelize(1 to 10)
526      checkpoint(rdd1, reliableCheckpoint)
527      val rdd2 = rdd1.map(_ + 1)
528      checkpoint(rdd2, reliableCheckpoint)
529      rdd2.count()
530      assert(rdd1.isCheckpointed === checkpointAllMarkedAncestors)
531      assert(rdd2.isCheckpointed === true)
532    } finally {
533      sc.setLocalProperty(RDD.CHECKPOINT_ALL_MARKED_ANCESTORS, null)
534    }
535  }
536}
537
538/** RDD partition that has large serialized size. */
539class FatPartition(val partition: Partition) extends Partition {
540  val bigData = new Array[Byte](10000)
541  def index: Int = partition.index
542}
543
544/** RDD that has large serialized size. */
545class FatRDD(parent: RDD[Int]) extends RDD[Int](parent) {
546  val bigData = new Array[Byte](100000)
547
548  protected def getPartitions: Array[Partition] = {
549    parent.partitions.map(p => new FatPartition(p))
550  }
551
552  def compute(split: Partition, context: TaskContext): Iterator[Int] = {
553    parent.compute(split.asInstanceOf[FatPartition].partition, context)
554  }
555}
556
557/** Pair RDD that has large serialized size. */
558class FatPairRDD(parent: RDD[Int], _partitioner: Partitioner) extends RDD[(Int, Int)](parent) {
559  val bigData = new Array[Byte](100000)
560
561  protected def getPartitions: Array[Partition] = {
562    parent.partitions.map(p => new FatPartition(p))
563  }
564
565  @transient override val partitioner = Some(_partitioner)
566
567  def compute(split: Partition, context: TaskContext): Iterator[(Int, Int)] = {
568    parent.compute(split.asInstanceOf[FatPartition].partition, context).map(x => (x, x))
569  }
570}
571
572object CheckpointSuite {
573  // This is a custom cogroup function that does not use mapValues like
574  // the PairRDDFunctions.cogroup()
575  def cogroup[K: ClassTag, V: ClassTag](first: RDD[(K, V)], second: RDD[(K, V)], part: Partitioner)
576    : RDD[(K, Array[Iterable[V]])] = {
577    new CoGroupedRDD[K](
578      Seq(first.asInstanceOf[RDD[(K, _)]], second.asInstanceOf[RDD[(K, _)]]),
579      part
580    ).asInstanceOf[RDD[(K, Array[Iterable[V]])]]
581  }
582}
583