1/*
2 * Licensed to the Apache Software Foundation (ASF) under one or more
3 * contributor license agreements.  See the NOTICE file distributed with
4 * this work for additional information regarding copyright ownership.
5 * The ASF licenses this file to You under the Apache License, Version 2.0
6 * (the "License"); you may not use this file except in compliance with
7 * the License.  You may obtain a copy of the License at
8 *
9 *    http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17
18package org.apache.spark.mllib.impl
19
20import org.apache.hadoop.fs.Path
21
22import org.apache.spark.{SparkContext, SparkFunSuite}
23import org.apache.spark.mllib.util.MLlibTestSparkContext
24import org.apache.spark.rdd.RDD
25import org.apache.spark.storage.StorageLevel
26import org.apache.spark.util.Utils
27
28
29class PeriodicRDDCheckpointerSuite extends SparkFunSuite with MLlibTestSparkContext {
30
31  import PeriodicRDDCheckpointerSuite._
32
33  test("Persisting") {
34    var rddsToCheck = Seq.empty[RDDToCheck]
35
36    val rdd1 = createRDD(sc)
37    val checkpointer = new PeriodicRDDCheckpointer[Double](10, rdd1.sparkContext)
38    checkpointer.update(rdd1)
39    rddsToCheck = rddsToCheck :+ RDDToCheck(rdd1, 1)
40    checkPersistence(rddsToCheck, 1)
41
42    var iteration = 2
43    while (iteration < 9) {
44      val rdd = createRDD(sc)
45      checkpointer.update(rdd)
46      rddsToCheck = rddsToCheck :+ RDDToCheck(rdd, iteration)
47      checkPersistence(rddsToCheck, iteration)
48      iteration += 1
49    }
50  }
51
52  test("Checkpointing") {
53    val tempDir = Utils.createTempDir()
54    val path = tempDir.toURI.toString
55    val checkpointInterval = 2
56    var rddsToCheck = Seq.empty[RDDToCheck]
57    sc.setCheckpointDir(path)
58    val rdd1 = createRDD(sc)
59    val checkpointer = new PeriodicRDDCheckpointer[Double](checkpointInterval, rdd1.sparkContext)
60    checkpointer.update(rdd1)
61    rdd1.count()
62    rddsToCheck = rddsToCheck :+ RDDToCheck(rdd1, 1)
63    checkCheckpoint(rddsToCheck, 1, checkpointInterval)
64
65    var iteration = 2
66    while (iteration < 9) {
67      val rdd = createRDD(sc)
68      checkpointer.update(rdd)
69      rdd.count()
70      rddsToCheck = rddsToCheck :+ RDDToCheck(rdd, iteration)
71      checkCheckpoint(rddsToCheck, iteration, checkpointInterval)
72      iteration += 1
73    }
74
75    checkpointer.deleteAllCheckpoints()
76    rddsToCheck.foreach { rdd =>
77      confirmCheckpointRemoved(rdd.rdd)
78    }
79
80    Utils.deleteRecursively(tempDir)
81  }
82}
83
84private object PeriodicRDDCheckpointerSuite {
85
86  case class RDDToCheck(rdd: RDD[Double], gIndex: Int)
87
88  def createRDD(sc: SparkContext): RDD[Double] = {
89    sc.parallelize(Seq(0.0, 1.0, 2.0, 3.0))
90  }
91
92  def checkPersistence(rdds: Seq[RDDToCheck], iteration: Int): Unit = {
93    rdds.foreach { g =>
94      checkPersistence(g.rdd, g.gIndex, iteration)
95    }
96  }
97
98  /**
99   * Check storage level of rdd.
100   * @param gIndex  Index of rdd in order inserted into checkpointer (from 1).
101   * @param iteration  Total number of rdds inserted into checkpointer.
102   */
103  def checkPersistence(rdd: RDD[_], gIndex: Int, iteration: Int): Unit = {
104    try {
105      if (gIndex + 2 < iteration) {
106        assert(rdd.getStorageLevel == StorageLevel.NONE)
107      } else {
108        assert(rdd.getStorageLevel != StorageLevel.NONE)
109      }
110    } catch {
111      case _: AssertionError =>
112        throw new Exception(s"PeriodicRDDCheckpointerSuite.checkPersistence failed with:\n" +
113          s"\t gIndex = $gIndex\n" +
114          s"\t iteration = $iteration\n" +
115          s"\t rdd.getStorageLevel = ${rdd.getStorageLevel}\n")
116    }
117  }
118
119  def checkCheckpoint(rdds: Seq[RDDToCheck], iteration: Int, checkpointInterval: Int): Unit = {
120    rdds.reverse.foreach { g =>
121      checkCheckpoint(g.rdd, g.gIndex, iteration, checkpointInterval)
122    }
123  }
124
125  def confirmCheckpointRemoved(rdd: RDD[_]): Unit = {
126    // Note: We cannot check rdd.isCheckpointed since that value is never updated.
127    //       Instead, we check for the presence of the checkpoint files.
128    //       This test should continue to work even after this rdd.isCheckpointed issue
129    //       is fixed (though it can then be simplified and not look for the files).
130    val hadoopConf = rdd.sparkContext.hadoopConfiguration
131    rdd.getCheckpointFile.foreach { checkpointFile =>
132      val path = new Path(checkpointFile)
133      val fs = path.getFileSystem(hadoopConf)
134      assert(!fs.exists(path), "RDD checkpoint file should have been removed")
135    }
136  }
137
138  /**
139   * Check checkpointed status of rdd.
140   * @param gIndex  Index of rdd in order inserted into checkpointer (from 1).
141   * @param iteration  Total number of rdds inserted into checkpointer.
142   */
143  def checkCheckpoint(
144      rdd: RDD[_],
145      gIndex: Int,
146      iteration: Int,
147      checkpointInterval: Int): Unit = {
148    try {
149      if (gIndex % checkpointInterval == 0) {
150        // We allow 2 checkpoint intervals since we perform an action (checkpointing a second rdd)
151        // only AFTER PeriodicRDDCheckpointer decides whether to remove the previous checkpoint.
152        if (iteration - 2 * checkpointInterval < gIndex && gIndex <= iteration) {
153          assert(rdd.isCheckpointed, "RDD should be checkpointed")
154          assert(rdd.getCheckpointFile.nonEmpty, "RDD should have 2 checkpoint files")
155        } else {
156          confirmCheckpointRemoved(rdd)
157        }
158      } else {
159        // RDD should never be checkpointed
160        assert(!rdd.isCheckpointed, "RDD should never have been checkpointed")
161        assert(rdd.getCheckpointFile.isEmpty, "RDD should not have any checkpoint files")
162      }
163    } catch {
164      case e: AssertionError =>
165        throw new Exception(s"PeriodicRDDCheckpointerSuite.checkCheckpoint failed with:\n" +
166          s"\t gIndex = $gIndex\n" +
167          s"\t iteration = $iteration\n" +
168          s"\t checkpointInterval = $checkpointInterval\n" +
169          s"\t rdd.isCheckpointed = ${rdd.isCheckpointed}\n" +
170          s"\t rdd.getCheckpointFile = ${rdd.getCheckpointFile.mkString(", ")}\n" +
171          s"  AssertionError message: ${e.getMessage}")
172    }
173  }
174
175}
176