1/*
2 * Licensed to the Apache Software Foundation (ASF) under one or more
3 * contributor license agreements.  See the NOTICE file distributed with
4 * this work for additional information regarding copyright ownership.
5 * The ASF licenses this file to You under the Apache License, Version 2.0
6 * (the "License"); you may not use this file except in compliance with
7 * the License.  You may obtain a copy of the License at
8 *
9 *    http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17
18package org.apache.spark
19
20import java.lang.ref.WeakReference
21
22import scala.collection.mutable.HashSet
23import scala.language.existentials
24import scala.util.Random
25
26import org.scalatest.BeforeAndAfter
27import org.scalatest.concurrent.Eventually._
28import org.scalatest.concurrent.PatienceConfiguration
29import org.scalatest.time.SpanSugar._
30
31import org.apache.spark.internal.Logging
32import org.apache.spark.rdd.{RDD, ReliableRDDCheckpointData}
33import org.apache.spark.shuffle.sort.SortShuffleManager
34import org.apache.spark.storage._
35import org.apache.spark.util.Utils
36
37/**
38 * An abstract base class for context cleaner tests, which sets up a context with a config
39 * suitable for cleaner tests and provides some utility functions. Subclasses can use different
40 * config options, in particular, a different shuffle manager class
41 */
42abstract class ContextCleanerSuiteBase(val shuffleManager: Class[_] = classOf[SortShuffleManager])
43  extends SparkFunSuite with BeforeAndAfter with LocalSparkContext
44{
45  implicit val defaultTimeout = timeout(10000 millis)
46  val conf = new SparkConf()
47    .setMaster("local[2]")
48    .setAppName("ContextCleanerSuite")
49    .set("spark.cleaner.referenceTracking.blocking", "true")
50    .set("spark.cleaner.referenceTracking.blocking.shuffle", "true")
51    .set("spark.cleaner.referenceTracking.cleanCheckpoints", "true")
52    .set("spark.shuffle.manager", shuffleManager.getName)
53
54  before {
55    sc = new SparkContext(conf)
56  }
57
58  after {
59    if (sc != null) {
60      sc.stop()
61      sc = null
62    }
63  }
64
65  // ------ Helper functions ------
66
67  protected def newRDD() = sc.makeRDD(1 to 10)
68  protected def newPairRDD() = newRDD().map(_ -> 1)
69  protected def newShuffleRDD() = newPairRDD().reduceByKey(_ + _)
70  protected def newBroadcast() = sc.broadcast(1 to 100)
71
72  protected def newRDDWithShuffleDependencies(): (RDD[_], Seq[ShuffleDependency[_, _, _]]) = {
73    def getAllDependencies(rdd: RDD[_]): Seq[Dependency[_]] = {
74      rdd.dependencies ++ rdd.dependencies.flatMap { dep =>
75        getAllDependencies(dep.rdd)
76      }
77    }
78    val rdd = newShuffleRDD()
79
80    // Get all the shuffle dependencies
81    val shuffleDeps = getAllDependencies(rdd)
82      .filter(_.isInstanceOf[ShuffleDependency[_, _, _]])
83      .map(_.asInstanceOf[ShuffleDependency[_, _, _]])
84    (rdd, shuffleDeps)
85  }
86
87  protected def randomRdd() = {
88    val rdd: RDD[_] = Random.nextInt(3) match {
89      case 0 => newRDD()
90      case 1 => newShuffleRDD()
91      case 2 => newPairRDD.join(newPairRDD())
92    }
93    if (Random.nextBoolean()) rdd.persist()
94    rdd.count()
95    rdd
96  }
97
98  /** Run GC and make sure it actually has run */
99  protected def runGC() {
100    val weakRef = new WeakReference(new Object())
101    val startTime = System.currentTimeMillis
102    System.gc() // Make a best effort to run the garbage collection. It *usually* runs GC.
103    // Wait until a weak reference object has been GCed
104    while (System.currentTimeMillis - startTime < 10000 && weakRef.get != null) {
105      System.gc()
106      Thread.sleep(200)
107    }
108  }
109
110  protected def cleaner = sc.cleaner.get
111}
112
113
114/**
115 * Basic ContextCleanerSuite, which uses sort-based shuffle
116 */
117class ContextCleanerSuite extends ContextCleanerSuiteBase {
118  test("cleanup RDD") {
119    val rdd = newRDD().persist()
120    val collected = rdd.collect().toList
121    val tester = new CleanerTester(sc, rddIds = Seq(rdd.id))
122
123    // Explicit cleanup
124    cleaner.doCleanupRDD(rdd.id, blocking = true)
125    tester.assertCleanup()
126
127    // Verify that RDDs can be re-executed after cleaning up
128    assert(rdd.collect().toList === collected)
129  }
130
131  test("cleanup shuffle") {
132    val (rdd, shuffleDeps) = newRDDWithShuffleDependencies()
133    val collected = rdd.collect().toList
134    val tester = new CleanerTester(sc, shuffleIds = shuffleDeps.map(_.shuffleId))
135
136    // Explicit cleanup
137    shuffleDeps.foreach(s => cleaner.doCleanupShuffle(s.shuffleId, blocking = true))
138    tester.assertCleanup()
139
140    // Verify that shuffles can be re-executed after cleaning up
141    assert(rdd.collect().toList.equals(collected))
142  }
143
144  test("cleanup broadcast") {
145    val broadcast = newBroadcast()
146    val tester = new CleanerTester(sc, broadcastIds = Seq(broadcast.id))
147
148    // Explicit cleanup
149    cleaner.doCleanupBroadcast(broadcast.id, blocking = true)
150    tester.assertCleanup()
151  }
152
153  test("automatically cleanup RDD") {
154    var rdd = newRDD().persist()
155    rdd.count()
156
157    // Test that GC does not cause RDD cleanup due to a strong reference
158    val preGCTester = new CleanerTester(sc, rddIds = Seq(rdd.id))
159    runGC()
160    intercept[Exception] {
161      preGCTester.assertCleanup()(timeout(1000 millis))
162    }
163
164    // Test that GC causes RDD cleanup after dereferencing the RDD
165    // Note rdd is used after previous GC to avoid early collection by the JVM
166    val postGCTester = new CleanerTester(sc, rddIds = Seq(rdd.id))
167    rdd = null // Make RDD out of scope
168    runGC()
169    postGCTester.assertCleanup()
170  }
171
172  test("automatically cleanup shuffle") {
173    var rdd = newShuffleRDD()
174    rdd.count()
175
176    // Test that GC does not cause shuffle cleanup due to a strong reference
177    val preGCTester = new CleanerTester(sc, shuffleIds = Seq(0))
178    runGC()
179    intercept[Exception] {
180      preGCTester.assertCleanup()(timeout(1000 millis))
181    }
182    rdd.count()  // Defeat early collection by the JVM
183
184    // Test that GC causes shuffle cleanup after dereferencing the RDD
185    val postGCTester = new CleanerTester(sc, shuffleIds = Seq(0))
186    rdd = null  // Make RDD out of scope, so that corresponding shuffle goes out of scope
187    runGC()
188    postGCTester.assertCleanup()
189  }
190
191  test("automatically cleanup broadcast") {
192    var broadcast = newBroadcast()
193
194    // Test that GC does not cause broadcast cleanup due to a strong reference
195    val preGCTester = new CleanerTester(sc, broadcastIds = Seq(broadcast.id))
196    runGC()
197    intercept[Exception] {
198      preGCTester.assertCleanup()(timeout(1000 millis))
199    }
200
201    // Test that GC causes broadcast cleanup after dereferencing the broadcast variable
202    // Note broadcast is used after previous GC to avoid early collection by the JVM
203    val postGCTester = new CleanerTester(sc, broadcastIds = Seq(broadcast.id))
204    broadcast = null  // Make broadcast variable out of scope
205    runGC()
206    postGCTester.assertCleanup()
207  }
208
209  test("automatically cleanup normal checkpoint") {
210    val checkpointDir = Utils.createTempDir()
211    checkpointDir.delete()
212    var rdd = newPairRDD()
213    sc.setCheckpointDir(checkpointDir.toString)
214    rdd.checkpoint()
215    rdd.cache()
216    rdd.collect()
217    var rddId = rdd.id
218
219    // Confirm the checkpoint directory exists
220    assert(ReliableRDDCheckpointData.checkpointPath(sc, rddId).isDefined)
221    val path = ReliableRDDCheckpointData.checkpointPath(sc, rddId).get
222    val fs = path.getFileSystem(sc.hadoopConfiguration)
223    assert(fs.exists(path))
224
225    // the checkpoint is not cleaned by default (without the configuration set)
226    var postGCTester = new CleanerTester(sc, Seq(rddId), Nil, Nil, Seq(rddId))
227    rdd = null // Make RDD out of scope, ok if collected earlier
228    runGC()
229    postGCTester.assertCleanup()
230    assert(!fs.exists(ReliableRDDCheckpointData.checkpointPath(sc, rddId).get))
231
232    // Verify that checkpoints are NOT cleaned up if the config is not enabled
233    sc.stop()
234    val conf = new SparkConf()
235      .setMaster("local[2]")
236      .setAppName("cleanupCheckpoint")
237      .set("spark.cleaner.referenceTracking.cleanCheckpoints", "false")
238    sc = new SparkContext(conf)
239    rdd = newPairRDD()
240    sc.setCheckpointDir(checkpointDir.toString)
241    rdd.checkpoint()
242    rdd.cache()
243    rdd.collect()
244    rddId = rdd.id
245
246    // Confirm the checkpoint directory exists
247    assert(fs.exists(ReliableRDDCheckpointData.checkpointPath(sc, rddId).get))
248
249    // Reference rdd to defeat any early collection by the JVM
250    rdd.count()
251
252    // Test that GC causes checkpoint data cleanup after dereferencing the RDD
253    postGCTester = new CleanerTester(sc, Seq(rddId))
254    rdd = null // Make RDD out of scope
255    runGC()
256    postGCTester.assertCleanup()
257    assert(fs.exists(ReliableRDDCheckpointData.checkpointPath(sc, rddId).get))
258  }
259
260  test("automatically clean up local checkpoint") {
261    // Note that this test is similar to the RDD cleanup
262    // test because the same underlying mechanism is used!
263    var rdd = newPairRDD().localCheckpoint()
264    assert(rdd.checkpointData.isDefined)
265    assert(rdd.checkpointData.get.checkpointRDD.isEmpty)
266    rdd.count()
267    assert(rdd.checkpointData.get.checkpointRDD.isDefined)
268
269    // Test that GC does not cause checkpoint cleanup due to a strong reference
270    val preGCTester = new CleanerTester(sc, rddIds = Seq(rdd.id))
271    runGC()
272    intercept[Exception] {
273      preGCTester.assertCleanup()(timeout(1000 millis))
274    }
275
276    // Test that RDD going out of scope does cause the checkpoint blocks to be cleaned up
277    val postGCTester = new CleanerTester(sc, rddIds = Seq(rdd.id))
278    rdd = null
279    runGC()
280    postGCTester.assertCleanup()
281  }
282
283  test("automatically cleanup RDD + shuffle + broadcast") {
284    val numRdds = 100
285    val numBroadcasts = 4 // Broadcasts are more costly
286    val rddBuffer = (1 to numRdds).map(i => randomRdd()).toBuffer
287    val broadcastBuffer = (1 to numBroadcasts).map(i => newBroadcast()).toBuffer
288    val rddIds = sc.persistentRdds.keys.toSeq
289    val shuffleIds = 0 until sc.newShuffleId
290    val broadcastIds = broadcastBuffer.map(_.id)
291
292    val preGCTester = new CleanerTester(sc, rddIds, shuffleIds, broadcastIds)
293    runGC()
294    intercept[Exception] {
295      preGCTester.assertCleanup()(timeout(1000 millis))
296    }
297
298    // Test that GC triggers the cleanup of all variables after the dereferencing them
299    val postGCTester = new CleanerTester(sc, rddIds, shuffleIds, broadcastIds)
300    broadcastBuffer.clear()
301    rddBuffer.clear()
302    runGC()
303    postGCTester.assertCleanup()
304
305    // Make sure the broadcasted task closure no longer exists after GC.
306    val taskClosureBroadcastId = broadcastIds.max + 1
307    assert(sc.env.blockManager.master.getMatchingBlockIds({
308      case BroadcastBlockId(`taskClosureBroadcastId`, _) => true
309      case _ => false
310    }, askSlaves = true).isEmpty)
311  }
312
313  test("automatically cleanup RDD + shuffle + broadcast in distributed mode") {
314    sc.stop()
315
316    val conf2 = new SparkConf()
317      .setMaster("local-cluster[2, 1, 1024]")
318      .setAppName("ContextCleanerSuite")
319      .set("spark.cleaner.referenceTracking.blocking", "true")
320      .set("spark.cleaner.referenceTracking.blocking.shuffle", "true")
321      .set("spark.shuffle.manager", shuffleManager.getName)
322    sc = new SparkContext(conf2)
323
324    val numRdds = 10
325    val numBroadcasts = 4 // Broadcasts are more costly
326    val rddBuffer = (1 to numRdds).map(i => randomRdd()).toBuffer
327    val broadcastBuffer = (1 to numBroadcasts).map(i => newBroadcast()).toBuffer
328    val rddIds = sc.persistentRdds.keys.toSeq
329    val shuffleIds = 0 until sc.newShuffleId
330    val broadcastIds = broadcastBuffer.map(_.id)
331
332    val preGCTester = new CleanerTester(sc, rddIds, shuffleIds, broadcastIds)
333    runGC()
334    intercept[Exception] {
335      preGCTester.assertCleanup()(timeout(1000 millis))
336    }
337
338    // Test that GC triggers the cleanup of all variables after the dereferencing them
339    val postGCTester = new CleanerTester(sc, rddIds, shuffleIds, broadcastIds)
340    broadcastBuffer.clear()
341    rddBuffer.clear()
342    runGC()
343    postGCTester.assertCleanup()
344
345    // Make sure the broadcasted task closure no longer exists after GC.
346    val taskClosureBroadcastId = broadcastIds.max + 1
347    assert(sc.env.blockManager.master.getMatchingBlockIds({
348      case BroadcastBlockId(`taskClosureBroadcastId`, _) => true
349      case _ => false
350    }, askSlaves = true).isEmpty)
351  }
352}
353
354
355/**
356 * Class to test whether RDDs, shuffles, etc. have been successfully cleaned.
357 * The checkpoint here refers only to normal (reliable) checkpoints, not local checkpoints.
358 */
359class CleanerTester(
360    sc: SparkContext,
361    rddIds: Seq[Int] = Seq.empty,
362    shuffleIds: Seq[Int] = Seq.empty,
363    broadcastIds: Seq[Long] = Seq.empty,
364    checkpointIds: Seq[Long] = Seq.empty)
365  extends Logging {
366
367  val toBeCleanedRDDIds = new HashSet[Int] ++= rddIds
368  val toBeCleanedShuffleIds = new HashSet[Int] ++= shuffleIds
369  val toBeCleanedBroadcstIds = new HashSet[Long] ++= broadcastIds
370  val toBeCheckpointIds = new HashSet[Long] ++= checkpointIds
371  val isDistributed = !sc.isLocal
372
373  val cleanerListener = new CleanerListener {
374    def rddCleaned(rddId: Int): Unit = {
375      toBeCleanedRDDIds.synchronized { toBeCleanedRDDIds -= rddId }
376      logInfo("RDD " + rddId + " cleaned")
377    }
378
379    def shuffleCleaned(shuffleId: Int): Unit = {
380      toBeCleanedShuffleIds.synchronized { toBeCleanedShuffleIds -= shuffleId }
381      logInfo("Shuffle " + shuffleId + " cleaned")
382    }
383
384    def broadcastCleaned(broadcastId: Long): Unit = {
385      toBeCleanedBroadcstIds.synchronized { toBeCleanedBroadcstIds -= broadcastId }
386      logInfo("Broadcast " + broadcastId + " cleaned")
387    }
388
389    def accumCleaned(accId: Long): Unit = {
390      logInfo("Cleaned accId " + accId + " cleaned")
391    }
392
393    def checkpointCleaned(rddId: Long): Unit = {
394      toBeCheckpointIds.synchronized { toBeCheckpointIds -= rddId }
395      logInfo("checkpoint  " + rddId + " cleaned")
396    }
397  }
398
399  val MAX_VALIDATION_ATTEMPTS = 10
400  val VALIDATION_ATTEMPT_INTERVAL = 100
401
402  logInfo("Attempting to validate before cleanup:\n" + uncleanedResourcesToString)
403  preCleanupValidate()
404  sc.cleaner.get.attachListener(cleanerListener)
405
406  /** Assert that all the stuff has been cleaned up */
407  def assertCleanup()(implicit waitTimeout: PatienceConfiguration.Timeout) {
408    try {
409      eventually(waitTimeout, interval(100 millis)) {
410        assert(isAllCleanedUp,
411          "The following resources were not cleaned up:\n" + uncleanedResourcesToString)
412      }
413      postCleanupValidate()
414    } finally {
415      logInfo("Resources left from cleaning up:\n" + uncleanedResourcesToString)
416    }
417  }
418
419  /** Verify that RDDs, shuffles, etc. occupy resources */
420  private def preCleanupValidate() {
421    assert(rddIds.nonEmpty || shuffleIds.nonEmpty || broadcastIds.nonEmpty ||
422      checkpointIds.nonEmpty, "Nothing to cleanup")
423
424    // Verify the RDDs have been persisted and blocks are present
425    rddIds.foreach { rddId =>
426      assert(
427        sc.persistentRdds.contains(rddId),
428        "RDD " + rddId + " have not been persisted, cannot start cleaner test"
429      )
430
431      assert(
432        !getRDDBlocks(rddId).isEmpty,
433        "Blocks of RDD " + rddId + " cannot be found in block manager, " +
434          "cannot start cleaner test"
435      )
436    }
437
438    // Verify the shuffle ids are registered and blocks are present
439    shuffleIds.foreach { shuffleId =>
440      assert(
441        mapOutputTrackerMaster.containsShuffle(shuffleId),
442        "Shuffle " + shuffleId + " have not been registered, cannot start cleaner test"
443      )
444
445      assert(
446        !getShuffleBlocks(shuffleId).isEmpty,
447        "Blocks of shuffle " + shuffleId + " cannot be found in block manager, " +
448          "cannot start cleaner test"
449      )
450    }
451
452    // Verify that the broadcast blocks are present
453    broadcastIds.foreach { broadcastId =>
454      assert(
455        !getBroadcastBlocks(broadcastId).isEmpty,
456        "Blocks of broadcast " + broadcastId + "cannot be found in block manager, " +
457          "cannot start cleaner test"
458      )
459    }
460  }
461
462  /**
463   * Verify that RDDs, shuffles, etc. do not occupy resources. Tests multiple times as there is
464   * as there is not guarantee on how long it will take clean up the resources.
465   */
466  private def postCleanupValidate() {
467    // Verify the RDDs have been persisted and blocks are present
468    rddIds.foreach { rddId =>
469      assert(
470        !sc.persistentRdds.contains(rddId),
471        "RDD " + rddId + " was not cleared from sc.persistentRdds"
472      )
473
474      assert(
475        getRDDBlocks(rddId).isEmpty,
476        "Blocks of RDD " + rddId + " were not cleared from block manager"
477      )
478    }
479
480    // Verify the shuffle ids are registered and blocks are present
481    shuffleIds.foreach { shuffleId =>
482      assert(
483        !mapOutputTrackerMaster.containsShuffle(shuffleId),
484        "Shuffle " + shuffleId + " was not deregistered from map output tracker"
485      )
486
487      assert(
488        getShuffleBlocks(shuffleId).isEmpty,
489        "Blocks of shuffle " + shuffleId + " were not cleared from block manager"
490      )
491    }
492
493    // Verify that the broadcast blocks are present
494    broadcastIds.foreach { broadcastId =>
495      assert(
496        getBroadcastBlocks(broadcastId).isEmpty,
497        "Blocks of broadcast " + broadcastId + " were not cleared from block manager"
498      )
499    }
500  }
501
502  private def uncleanedResourcesToString = {
503    val s1 = toBeCleanedRDDIds.synchronized {
504      toBeCleanedRDDIds.toSeq.sorted.mkString("[", ", ", "]")
505    }
506    val s2 = toBeCleanedShuffleIds.synchronized {
507      toBeCleanedShuffleIds.toSeq.sorted.mkString("[", ", ", "]")
508    }
509    val s3 = toBeCleanedBroadcstIds.synchronized {
510      toBeCleanedBroadcstIds.toSeq.sorted.mkString("[", ", ", "]")
511    }
512    s"""
513       |\tRDDs = $s1
514       |\tShuffles = $s2
515       |\tBroadcasts = $s3
516    """.stripMargin
517  }
518
519  private def isAllCleanedUp =
520    toBeCleanedRDDIds.synchronized { toBeCleanedRDDIds.isEmpty } &&
521    toBeCleanedShuffleIds.synchronized { toBeCleanedShuffleIds.isEmpty } &&
522    toBeCleanedBroadcstIds.synchronized { toBeCleanedBroadcstIds.isEmpty } &&
523    toBeCheckpointIds.synchronized { toBeCheckpointIds.isEmpty }
524
525  private def getRDDBlocks(rddId: Int): Seq[BlockId] = {
526    blockManager.master.getMatchingBlockIds( _ match {
527      case RDDBlockId(`rddId`, _) => true
528      case _ => false
529    }, askSlaves = true)
530  }
531
532  private def getShuffleBlocks(shuffleId: Int): Seq[BlockId] = {
533    blockManager.master.getMatchingBlockIds( _ match {
534      case ShuffleBlockId(`shuffleId`, _, _) => true
535      case ShuffleIndexBlockId(`shuffleId`, _, _) => true
536      case _ => false
537    }, askSlaves = true)
538  }
539
540  private def getBroadcastBlocks(broadcastId: Long): Seq[BlockId] = {
541    blockManager.master.getMatchingBlockIds( _ match {
542      case BroadcastBlockId(`broadcastId`, _) => true
543      case _ => false
544    }, askSlaves = true)
545  }
546
547  private def blockManager = sc.env.blockManager
548  private def mapOutputTrackerMaster = sc.env.mapOutputTracker.asInstanceOf[MapOutputTrackerMaster]
549}
550