1/*
2 * Licensed to the Apache Software Foundation (ASF) under one or more
3 * contributor license agreements.  See the NOTICE file distributed with
4 * this work for additional information regarding copyright ownership.
5 * The ASF licenses this file to You under the Apache License, Version 2.0
6 * (the "License"); you may not use this file except in compliance with
7 * the License.  You may obtain a copy of the License at
8 *
9 *    http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17
18package org.apache.spark.mllib.impl
19
20import org.apache.hadoop.fs.Path
21
22import org.apache.spark.{SparkContext, SparkFunSuite}
23import org.apache.spark.graphx.{Edge, Graph}
24import org.apache.spark.mllib.util.MLlibTestSparkContext
25import org.apache.spark.storage.StorageLevel
26import org.apache.spark.util.Utils
27
28
29class PeriodicGraphCheckpointerSuite extends SparkFunSuite with MLlibTestSparkContext {
30
31  import PeriodicGraphCheckpointerSuite._
32
33  test("Persisting") {
34    var graphsToCheck = Seq.empty[GraphToCheck]
35
36    val graph1 = createGraph(sc)
37    val checkpointer =
38      new PeriodicGraphCheckpointer[Double, Double](10, graph1.vertices.sparkContext)
39    checkpointer.update(graph1)
40    graphsToCheck = graphsToCheck :+ GraphToCheck(graph1, 1)
41    checkPersistence(graphsToCheck, 1)
42
43    var iteration = 2
44    while (iteration < 9) {
45      val graph = createGraph(sc)
46      checkpointer.update(graph)
47      graphsToCheck = graphsToCheck :+ GraphToCheck(graph, iteration)
48      checkPersistence(graphsToCheck, iteration)
49      iteration += 1
50    }
51  }
52
53  test("Checkpointing") {
54    val tempDir = Utils.createTempDir()
55    val path = tempDir.toURI.toString
56    val checkpointInterval = 2
57    var graphsToCheck = Seq.empty[GraphToCheck]
58    sc.setCheckpointDir(path)
59    val graph1 = createGraph(sc)
60    val checkpointer = new PeriodicGraphCheckpointer[Double, Double](
61      checkpointInterval, graph1.vertices.sparkContext)
62    checkpointer.update(graph1)
63    graph1.edges.count()
64    graph1.vertices.count()
65    graphsToCheck = graphsToCheck :+ GraphToCheck(graph1, 1)
66    checkCheckpoint(graphsToCheck, 1, checkpointInterval)
67
68    var iteration = 2
69    while (iteration < 9) {
70      val graph = createGraph(sc)
71      checkpointer.update(graph)
72      graph.vertices.count()
73      graph.edges.count()
74      graphsToCheck = graphsToCheck :+ GraphToCheck(graph, iteration)
75      checkCheckpoint(graphsToCheck, iteration, checkpointInterval)
76      iteration += 1
77    }
78
79    checkpointer.deleteAllCheckpoints()
80    graphsToCheck.foreach { graph =>
81      confirmCheckpointRemoved(graph.graph)
82    }
83
84    Utils.deleteRecursively(tempDir)
85  }
86}
87
88private object PeriodicGraphCheckpointerSuite {
89
90  case class GraphToCheck(graph: Graph[Double, Double], gIndex: Int)
91
92  val edges = Seq(
93    Edge[Double](0, 1, 0),
94    Edge[Double](1, 2, 0),
95    Edge[Double](2, 3, 0),
96    Edge[Double](3, 4, 0))
97
98  def createGraph(sc: SparkContext): Graph[Double, Double] = {
99    Graph.fromEdges[Double, Double](sc.parallelize(edges), 0)
100  }
101
102  def checkPersistence(graphs: Seq[GraphToCheck], iteration: Int): Unit = {
103    graphs.foreach { g =>
104      checkPersistence(g.graph, g.gIndex, iteration)
105    }
106  }
107
108  /**
109   * Check storage level of graph.
110   * @param gIndex  Index of graph in order inserted into checkpointer (from 1).
111   * @param iteration  Total number of graphs inserted into checkpointer.
112   */
113  def checkPersistence(graph: Graph[_, _], gIndex: Int, iteration: Int): Unit = {
114    try {
115      if (gIndex + 2 < iteration) {
116        assert(graph.vertices.getStorageLevel == StorageLevel.NONE)
117        assert(graph.edges.getStorageLevel == StorageLevel.NONE)
118      } else {
119        assert(graph.vertices.getStorageLevel != StorageLevel.NONE)
120        assert(graph.edges.getStorageLevel != StorageLevel.NONE)
121      }
122    } catch {
123      case _: AssertionError =>
124        throw new Exception(s"PeriodicGraphCheckpointerSuite.checkPersistence failed with:\n" +
125          s"\t gIndex = $gIndex\n" +
126          s"\t iteration = $iteration\n" +
127          s"\t graph.vertices.getStorageLevel = ${graph.vertices.getStorageLevel}\n" +
128          s"\t graph.edges.getStorageLevel = ${graph.edges.getStorageLevel}\n")
129    }
130  }
131
132  def checkCheckpoint(graphs: Seq[GraphToCheck], iteration: Int, checkpointInterval: Int): Unit = {
133    graphs.reverse.foreach { g =>
134      checkCheckpoint(g.graph, g.gIndex, iteration, checkpointInterval)
135    }
136  }
137
138  def confirmCheckpointRemoved(graph: Graph[_, _]): Unit = {
139    // Note: We cannot check graph.isCheckpointed since that value is never updated.
140    //       Instead, we check for the presence of the checkpoint files.
141    //       This test should continue to work even after this graph.isCheckpointed issue
142    //       is fixed (though it can then be simplified and not look for the files).
143    val hadoopConf = graph.vertices.sparkContext.hadoopConfiguration
144    graph.getCheckpointFiles.foreach { checkpointFile =>
145      val path = new Path(checkpointFile)
146      val fs = path.getFileSystem(hadoopConf)
147      assert(!fs.exists(path),
148        "Graph checkpoint file should have been removed")
149    }
150  }
151
152  /**
153   * Check checkpointed status of graph.
154   * @param gIndex  Index of graph in order inserted into checkpointer (from 1).
155   * @param iteration  Total number of graphs inserted into checkpointer.
156   */
157  def checkCheckpoint(
158      graph: Graph[_, _],
159      gIndex: Int,
160      iteration: Int,
161      checkpointInterval: Int): Unit = {
162    try {
163      if (gIndex % checkpointInterval == 0) {
164        // We allow 2 checkpoint intervals since we perform an action (checkpointing a second graph)
165        // only AFTER PeriodicGraphCheckpointer decides whether to remove the previous checkpoint.
166        if (iteration - 2 * checkpointInterval < gIndex && gIndex <= iteration) {
167          assert(graph.isCheckpointed, "Graph should be checkpointed")
168          assert(graph.getCheckpointFiles.length == 2, "Graph should have 2 checkpoint files")
169        } else {
170          confirmCheckpointRemoved(graph)
171        }
172      } else {
173        // Graph should never be checkpointed
174        assert(!graph.isCheckpointed, "Graph should never have been checkpointed")
175        assert(graph.getCheckpointFiles.isEmpty, "Graph should not have any checkpoint files")
176      }
177    } catch {
178      case e: AssertionError =>
179        throw new Exception(s"PeriodicGraphCheckpointerSuite.checkCheckpoint failed with:\n" +
180          s"\t gIndex = $gIndex\n" +
181          s"\t iteration = $iteration\n" +
182          s"\t checkpointInterval = $checkpointInterval\n" +
183          s"\t graph.isCheckpointed = ${graph.isCheckpointed}\n" +
184          s"\t graph.getCheckpointFiles = ${graph.getCheckpointFiles.mkString(", ")}\n" +
185          s"  AssertionError message: ${e.getMessage}")
186    }
187  }
188
189}
190