1/* 2 * Licensed to the Apache Software Foundation (ASF) under one or more 3 * contributor license agreements. See the NOTICE file distributed with 4 * this work for additional information regarding copyright ownership. 5 * The ASF licenses this file to You under the Apache License, Version 2.0 6 * (the "License"); you may not use this file except in compliance with 7 * the License. You may obtain a copy of the License at 8 * 9 * http://www.apache.org/licenses/LICENSE-2.0 10 * 11 * Unless required by applicable law or agreed to in writing, software 12 * distributed under the License is distributed on an "AS IS" BASIS, 13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 * See the License for the specific language governing permissions and 15 * limitations under the License. 16 */ 17 18package org.apache.spark.mllib.impl 19 20import org.apache.hadoop.fs.Path 21 22import org.apache.spark.{SparkContext, SparkFunSuite} 23import org.apache.spark.graphx.{Edge, Graph} 24import org.apache.spark.mllib.util.MLlibTestSparkContext 25import org.apache.spark.storage.StorageLevel 26import org.apache.spark.util.Utils 27 28 29class PeriodicGraphCheckpointerSuite extends SparkFunSuite with MLlibTestSparkContext { 30 31 import PeriodicGraphCheckpointerSuite._ 32 33 test("Persisting") { 34 var graphsToCheck = Seq.empty[GraphToCheck] 35 36 val graph1 = createGraph(sc) 37 val checkpointer = 38 new PeriodicGraphCheckpointer[Double, Double](10, graph1.vertices.sparkContext) 39 checkpointer.update(graph1) 40 graphsToCheck = graphsToCheck :+ GraphToCheck(graph1, 1) 41 checkPersistence(graphsToCheck, 1) 42 43 var iteration = 2 44 while (iteration < 9) { 45 val graph = createGraph(sc) 46 checkpointer.update(graph) 47 graphsToCheck = graphsToCheck :+ GraphToCheck(graph, iteration) 48 checkPersistence(graphsToCheck, iteration) 49 iteration += 1 50 } 51 } 52 53 test("Checkpointing") { 54 val tempDir = Utils.createTempDir() 55 val path = tempDir.toURI.toString 56 val checkpointInterval = 2 57 var graphsToCheck = Seq.empty[GraphToCheck] 58 sc.setCheckpointDir(path) 59 val graph1 = createGraph(sc) 60 val checkpointer = new PeriodicGraphCheckpointer[Double, Double]( 61 checkpointInterval, graph1.vertices.sparkContext) 62 checkpointer.update(graph1) 63 graph1.edges.count() 64 graph1.vertices.count() 65 graphsToCheck = graphsToCheck :+ GraphToCheck(graph1, 1) 66 checkCheckpoint(graphsToCheck, 1, checkpointInterval) 67 68 var iteration = 2 69 while (iteration < 9) { 70 val graph = createGraph(sc) 71 checkpointer.update(graph) 72 graph.vertices.count() 73 graph.edges.count() 74 graphsToCheck = graphsToCheck :+ GraphToCheck(graph, iteration) 75 checkCheckpoint(graphsToCheck, iteration, checkpointInterval) 76 iteration += 1 77 } 78 79 checkpointer.deleteAllCheckpoints() 80 graphsToCheck.foreach { graph => 81 confirmCheckpointRemoved(graph.graph) 82 } 83 84 Utils.deleteRecursively(tempDir) 85 } 86} 87 88private object PeriodicGraphCheckpointerSuite { 89 90 case class GraphToCheck(graph: Graph[Double, Double], gIndex: Int) 91 92 val edges = Seq( 93 Edge[Double](0, 1, 0), 94 Edge[Double](1, 2, 0), 95 Edge[Double](2, 3, 0), 96 Edge[Double](3, 4, 0)) 97 98 def createGraph(sc: SparkContext): Graph[Double, Double] = { 99 Graph.fromEdges[Double, Double](sc.parallelize(edges), 0) 100 } 101 102 def checkPersistence(graphs: Seq[GraphToCheck], iteration: Int): Unit = { 103 graphs.foreach { g => 104 checkPersistence(g.graph, g.gIndex, iteration) 105 } 106 } 107 108 /** 109 * Check storage level of graph. 110 * @param gIndex Index of graph in order inserted into checkpointer (from 1). 111 * @param iteration Total number of graphs inserted into checkpointer. 112 */ 113 def checkPersistence(graph: Graph[_, _], gIndex: Int, iteration: Int): Unit = { 114 try { 115 if (gIndex + 2 < iteration) { 116 assert(graph.vertices.getStorageLevel == StorageLevel.NONE) 117 assert(graph.edges.getStorageLevel == StorageLevel.NONE) 118 } else { 119 assert(graph.vertices.getStorageLevel != StorageLevel.NONE) 120 assert(graph.edges.getStorageLevel != StorageLevel.NONE) 121 } 122 } catch { 123 case _: AssertionError => 124 throw new Exception(s"PeriodicGraphCheckpointerSuite.checkPersistence failed with:\n" + 125 s"\t gIndex = $gIndex\n" + 126 s"\t iteration = $iteration\n" + 127 s"\t graph.vertices.getStorageLevel = ${graph.vertices.getStorageLevel}\n" + 128 s"\t graph.edges.getStorageLevel = ${graph.edges.getStorageLevel}\n") 129 } 130 } 131 132 def checkCheckpoint(graphs: Seq[GraphToCheck], iteration: Int, checkpointInterval: Int): Unit = { 133 graphs.reverse.foreach { g => 134 checkCheckpoint(g.graph, g.gIndex, iteration, checkpointInterval) 135 } 136 } 137 138 def confirmCheckpointRemoved(graph: Graph[_, _]): Unit = { 139 // Note: We cannot check graph.isCheckpointed since that value is never updated. 140 // Instead, we check for the presence of the checkpoint files. 141 // This test should continue to work even after this graph.isCheckpointed issue 142 // is fixed (though it can then be simplified and not look for the files). 143 val hadoopConf = graph.vertices.sparkContext.hadoopConfiguration 144 graph.getCheckpointFiles.foreach { checkpointFile => 145 val path = new Path(checkpointFile) 146 val fs = path.getFileSystem(hadoopConf) 147 assert(!fs.exists(path), 148 "Graph checkpoint file should have been removed") 149 } 150 } 151 152 /** 153 * Check checkpointed status of graph. 154 * @param gIndex Index of graph in order inserted into checkpointer (from 1). 155 * @param iteration Total number of graphs inserted into checkpointer. 156 */ 157 def checkCheckpoint( 158 graph: Graph[_, _], 159 gIndex: Int, 160 iteration: Int, 161 checkpointInterval: Int): Unit = { 162 try { 163 if (gIndex % checkpointInterval == 0) { 164 // We allow 2 checkpoint intervals since we perform an action (checkpointing a second graph) 165 // only AFTER PeriodicGraphCheckpointer decides whether to remove the previous checkpoint. 166 if (iteration - 2 * checkpointInterval < gIndex && gIndex <= iteration) { 167 assert(graph.isCheckpointed, "Graph should be checkpointed") 168 assert(graph.getCheckpointFiles.length == 2, "Graph should have 2 checkpoint files") 169 } else { 170 confirmCheckpointRemoved(graph) 171 } 172 } else { 173 // Graph should never be checkpointed 174 assert(!graph.isCheckpointed, "Graph should never have been checkpointed") 175 assert(graph.getCheckpointFiles.isEmpty, "Graph should not have any checkpoint files") 176 } 177 } catch { 178 case e: AssertionError => 179 throw new Exception(s"PeriodicGraphCheckpointerSuite.checkCheckpoint failed with:\n" + 180 s"\t gIndex = $gIndex\n" + 181 s"\t iteration = $iteration\n" + 182 s"\t checkpointInterval = $checkpointInterval\n" + 183 s"\t graph.isCheckpointed = ${graph.isCheckpointed}\n" + 184 s"\t graph.getCheckpointFiles = ${graph.getCheckpointFiles.mkString(", ")}\n" + 185 s" AssertionError message: ${e.getMessage}") 186 } 187 } 188 189} 190