1/*
2 * Licensed to the Apache Software Foundation (ASF) under one or more
3 * contributor license agreements.  See the NOTICE file distributed with
4 * this work for additional information regarding copyright ownership.
5 * The ASF licenses this file to You under the Apache License, Version 2.0
6 * (the "License"); you may not use this file except in compliance with
7 * the License.  You may obtain a copy of the License at
8 *
9 *    http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17
18package org.apache.spark.scheduler
19
20import java.util.concurrent.Semaphore
21
22import scala.collection.mutable
23import scala.collection.JavaConverters._
24
25import org.scalatest.Matchers
26
27import org.apache.spark._
28import org.apache.spark.executor.TaskMetrics
29import org.apache.spark.util.{ResetSystemProperties, RpcUtils}
30
31class SparkListenerSuite extends SparkFunSuite with LocalSparkContext with Matchers
32  with ResetSystemProperties {
33
34  /** Length of time to wait while draining listener events. */
35  val WAIT_TIMEOUT_MILLIS = 10000
36
37  val jobCompletionTime = 1421191296660L
38
39  test("don't call sc.stop in listener") {
40    sc = new SparkContext("local", "SparkListenerSuite", new SparkConf())
41    val listener = new SparkContextStoppingListener(sc)
42    val bus = new LiveListenerBus(sc)
43    bus.addListener(listener)
44
45    // Starting listener bus should flush all buffered events
46    bus.start()
47    bus.post(SparkListenerJobEnd(0, jobCompletionTime, JobSucceeded))
48    bus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)
49
50    bus.stop()
51    assert(listener.sparkExSeen)
52  }
53
54  test("basic creation and shutdown of LiveListenerBus") {
55    sc = new SparkContext("local", "SparkListenerSuite", new SparkConf())
56    val counter = new BasicJobCounter
57    val bus = new LiveListenerBus(sc)
58    bus.addListener(counter)
59
60    // Listener bus hasn't started yet, so posting events should not increment counter
61    (1 to 5).foreach { _ => bus.post(SparkListenerJobEnd(0, jobCompletionTime, JobSucceeded)) }
62    assert(counter.count === 0)
63
64    // Starting listener bus should flush all buffered events
65    bus.start()
66    bus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)
67    assert(counter.count === 5)
68
69    // After listener bus has stopped, posting events should not increment counter
70    bus.stop()
71    (1 to 5).foreach { _ => bus.post(SparkListenerJobEnd(0, jobCompletionTime, JobSucceeded)) }
72    assert(counter.count === 5)
73
74    // Listener bus must not be started twice
75    intercept[IllegalStateException] {
76      val bus = new LiveListenerBus(sc)
77      bus.start()
78      bus.start()
79    }
80
81    // ... or stopped before starting
82    intercept[IllegalStateException] {
83      val bus = new LiveListenerBus(sc)
84      bus.stop()
85    }
86  }
87
88  test("bus.stop() waits for the event queue to completely drain") {
89    @volatile var drained = false
90
91    // When Listener has started
92    val listenerStarted = new Semaphore(0)
93
94    // Tells the listener to stop blocking
95    val listenerWait = new Semaphore(0)
96
97    // When stopper has started
98    val stopperStarted = new Semaphore(0)
99
100    // When stopper has returned
101    val stopperReturned = new Semaphore(0)
102
103    class BlockingListener extends SparkListener {
104      override def onJobEnd(jobEnd: SparkListenerJobEnd): Unit = {
105        listenerStarted.release()
106        listenerWait.acquire()
107        drained = true
108      }
109    }
110    sc = new SparkContext("local", "SparkListenerSuite", new SparkConf())
111    val bus = new LiveListenerBus(sc)
112    val blockingListener = new BlockingListener
113
114    bus.addListener(blockingListener)
115    bus.start()
116    bus.post(SparkListenerJobEnd(0, jobCompletionTime, JobSucceeded))
117
118    listenerStarted.acquire()
119    // Listener should be blocked after start
120    assert(!drained)
121
122    new Thread("ListenerBusStopper") {
123      override def run() {
124        stopperStarted.release()
125        // stop() will block until notify() is called below
126        bus.stop()
127        stopperReturned.release()
128      }
129    }.start()
130
131    stopperStarted.acquire()
132    // Listener should remain blocked after stopper started
133    assert(!drained)
134
135    // unblock Listener to let queue drain
136    listenerWait.release()
137    stopperReturned.acquire()
138    assert(drained)
139  }
140
141  test("basic creation of StageInfo") {
142    sc = new SparkContext("local", "SparkListenerSuite")
143    val listener = new SaveStageAndTaskInfo
144    sc.addSparkListener(listener)
145    val rdd1 = sc.parallelize(1 to 100, 4)
146    val rdd2 = rdd1.map(_.toString)
147    rdd2.setName("Target RDD")
148    rdd2.count()
149
150    sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)
151
152    listener.stageInfos.size should be {1}
153    val (stageInfo, taskInfoMetrics) = listener.stageInfos.head
154    stageInfo.rddInfos.size should be {2}
155    stageInfo.rddInfos.forall(_.numPartitions == 4) should be {true}
156    stageInfo.rddInfos.exists(_.name == "Target RDD") should be {true}
157    stageInfo.numTasks should be {4}
158    stageInfo.submissionTime should be ('defined)
159    stageInfo.completionTime should be ('defined)
160    taskInfoMetrics.length should be {4}
161  }
162
163  test("basic creation of StageInfo with shuffle") {
164    sc = new SparkContext("local", "SparkListenerSuite")
165    val listener = new SaveStageAndTaskInfo
166    sc.addSparkListener(listener)
167    val rdd1 = sc.parallelize(1 to 100, 4)
168    val rdd2 = rdd1.filter(_ % 2 == 0).map(i => (i, i))
169    val rdd3 = rdd2.reduceByKey(_ + _)
170    rdd1.setName("Un")
171    rdd2.setName("Deux")
172    rdd3.setName("Trois")
173
174    rdd1.count()
175    sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)
176    listener.stageInfos.size should be {1}
177    val stageInfo1 = listener.stageInfos.keys.find(_.stageId == 0).get
178    stageInfo1.rddInfos.size should be {1} // ParallelCollectionRDD
179    stageInfo1.rddInfos.forall(_.numPartitions == 4) should be {true}
180    stageInfo1.rddInfos.exists(_.name == "Un") should be {true}
181    listener.stageInfos.clear()
182
183    rdd2.count()
184    sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)
185    listener.stageInfos.size should be {1}
186    val stageInfo2 = listener.stageInfos.keys.find(_.stageId == 1).get
187    stageInfo2.rddInfos.size should be {3} // ParallelCollectionRDD, FilteredRDD, MappedRDD
188    stageInfo2.rddInfos.forall(_.numPartitions == 4) should be {true}
189    stageInfo2.rddInfos.exists(_.name == "Deux") should be {true}
190    listener.stageInfos.clear()
191
192    rdd3.count()
193    sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)
194    listener.stageInfos.size should be {2} // Shuffle map stage + result stage
195    val stageInfo3 = listener.stageInfos.keys.find(_.stageId == 3).get
196    stageInfo3.rddInfos.size should be {1} // ShuffledRDD
197    stageInfo3.rddInfos.forall(_.numPartitions == 4) should be {true}
198    stageInfo3.rddInfos.exists(_.name == "Trois") should be {true}
199  }
200
201  test("StageInfo with fewer tasks than partitions") {
202    sc = new SparkContext("local", "SparkListenerSuite")
203    val listener = new SaveStageAndTaskInfo
204    sc.addSparkListener(listener)
205    val rdd1 = sc.parallelize(1 to 100, 4)
206    val rdd2 = rdd1.map(_.toString)
207    sc.runJob(rdd2, (items: Iterator[String]) => items.size, Seq(0, 1))
208
209    sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)
210
211    listener.stageInfos.size should be {1}
212    val (stageInfo, _) = listener.stageInfos.head
213    stageInfo.numTasks should be {2}
214    stageInfo.rddInfos.size should be {2}
215    stageInfo.rddInfos.forall(_.numPartitions == 4) should be {true}
216  }
217
218  test("local metrics") {
219    sc = new SparkContext("local", "SparkListenerSuite")
220    val listener = new SaveStageAndTaskInfo
221    sc.addSparkListener(listener)
222    sc.addSparkListener(new StatsReportListener)
223    // just to make sure some of the tasks take a noticeable amount of time
224    val w = { i: Int =>
225      if (i == 0) {
226        Thread.sleep(100)
227      }
228      i
229    }
230
231    val numSlices = 16
232    val d = sc.parallelize(0 to 1e3.toInt, numSlices).map(w)
233    d.count()
234    sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)
235    listener.stageInfos.size should be (1)
236
237    val d2 = d.map { i => w(i) -> i * 2 }.setName("shuffle input 1")
238    val d3 = d.map { i => w(i) -> (0 to (i % 5)) }.setName("shuffle input 2")
239    val d4 = d2.cogroup(d3, numSlices).map { case (k, (v1, v2)) =>
240      w(k) -> (v1.size, v2.size)
241    }
242    d4.setName("A Cogroup")
243    d4.collectAsMap()
244
245    sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)
246    listener.stageInfos.size should be (4)
247    listener.stageInfos.foreach { case (stageInfo, taskInfoMetrics) =>
248      /**
249       * Small test, so some tasks might take less than 1 millisecond, but average should be greater
250       * than 0 ms.
251       */
252      checkNonZeroAvg(
253        taskInfoMetrics.map(_._2.executorRunTime),
254        stageInfo + " executorRunTime")
255      checkNonZeroAvg(
256        taskInfoMetrics.map(_._2.executorDeserializeTime),
257        stageInfo + " executorDeserializeTime")
258
259      /* Test is disabled (SEE SPARK-2208)
260      if (stageInfo.rddInfos.exists(_.name == d4.name)) {
261        checkNonZeroAvg(
262          taskInfoMetrics.map(_._2.shuffleReadMetrics.get.fetchWaitTime),
263          stageInfo + " fetchWaitTime")
264      }
265      */
266
267      taskInfoMetrics.foreach { case (taskInfo, taskMetrics) =>
268        taskMetrics.resultSize should be > (0L)
269        if (stageInfo.rddInfos.exists(info => info.name == d2.name || info.name == d3.name)) {
270          assert(taskMetrics.shuffleWriteMetrics.bytesWritten > 0L)
271        }
272        if (stageInfo.rddInfos.exists(_.name == d4.name)) {
273          assert(taskMetrics.shuffleReadMetrics.totalBlocksFetched == 2 * numSlices)
274          assert(taskMetrics.shuffleReadMetrics.localBlocksFetched == 2 * numSlices)
275          assert(taskMetrics.shuffleReadMetrics.remoteBlocksFetched == 0)
276          assert(taskMetrics.shuffleReadMetrics.remoteBytesRead == 0L)
277        }
278      }
279    }
280  }
281
282  test("onTaskGettingResult() called when result fetched remotely") {
283    val conf = new SparkConf().set("spark.rpc.message.maxSize", "1")
284    sc = new SparkContext("local", "SparkListenerSuite", conf)
285    val listener = new SaveTaskEvents
286    sc.addSparkListener(listener)
287
288    // Make a task whose result is larger than the RPC message size
289    val maxRpcMessageSize = RpcUtils.maxMessageSizeBytes(conf)
290    assert(maxRpcMessageSize === 1024 * 1024)
291    val result = sc.parallelize(Seq(1), 1)
292      .map { x => 1.to(maxRpcMessageSize).toArray }
293      .reduce { case (x, y) => x }
294    assert(result === 1.to(maxRpcMessageSize).toArray)
295
296    sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)
297    val TASK_INDEX = 0
298    assert(listener.startedTasks.contains(TASK_INDEX))
299    assert(listener.startedGettingResultTasks.contains(TASK_INDEX))
300    assert(listener.endedTasks.contains(TASK_INDEX))
301  }
302
303  test("onTaskGettingResult() not called when result sent directly") {
304    sc = new SparkContext("local", "SparkListenerSuite")
305    val listener = new SaveTaskEvents
306    sc.addSparkListener(listener)
307
308    // Make a task whose result is larger than the RPC message size
309    val result = sc.parallelize(Seq(1), 1).map(2 * _).reduce { case (x, y) => x }
310    assert(result === 2)
311
312    sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)
313    val TASK_INDEX = 0
314    assert(listener.startedTasks.contains(TASK_INDEX))
315    assert(listener.startedGettingResultTasks.isEmpty)
316    assert(listener.endedTasks.contains(TASK_INDEX))
317  }
318
319  test("onTaskEnd() should be called for all started tasks, even after job has been killed") {
320    sc = new SparkContext("local", "SparkListenerSuite")
321    val WAIT_TIMEOUT_MILLIS = 10000
322    val listener = new SaveTaskEvents
323    sc.addSparkListener(listener)
324
325    val numTasks = 10
326    val f = sc.parallelize(1 to 10000, numTasks).map { i => Thread.sleep(10); i }.countAsync()
327    // Wait until one task has started (because we want to make sure that any tasks that are started
328    // have corresponding end events sent to the listener).
329    var finishTime = System.currentTimeMillis + WAIT_TIMEOUT_MILLIS
330    listener.synchronized {
331      var remainingWait = finishTime - System.currentTimeMillis
332      while (listener.startedTasks.isEmpty && remainingWait > 0) {
333        listener.wait(remainingWait)
334        remainingWait = finishTime - System.currentTimeMillis
335      }
336      assert(!listener.startedTasks.isEmpty)
337    }
338
339    f.cancel()
340
341    // Ensure that onTaskEnd is called for all started tasks.
342    finishTime = System.currentTimeMillis + WAIT_TIMEOUT_MILLIS
343    listener.synchronized {
344      var remainingWait = finishTime - System.currentTimeMillis
345      while (listener.endedTasks.size < listener.startedTasks.size && remainingWait > 0) {
346        listener.wait(finishTime - System.currentTimeMillis)
347        remainingWait = finishTime - System.currentTimeMillis
348      }
349      assert(listener.endedTasks.size === listener.startedTasks.size)
350    }
351  }
352
353  test("SparkListener moves on if a listener throws an exception") {
354    val badListener = new BadListener
355    val jobCounter1 = new BasicJobCounter
356    val jobCounter2 = new BasicJobCounter
357    sc = new SparkContext("local", "SparkListenerSuite", new SparkConf())
358    val bus = new LiveListenerBus(sc)
359
360    // Propagate events to bad listener first
361    bus.addListener(badListener)
362    bus.addListener(jobCounter1)
363    bus.addListener(jobCounter2)
364    bus.start()
365
366    // Post events to all listeners, and wait until the queue is drained
367    (1 to 5).foreach { _ => bus.post(SparkListenerJobEnd(0, jobCompletionTime, JobSucceeded)) }
368    bus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)
369
370    // The exception should be caught, and the event should be propagated to other listeners
371    assert(bus.listenerThreadIsAlive)
372    assert(jobCounter1.count === 5)
373    assert(jobCounter2.count === 5)
374  }
375
376  test("registering listeners via spark.extraListeners") {
377    val listeners = Seq(
378      classOf[ListenerThatAcceptsSparkConf],
379      classOf[FirehoseListenerThatAcceptsSparkConf],
380      classOf[BasicJobCounter])
381    val conf = new SparkConf().setMaster("local").setAppName("test")
382      .set("spark.extraListeners", listeners.map(_.getName).mkString(","))
383    sc = new SparkContext(conf)
384    sc.listenerBus.listeners.asScala.count(_.isInstanceOf[BasicJobCounter]) should be (1)
385    sc.listenerBus.listeners.asScala
386      .count(_.isInstanceOf[ListenerThatAcceptsSparkConf]) should be (1)
387    sc.listenerBus.listeners.asScala
388        .count(_.isInstanceOf[FirehoseListenerThatAcceptsSparkConf]) should be (1)
389  }
390
391  /**
392   * Assert that the given list of numbers has an average that is greater than zero.
393   */
394  private def checkNonZeroAvg(m: Traversable[Long], msg: String) {
395    assert(m.sum / m.size.toDouble > 0.0, msg)
396  }
397
398  /**
399   * A simple listener that saves all task infos and task metrics.
400   */
401  private class SaveStageAndTaskInfo extends SparkListener {
402    val stageInfos = mutable.Map[StageInfo, Seq[(TaskInfo, TaskMetrics)]]()
403    var taskInfoMetrics = mutable.Buffer[(TaskInfo, TaskMetrics)]()
404
405    override def onTaskEnd(task: SparkListenerTaskEnd) {
406      val info = task.taskInfo
407      val metrics = task.taskMetrics
408      if (info != null && metrics != null) {
409        taskInfoMetrics += ((info, metrics))
410      }
411    }
412
413    override def onStageCompleted(stage: SparkListenerStageCompleted) {
414      stageInfos(stage.stageInfo) = taskInfoMetrics
415      taskInfoMetrics = mutable.Buffer.empty
416    }
417  }
418
419  /**
420   * A simple listener that saves the task indices for all task events.
421   */
422  private class SaveTaskEvents extends SparkListener {
423    val startedTasks = new mutable.HashSet[Int]()
424    val startedGettingResultTasks = new mutable.HashSet[Int]()
425    val endedTasks = new mutable.HashSet[Int]()
426
427    override def onTaskStart(taskStart: SparkListenerTaskStart): Unit = synchronized {
428      startedTasks += taskStart.taskInfo.index
429      notify()
430    }
431
432    override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = synchronized {
433      endedTasks += taskEnd.taskInfo.index
434      notify()
435    }
436
437    override def onTaskGettingResult(taskGettingResult: SparkListenerTaskGettingResult) {
438      startedGettingResultTasks += taskGettingResult.taskInfo.index
439    }
440  }
441
442  /**
443   * A simple listener that throws an exception on job end.
444   */
445  private class BadListener extends SparkListener {
446    override def onJobEnd(jobEnd: SparkListenerJobEnd): Unit = { throw new Exception }
447  }
448
449}
450
451// These classes can't be declared inside of the SparkListenerSuite class because we don't want
452// their constructors to contain references to SparkListenerSuite:
453
454/**
455 * A simple listener that counts the number of jobs observed.
456 */
457private class BasicJobCounter extends SparkListener {
458  var count = 0
459  override def onJobEnd(job: SparkListenerJobEnd): Unit = count += 1
460}
461
462/**
463 * A simple listener that tries to stop SparkContext.
464 */
465private class SparkContextStoppingListener(val sc: SparkContext) extends SparkListener {
466  @volatile var sparkExSeen = false
467  override def onJobEnd(job: SparkListenerJobEnd): Unit = {
468    try {
469      sc.stop()
470    } catch {
471      case se: SparkException =>
472        sparkExSeen = true
473    }
474  }
475}
476
477private class ListenerThatAcceptsSparkConf(conf: SparkConf) extends SparkListener {
478  var count = 0
479  override def onJobEnd(job: SparkListenerJobEnd): Unit = count += 1
480}
481
482private class FirehoseListenerThatAcceptsSparkConf(conf: SparkConf) extends SparkFirehoseListener {
483  var count = 0
484  override def onEvent(event: SparkListenerEvent): Unit = event match {
485    case job: SparkListenerJobEnd => count += 1
486    case _ =>
487  }
488}
489