1/*
2 * Licensed to the Apache Software Foundation (ASF) under one or more
3 * contributor license agreements.  See the NOTICE file distributed with
4 * this work for additional information regarding copyright ownership.
5 * The ASF licenses this file to You under the Apache License, Version 2.0
6 * (the "License"); you may not use this file except in compliance with
7 * the License.  You may obtain a copy of the License at
8 *
9 *    http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17
18package org.apache.spark
19
20import java.util.concurrent.Semaphore
21
22import scala.concurrent.ExecutionContext.Implicits.global
23import scala.concurrent.duration._
24import scala.concurrent.Future
25
26import org.scalatest.BeforeAndAfter
27import org.scalatest.Matchers
28
29import org.apache.spark.scheduler.{SparkListener, SparkListenerTaskStart}
30import org.apache.spark.util.ThreadUtils
31
32/**
33 * Test suite for cancelling running jobs. We run the cancellation tasks for single job action
34 * (e.g. count) as well as multi-job action (e.g. take). We test the local and cluster schedulers
35 * in both FIFO and fair scheduling modes.
36 */
37class JobCancellationSuite extends SparkFunSuite with Matchers with BeforeAndAfter
38  with LocalSparkContext {
39
40  override def afterEach() {
41    try {
42      resetSparkContext()
43    } finally {
44      super.afterEach()
45    }
46  }
47
48  test("local mode, FIFO scheduler") {
49    val conf = new SparkConf().set("spark.scheduler.mode", "FIFO")
50    sc = new SparkContext("local[2]", "test", conf)
51    testCount()
52    testTake()
53    // Make sure we can still launch tasks.
54    assert(sc.parallelize(1 to 10, 2).count === 10)
55  }
56
57  test("local mode, fair scheduler") {
58    val conf = new SparkConf().set("spark.scheduler.mode", "FAIR")
59    val xmlPath = getClass.getClassLoader.getResource("fairscheduler.xml").getFile()
60    conf.set("spark.scheduler.allocation.file", xmlPath)
61    sc = new SparkContext("local[2]", "test", conf)
62    testCount()
63    testTake()
64    // Make sure we can still launch tasks.
65    assert(sc.parallelize(1 to 10, 2).count === 10)
66  }
67
68  test("cluster mode, FIFO scheduler") {
69    val conf = new SparkConf().set("spark.scheduler.mode", "FIFO")
70    sc = new SparkContext("local-cluster[2,1,1024]", "test", conf)
71    testCount()
72    testTake()
73    // Make sure we can still launch tasks.
74    assert(sc.parallelize(1 to 10, 2).count === 10)
75  }
76
77  test("cluster mode, fair scheduler") {
78    val conf = new SparkConf().set("spark.scheduler.mode", "FAIR")
79    val xmlPath = getClass.getClassLoader.getResource("fairscheduler.xml").getFile()
80    conf.set("spark.scheduler.allocation.file", xmlPath)
81    sc = new SparkContext("local-cluster[2,1,1024]", "test", conf)
82    testCount()
83    testTake()
84    // Make sure we can still launch tasks.
85    assert(sc.parallelize(1 to 10, 2).count === 10)
86  }
87
88  test("do not put partially executed partitions into cache") {
89    // In this test case, we create a scenario in which a partition is only partially executed,
90    // and make sure CacheManager does not put that partially executed partition into the
91    // BlockManager.
92    import JobCancellationSuite._
93    sc = new SparkContext("local", "test")
94
95    // Run from 1 to 10, and then block and wait for the task to be killed.
96    val rdd = sc.parallelize(1 to 1000, 2).map { x =>
97      if (x > 10) {
98        taskStartedSemaphore.release()
99        taskCancelledSemaphore.acquire()
100      }
101      x
102    }.cache()
103
104    val rdd1 = rdd.map(x => x)
105
106    Future {
107      taskStartedSemaphore.acquire()
108      sc.cancelAllJobs()
109      taskCancelledSemaphore.release(100000)
110    }
111
112    intercept[SparkException] { rdd1.count() }
113    // If the partial block is put into cache, rdd.count() would return a number less than 1000.
114    assert(rdd.count() === 1000)
115  }
116
117  test("job group") {
118    sc = new SparkContext("local[2]", "test")
119
120    // Add a listener to release the semaphore once any tasks are launched.
121    val sem = new Semaphore(0)
122    sc.addSparkListener(new SparkListener {
123      override def onTaskStart(taskStart: SparkListenerTaskStart) {
124        sem.release()
125      }
126    })
127
128    // jobA is the one to be cancelled.
129    val jobA = Future {
130      sc.setJobGroup("jobA", "this is a job to be cancelled")
131      sc.parallelize(1 to 10000, 2).map { i => Thread.sleep(10); i }.count()
132    }
133
134    // Block until both tasks of job A have started and cancel job A.
135    sem.acquire(2)
136
137    sc.clearJobGroup()
138    val jobB = sc.parallelize(1 to 100, 2).countAsync()
139    sc.cancelJobGroup("jobA")
140    val e = intercept[SparkException] { ThreadUtils.awaitResult(jobA, Duration.Inf) }.getCause
141    assert(e.getMessage contains "cancel")
142
143    // Once A is cancelled, job B should finish fairly quickly.
144    assert(jobB.get() === 100)
145  }
146
147  test("inherited job group (SPARK-6629)") {
148    sc = new SparkContext("local[2]", "test")
149
150    // Add a listener to release the semaphore once any tasks are launched.
151    val sem = new Semaphore(0)
152    sc.addSparkListener(new SparkListener {
153      override def onTaskStart(taskStart: SparkListenerTaskStart) {
154        sem.release()
155      }
156    })
157
158    sc.setJobGroup("jobA", "this is a job to be cancelled")
159    @volatile var exception: Exception = null
160    val jobA = new Thread() {
161      // The job group should be inherited by this thread
162      override def run(): Unit = {
163        exception = intercept[SparkException] {
164          sc.parallelize(1 to 10000, 2).map { i => Thread.sleep(10); i }.count()
165        }
166      }
167    }
168    jobA.start()
169
170    // Block until both tasks of job A have started and cancel job A.
171    sem.acquire(2)
172    sc.cancelJobGroup("jobA")
173    jobA.join(10000)
174    assert(!jobA.isAlive)
175    assert(exception.getMessage contains "cancel")
176
177    // Once A is cancelled, job B should finish fairly quickly.
178    val jobB = sc.parallelize(1 to 100, 2).countAsync()
179    assert(jobB.get() === 100)
180  }
181
182  test("job group with interruption") {
183    sc = new SparkContext("local[2]", "test")
184
185    // Add a listener to release the semaphore once any tasks are launched.
186    val sem = new Semaphore(0)
187    sc.addSparkListener(new SparkListener {
188      override def onTaskStart(taskStart: SparkListenerTaskStart) {
189        sem.release()
190      }
191    })
192
193    // jobA is the one to be cancelled.
194    val jobA = Future {
195      sc.setJobGroup("jobA", "this is a job to be cancelled", interruptOnCancel = true)
196      sc.parallelize(1 to 10000, 2).map { i => Thread.sleep(100000); i }.count()
197    }
198
199    // Block until both tasks of job A have started and cancel job A.
200    sem.acquire(2)
201
202    sc.clearJobGroup()
203    val jobB = sc.parallelize(1 to 100, 2).countAsync()
204    sc.cancelJobGroup("jobA")
205    val e = intercept[SparkException] { ThreadUtils.awaitResult(jobA, 5.seconds) }.getCause
206    assert(e.getMessage contains "cancel")
207
208    // Once A is cancelled, job B should finish fairly quickly.
209    assert(jobB.get() === 100)
210  }
211
212  test("task reaper kills JVM if killed tasks keep running for too long") {
213    val conf = new SparkConf()
214      .set("spark.task.reaper.enabled", "true")
215      .set("spark.task.reaper.killTimeout", "5s")
216    sc = new SparkContext("local-cluster[2,1,1024]", "test", conf)
217
218    // Add a listener to release the semaphore once any tasks are launched.
219    val sem = new Semaphore(0)
220    sc.addSparkListener(new SparkListener {
221      override def onTaskStart(taskStart: SparkListenerTaskStart) {
222        sem.release()
223      }
224    })
225
226    // jobA is the one to be cancelled.
227    val jobA = Future {
228      sc.setJobGroup("jobA", "this is a job to be cancelled", interruptOnCancel = true)
229      sc.parallelize(1 to 10000, 2).map { i =>
230        while (true) { }
231      }.count()
232    }
233
234    // Block until both tasks of job A have started and cancel job A.
235    sem.acquire(2)
236    // Small delay to ensure tasks actually start executing the task body
237    Thread.sleep(1000)
238
239    sc.clearJobGroup()
240    val jobB = sc.parallelize(1 to 100, 2).countAsync()
241    sc.cancelJobGroup("jobA")
242    val e = intercept[SparkException] { ThreadUtils.awaitResult(jobA, 15.seconds) }.getCause
243    assert(e.getMessage contains "cancel")
244
245    // Once A is cancelled, job B should finish fairly quickly.
246    assert(ThreadUtils.awaitResult(jobB, 60.seconds) === 100)
247  }
248
249  test("task reaper will not kill JVM if spark.task.killTimeout == -1") {
250    val conf = new SparkConf()
251      .set("spark.task.reaper.enabled", "true")
252      .set("spark.task.reaper.killTimeout", "-1")
253      .set("spark.task.reaper.PollingInterval", "1s")
254      .set("spark.deploy.maxExecutorRetries", "1")
255    sc = new SparkContext("local-cluster[2,1,1024]", "test", conf)
256
257    // Add a listener to release the semaphore once any tasks are launched.
258    val sem = new Semaphore(0)
259    sc.addSparkListener(new SparkListener {
260      override def onTaskStart(taskStart: SparkListenerTaskStart) {
261        sem.release()
262      }
263    })
264
265    // jobA is the one to be cancelled.
266    val jobA = Future {
267      sc.setJobGroup("jobA", "this is a job to be cancelled", interruptOnCancel = true)
268      sc.parallelize(1 to 2, 2).map { i =>
269        val startTime = System.currentTimeMillis()
270        while (System.currentTimeMillis() < startTime + 10000) { }
271      }.count()
272    }
273
274    // Block until both tasks of job A have started and cancel job A.
275    sem.acquire(2)
276    // Small delay to ensure tasks actually start executing the task body
277    Thread.sleep(1000)
278
279    sc.clearJobGroup()
280    val jobB = sc.parallelize(1 to 100, 2).countAsync()
281    sc.cancelJobGroup("jobA")
282    val e = intercept[SparkException] { ThreadUtils.awaitResult(jobA, 15.seconds) }.getCause
283    assert(e.getMessage contains "cancel")
284
285    // Once A is cancelled, job B should finish fairly quickly.
286    assert(ThreadUtils.awaitResult(jobB, 60.seconds) === 100)
287  }
288
289  test("two jobs sharing the same stage") {
290    // sem1: make sure cancel is issued after some tasks are launched
291    // twoJobsSharingStageSemaphore:
292    //   make sure the first stage is not finished until cancel is issued
293    val sem1 = new Semaphore(0)
294
295    sc = new SparkContext("local[2]", "test")
296    sc.addSparkListener(new SparkListener {
297      override def onTaskStart(taskStart: SparkListenerTaskStart) {
298        sem1.release()
299      }
300    })
301
302    // Create two actions that would share the some stages.
303    val rdd = sc.parallelize(1 to 10, 2).map { i =>
304      JobCancellationSuite.twoJobsSharingStageSemaphore.acquire()
305      (i, i)
306    }.reduceByKey(_ + _)
307    val f1 = rdd.collectAsync()
308    val f2 = rdd.countAsync()
309
310    // Kill one of the action.
311    Future {
312      sem1.acquire()
313      f1.cancel()
314      JobCancellationSuite.twoJobsSharingStageSemaphore.release(10)
315    }
316
317    // Expect f1 to fail due to cancellation,
318    intercept[SparkException] { f1.get() }
319    // but f2 should not be affected
320    f2.get()
321  }
322
323  def testCount() {
324    // Cancel before launching any tasks
325    {
326      val f = sc.parallelize(1 to 10000, 2).map { i => Thread.sleep(10); i }.countAsync()
327      Future { f.cancel() }
328      val e = intercept[SparkException] { f.get() }.getCause
329      assert(e.getMessage.contains("cancelled") || e.getMessage.contains("killed"))
330    }
331
332    // Cancel after some tasks have been launched
333    {
334      // Add a listener to release the semaphore once any tasks are launched.
335      val sem = new Semaphore(0)
336      sc.addSparkListener(new SparkListener {
337        override def onTaskStart(taskStart: SparkListenerTaskStart) {
338          sem.release()
339        }
340      })
341
342      val f = sc.parallelize(1 to 10000, 2).map { i => Thread.sleep(10); i }.countAsync()
343      Future {
344        // Wait until some tasks were launched before we cancel the job.
345        sem.acquire()
346        f.cancel()
347      }
348      val e = intercept[SparkException] { f.get() }.getCause
349      assert(e.getMessage.contains("cancelled") || e.getMessage.contains("killed"))
350    }
351  }
352
353  def testTake() {
354    // Cancel before launching any tasks
355    {
356      val f = sc.parallelize(1 to 10000, 2).map { i => Thread.sleep(10); i }.takeAsync(5000)
357      Future { f.cancel() }
358      val e = intercept[SparkException] { f.get() }.getCause
359      assert(e.getMessage.contains("cancelled") || e.getMessage.contains("killed"))
360    }
361
362    // Cancel after some tasks have been launched
363    {
364      // Add a listener to release the semaphore once any tasks are launched.
365      val sem = new Semaphore(0)
366      sc.addSparkListener(new SparkListener {
367        override def onTaskStart(taskStart: SparkListenerTaskStart) {
368          sem.release()
369        }
370      })
371      val f = sc.parallelize(1 to 10000, 2).map { i => Thread.sleep(10); i }.takeAsync(5000)
372      Future {
373        sem.acquire()
374        f.cancel()
375      }
376      val e = intercept[SparkException] { f.get() }.getCause
377      assert(e.getMessage.contains("cancelled") || e.getMessage.contains("killed"))
378    }
379  }
380}
381
382
383object JobCancellationSuite {
384  val taskStartedSemaphore = new Semaphore(0)
385  val taskCancelledSemaphore = new Semaphore(0)
386  val twoJobsSharingStageSemaphore = new Semaphore(0)
387}
388