1/* 2 * Licensed to the Apache Software Foundation (ASF) under one or more 3 * contributor license agreements. See the NOTICE file distributed with 4 * this work for additional information regarding copyright ownership. 5 * The ASF licenses this file to You under the Apache License, Version 2.0 6 * (the "License"); you may not use this file except in compliance with 7 * the License. You may obtain a copy of the License at 8 * 9 * http://www.apache.org/licenses/LICENSE-2.0 10 * 11 * Unless required by applicable law or agreed to in writing, software 12 * distributed under the License is distributed on an "AS IS" BASIS, 13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 * See the License for the specific language governing permissions and 15 * limitations under the License. 16 */ 17 18package org.apache.spark.mllib.impl 19 20import org.apache.hadoop.fs.Path 21 22import org.apache.spark.{SparkContext, SparkFunSuite} 23import org.apache.spark.mllib.util.MLlibTestSparkContext 24import org.apache.spark.rdd.RDD 25import org.apache.spark.storage.StorageLevel 26import org.apache.spark.util.Utils 27 28 29class PeriodicRDDCheckpointerSuite extends SparkFunSuite with MLlibTestSparkContext { 30 31 import PeriodicRDDCheckpointerSuite._ 32 33 test("Persisting") { 34 var rddsToCheck = Seq.empty[RDDToCheck] 35 36 val rdd1 = createRDD(sc) 37 val checkpointer = new PeriodicRDDCheckpointer[Double](10, rdd1.sparkContext) 38 checkpointer.update(rdd1) 39 rddsToCheck = rddsToCheck :+ RDDToCheck(rdd1, 1) 40 checkPersistence(rddsToCheck, 1) 41 42 var iteration = 2 43 while (iteration < 9) { 44 val rdd = createRDD(sc) 45 checkpointer.update(rdd) 46 rddsToCheck = rddsToCheck :+ RDDToCheck(rdd, iteration) 47 checkPersistence(rddsToCheck, iteration) 48 iteration += 1 49 } 50 } 51 52 test("Checkpointing") { 53 val tempDir = Utils.createTempDir() 54 val path = tempDir.toURI.toString 55 val checkpointInterval = 2 56 var rddsToCheck = Seq.empty[RDDToCheck] 57 sc.setCheckpointDir(path) 58 val rdd1 = createRDD(sc) 59 val checkpointer = new PeriodicRDDCheckpointer[Double](checkpointInterval, rdd1.sparkContext) 60 checkpointer.update(rdd1) 61 rdd1.count() 62 rddsToCheck = rddsToCheck :+ RDDToCheck(rdd1, 1) 63 checkCheckpoint(rddsToCheck, 1, checkpointInterval) 64 65 var iteration = 2 66 while (iteration < 9) { 67 val rdd = createRDD(sc) 68 checkpointer.update(rdd) 69 rdd.count() 70 rddsToCheck = rddsToCheck :+ RDDToCheck(rdd, iteration) 71 checkCheckpoint(rddsToCheck, iteration, checkpointInterval) 72 iteration += 1 73 } 74 75 checkpointer.deleteAllCheckpoints() 76 rddsToCheck.foreach { rdd => 77 confirmCheckpointRemoved(rdd.rdd) 78 } 79 80 Utils.deleteRecursively(tempDir) 81 } 82} 83 84private object PeriodicRDDCheckpointerSuite { 85 86 case class RDDToCheck(rdd: RDD[Double], gIndex: Int) 87 88 def createRDD(sc: SparkContext): RDD[Double] = { 89 sc.parallelize(Seq(0.0, 1.0, 2.0, 3.0)) 90 } 91 92 def checkPersistence(rdds: Seq[RDDToCheck], iteration: Int): Unit = { 93 rdds.foreach { g => 94 checkPersistence(g.rdd, g.gIndex, iteration) 95 } 96 } 97 98 /** 99 * Check storage level of rdd. 100 * @param gIndex Index of rdd in order inserted into checkpointer (from 1). 101 * @param iteration Total number of rdds inserted into checkpointer. 102 */ 103 def checkPersistence(rdd: RDD[_], gIndex: Int, iteration: Int): Unit = { 104 try { 105 if (gIndex + 2 < iteration) { 106 assert(rdd.getStorageLevel == StorageLevel.NONE) 107 } else { 108 assert(rdd.getStorageLevel != StorageLevel.NONE) 109 } 110 } catch { 111 case _: AssertionError => 112 throw new Exception(s"PeriodicRDDCheckpointerSuite.checkPersistence failed with:\n" + 113 s"\t gIndex = $gIndex\n" + 114 s"\t iteration = $iteration\n" + 115 s"\t rdd.getStorageLevel = ${rdd.getStorageLevel}\n") 116 } 117 } 118 119 def checkCheckpoint(rdds: Seq[RDDToCheck], iteration: Int, checkpointInterval: Int): Unit = { 120 rdds.reverse.foreach { g => 121 checkCheckpoint(g.rdd, g.gIndex, iteration, checkpointInterval) 122 } 123 } 124 125 def confirmCheckpointRemoved(rdd: RDD[_]): Unit = { 126 // Note: We cannot check rdd.isCheckpointed since that value is never updated. 127 // Instead, we check for the presence of the checkpoint files. 128 // This test should continue to work even after this rdd.isCheckpointed issue 129 // is fixed (though it can then be simplified and not look for the files). 130 val hadoopConf = rdd.sparkContext.hadoopConfiguration 131 rdd.getCheckpointFile.foreach { checkpointFile => 132 val path = new Path(checkpointFile) 133 val fs = path.getFileSystem(hadoopConf) 134 assert(!fs.exists(path), "RDD checkpoint file should have been removed") 135 } 136 } 137 138 /** 139 * Check checkpointed status of rdd. 140 * @param gIndex Index of rdd in order inserted into checkpointer (from 1). 141 * @param iteration Total number of rdds inserted into checkpointer. 142 */ 143 def checkCheckpoint( 144 rdd: RDD[_], 145 gIndex: Int, 146 iteration: Int, 147 checkpointInterval: Int): Unit = { 148 try { 149 if (gIndex % checkpointInterval == 0) { 150 // We allow 2 checkpoint intervals since we perform an action (checkpointing a second rdd) 151 // only AFTER PeriodicRDDCheckpointer decides whether to remove the previous checkpoint. 152 if (iteration - 2 * checkpointInterval < gIndex && gIndex <= iteration) { 153 assert(rdd.isCheckpointed, "RDD should be checkpointed") 154 assert(rdd.getCheckpointFile.nonEmpty, "RDD should have 2 checkpoint files") 155 } else { 156 confirmCheckpointRemoved(rdd) 157 } 158 } else { 159 // RDD should never be checkpointed 160 assert(!rdd.isCheckpointed, "RDD should never have been checkpointed") 161 assert(rdd.getCheckpointFile.isEmpty, "RDD should not have any checkpoint files") 162 } 163 } catch { 164 case e: AssertionError => 165 throw new Exception(s"PeriodicRDDCheckpointerSuite.checkCheckpoint failed with:\n" + 166 s"\t gIndex = $gIndex\n" + 167 s"\t iteration = $iteration\n" + 168 s"\t checkpointInterval = $checkpointInterval\n" + 169 s"\t rdd.isCheckpointed = ${rdd.isCheckpointed}\n" + 170 s"\t rdd.getCheckpointFile = ${rdd.getCheckpointFile.mkString(", ")}\n" + 171 s" AssertionError message: ${e.getMessage}") 172 } 173 } 174 175} 176