1/* 2 * Licensed to the Apache Software Foundation (ASF) under one or more 3 * contributor license agreements. See the NOTICE file distributed with 4 * this work for additional information regarding copyright ownership. 5 * The ASF licenses this file to You under the Apache License, Version 2.0 6 * (the "License"); you may not use this file except in compliance with 7 * the License. You may obtain a copy of the License at 8 * 9 * http://www.apache.org/licenses/LICENSE-2.0 10 * 11 * Unless required by applicable law or agreed to in writing, software 12 * distributed under the License is distributed on an "AS IS" BASIS, 13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 * See the License for the specific language governing permissions and 15 * limitations under the License. 16 */ 17 18package org.apache.spark.scheduler 19 20import java.util.concurrent.Semaphore 21 22import scala.collection.mutable 23import scala.collection.JavaConverters._ 24 25import org.scalatest.Matchers 26 27import org.apache.spark._ 28import org.apache.spark.executor.TaskMetrics 29import org.apache.spark.util.{ResetSystemProperties, RpcUtils} 30 31class SparkListenerSuite extends SparkFunSuite with LocalSparkContext with Matchers 32 with ResetSystemProperties { 33 34 /** Length of time to wait while draining listener events. */ 35 val WAIT_TIMEOUT_MILLIS = 10000 36 37 val jobCompletionTime = 1421191296660L 38 39 test("don't call sc.stop in listener") { 40 sc = new SparkContext("local", "SparkListenerSuite", new SparkConf()) 41 val listener = new SparkContextStoppingListener(sc) 42 val bus = new LiveListenerBus(sc) 43 bus.addListener(listener) 44 45 // Starting listener bus should flush all buffered events 46 bus.start() 47 bus.post(SparkListenerJobEnd(0, jobCompletionTime, JobSucceeded)) 48 bus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS) 49 50 bus.stop() 51 assert(listener.sparkExSeen) 52 } 53 54 test("basic creation and shutdown of LiveListenerBus") { 55 sc = new SparkContext("local", "SparkListenerSuite", new SparkConf()) 56 val counter = new BasicJobCounter 57 val bus = new LiveListenerBus(sc) 58 bus.addListener(counter) 59 60 // Listener bus hasn't started yet, so posting events should not increment counter 61 (1 to 5).foreach { _ => bus.post(SparkListenerJobEnd(0, jobCompletionTime, JobSucceeded)) } 62 assert(counter.count === 0) 63 64 // Starting listener bus should flush all buffered events 65 bus.start() 66 bus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS) 67 assert(counter.count === 5) 68 69 // After listener bus has stopped, posting events should not increment counter 70 bus.stop() 71 (1 to 5).foreach { _ => bus.post(SparkListenerJobEnd(0, jobCompletionTime, JobSucceeded)) } 72 assert(counter.count === 5) 73 74 // Listener bus must not be started twice 75 intercept[IllegalStateException] { 76 val bus = new LiveListenerBus(sc) 77 bus.start() 78 bus.start() 79 } 80 81 // ... or stopped before starting 82 intercept[IllegalStateException] { 83 val bus = new LiveListenerBus(sc) 84 bus.stop() 85 } 86 } 87 88 test("bus.stop() waits for the event queue to completely drain") { 89 @volatile var drained = false 90 91 // When Listener has started 92 val listenerStarted = new Semaphore(0) 93 94 // Tells the listener to stop blocking 95 val listenerWait = new Semaphore(0) 96 97 // When stopper has started 98 val stopperStarted = new Semaphore(0) 99 100 // When stopper has returned 101 val stopperReturned = new Semaphore(0) 102 103 class BlockingListener extends SparkListener { 104 override def onJobEnd(jobEnd: SparkListenerJobEnd): Unit = { 105 listenerStarted.release() 106 listenerWait.acquire() 107 drained = true 108 } 109 } 110 sc = new SparkContext("local", "SparkListenerSuite", new SparkConf()) 111 val bus = new LiveListenerBus(sc) 112 val blockingListener = new BlockingListener 113 114 bus.addListener(blockingListener) 115 bus.start() 116 bus.post(SparkListenerJobEnd(0, jobCompletionTime, JobSucceeded)) 117 118 listenerStarted.acquire() 119 // Listener should be blocked after start 120 assert(!drained) 121 122 new Thread("ListenerBusStopper") { 123 override def run() { 124 stopperStarted.release() 125 // stop() will block until notify() is called below 126 bus.stop() 127 stopperReturned.release() 128 } 129 }.start() 130 131 stopperStarted.acquire() 132 // Listener should remain blocked after stopper started 133 assert(!drained) 134 135 // unblock Listener to let queue drain 136 listenerWait.release() 137 stopperReturned.acquire() 138 assert(drained) 139 } 140 141 test("basic creation of StageInfo") { 142 sc = new SparkContext("local", "SparkListenerSuite") 143 val listener = new SaveStageAndTaskInfo 144 sc.addSparkListener(listener) 145 val rdd1 = sc.parallelize(1 to 100, 4) 146 val rdd2 = rdd1.map(_.toString) 147 rdd2.setName("Target RDD") 148 rdd2.count() 149 150 sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS) 151 152 listener.stageInfos.size should be {1} 153 val (stageInfo, taskInfoMetrics) = listener.stageInfos.head 154 stageInfo.rddInfos.size should be {2} 155 stageInfo.rddInfos.forall(_.numPartitions == 4) should be {true} 156 stageInfo.rddInfos.exists(_.name == "Target RDD") should be {true} 157 stageInfo.numTasks should be {4} 158 stageInfo.submissionTime should be ('defined) 159 stageInfo.completionTime should be ('defined) 160 taskInfoMetrics.length should be {4} 161 } 162 163 test("basic creation of StageInfo with shuffle") { 164 sc = new SparkContext("local", "SparkListenerSuite") 165 val listener = new SaveStageAndTaskInfo 166 sc.addSparkListener(listener) 167 val rdd1 = sc.parallelize(1 to 100, 4) 168 val rdd2 = rdd1.filter(_ % 2 == 0).map(i => (i, i)) 169 val rdd3 = rdd2.reduceByKey(_ + _) 170 rdd1.setName("Un") 171 rdd2.setName("Deux") 172 rdd3.setName("Trois") 173 174 rdd1.count() 175 sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS) 176 listener.stageInfos.size should be {1} 177 val stageInfo1 = listener.stageInfos.keys.find(_.stageId == 0).get 178 stageInfo1.rddInfos.size should be {1} // ParallelCollectionRDD 179 stageInfo1.rddInfos.forall(_.numPartitions == 4) should be {true} 180 stageInfo1.rddInfos.exists(_.name == "Un") should be {true} 181 listener.stageInfos.clear() 182 183 rdd2.count() 184 sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS) 185 listener.stageInfos.size should be {1} 186 val stageInfo2 = listener.stageInfos.keys.find(_.stageId == 1).get 187 stageInfo2.rddInfos.size should be {3} // ParallelCollectionRDD, FilteredRDD, MappedRDD 188 stageInfo2.rddInfos.forall(_.numPartitions == 4) should be {true} 189 stageInfo2.rddInfos.exists(_.name == "Deux") should be {true} 190 listener.stageInfos.clear() 191 192 rdd3.count() 193 sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS) 194 listener.stageInfos.size should be {2} // Shuffle map stage + result stage 195 val stageInfo3 = listener.stageInfos.keys.find(_.stageId == 3).get 196 stageInfo3.rddInfos.size should be {1} // ShuffledRDD 197 stageInfo3.rddInfos.forall(_.numPartitions == 4) should be {true} 198 stageInfo3.rddInfos.exists(_.name == "Trois") should be {true} 199 } 200 201 test("StageInfo with fewer tasks than partitions") { 202 sc = new SparkContext("local", "SparkListenerSuite") 203 val listener = new SaveStageAndTaskInfo 204 sc.addSparkListener(listener) 205 val rdd1 = sc.parallelize(1 to 100, 4) 206 val rdd2 = rdd1.map(_.toString) 207 sc.runJob(rdd2, (items: Iterator[String]) => items.size, Seq(0, 1)) 208 209 sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS) 210 211 listener.stageInfos.size should be {1} 212 val (stageInfo, _) = listener.stageInfos.head 213 stageInfo.numTasks should be {2} 214 stageInfo.rddInfos.size should be {2} 215 stageInfo.rddInfos.forall(_.numPartitions == 4) should be {true} 216 } 217 218 test("local metrics") { 219 sc = new SparkContext("local", "SparkListenerSuite") 220 val listener = new SaveStageAndTaskInfo 221 sc.addSparkListener(listener) 222 sc.addSparkListener(new StatsReportListener) 223 // just to make sure some of the tasks take a noticeable amount of time 224 val w = { i: Int => 225 if (i == 0) { 226 Thread.sleep(100) 227 } 228 i 229 } 230 231 val numSlices = 16 232 val d = sc.parallelize(0 to 1e3.toInt, numSlices).map(w) 233 d.count() 234 sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS) 235 listener.stageInfos.size should be (1) 236 237 val d2 = d.map { i => w(i) -> i * 2 }.setName("shuffle input 1") 238 val d3 = d.map { i => w(i) -> (0 to (i % 5)) }.setName("shuffle input 2") 239 val d4 = d2.cogroup(d3, numSlices).map { case (k, (v1, v2)) => 240 w(k) -> (v1.size, v2.size) 241 } 242 d4.setName("A Cogroup") 243 d4.collectAsMap() 244 245 sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS) 246 listener.stageInfos.size should be (4) 247 listener.stageInfos.foreach { case (stageInfo, taskInfoMetrics) => 248 /** 249 * Small test, so some tasks might take less than 1 millisecond, but average should be greater 250 * than 0 ms. 251 */ 252 checkNonZeroAvg( 253 taskInfoMetrics.map(_._2.executorRunTime), 254 stageInfo + " executorRunTime") 255 checkNonZeroAvg( 256 taskInfoMetrics.map(_._2.executorDeserializeTime), 257 stageInfo + " executorDeserializeTime") 258 259 /* Test is disabled (SEE SPARK-2208) 260 if (stageInfo.rddInfos.exists(_.name == d4.name)) { 261 checkNonZeroAvg( 262 taskInfoMetrics.map(_._2.shuffleReadMetrics.get.fetchWaitTime), 263 stageInfo + " fetchWaitTime") 264 } 265 */ 266 267 taskInfoMetrics.foreach { case (taskInfo, taskMetrics) => 268 taskMetrics.resultSize should be > (0L) 269 if (stageInfo.rddInfos.exists(info => info.name == d2.name || info.name == d3.name)) { 270 assert(taskMetrics.shuffleWriteMetrics.bytesWritten > 0L) 271 } 272 if (stageInfo.rddInfos.exists(_.name == d4.name)) { 273 assert(taskMetrics.shuffleReadMetrics.totalBlocksFetched == 2 * numSlices) 274 assert(taskMetrics.shuffleReadMetrics.localBlocksFetched == 2 * numSlices) 275 assert(taskMetrics.shuffleReadMetrics.remoteBlocksFetched == 0) 276 assert(taskMetrics.shuffleReadMetrics.remoteBytesRead == 0L) 277 } 278 } 279 } 280 } 281 282 test("onTaskGettingResult() called when result fetched remotely") { 283 val conf = new SparkConf().set("spark.rpc.message.maxSize", "1") 284 sc = new SparkContext("local", "SparkListenerSuite", conf) 285 val listener = new SaveTaskEvents 286 sc.addSparkListener(listener) 287 288 // Make a task whose result is larger than the RPC message size 289 val maxRpcMessageSize = RpcUtils.maxMessageSizeBytes(conf) 290 assert(maxRpcMessageSize === 1024 * 1024) 291 val result = sc.parallelize(Seq(1), 1) 292 .map { x => 1.to(maxRpcMessageSize).toArray } 293 .reduce { case (x, y) => x } 294 assert(result === 1.to(maxRpcMessageSize).toArray) 295 296 sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS) 297 val TASK_INDEX = 0 298 assert(listener.startedTasks.contains(TASK_INDEX)) 299 assert(listener.startedGettingResultTasks.contains(TASK_INDEX)) 300 assert(listener.endedTasks.contains(TASK_INDEX)) 301 } 302 303 test("onTaskGettingResult() not called when result sent directly") { 304 sc = new SparkContext("local", "SparkListenerSuite") 305 val listener = new SaveTaskEvents 306 sc.addSparkListener(listener) 307 308 // Make a task whose result is larger than the RPC message size 309 val result = sc.parallelize(Seq(1), 1).map(2 * _).reduce { case (x, y) => x } 310 assert(result === 2) 311 312 sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS) 313 val TASK_INDEX = 0 314 assert(listener.startedTasks.contains(TASK_INDEX)) 315 assert(listener.startedGettingResultTasks.isEmpty) 316 assert(listener.endedTasks.contains(TASK_INDEX)) 317 } 318 319 test("onTaskEnd() should be called for all started tasks, even after job has been killed") { 320 sc = new SparkContext("local", "SparkListenerSuite") 321 val WAIT_TIMEOUT_MILLIS = 10000 322 val listener = new SaveTaskEvents 323 sc.addSparkListener(listener) 324 325 val numTasks = 10 326 val f = sc.parallelize(1 to 10000, numTasks).map { i => Thread.sleep(10); i }.countAsync() 327 // Wait until one task has started (because we want to make sure that any tasks that are started 328 // have corresponding end events sent to the listener). 329 var finishTime = System.currentTimeMillis + WAIT_TIMEOUT_MILLIS 330 listener.synchronized { 331 var remainingWait = finishTime - System.currentTimeMillis 332 while (listener.startedTasks.isEmpty && remainingWait > 0) { 333 listener.wait(remainingWait) 334 remainingWait = finishTime - System.currentTimeMillis 335 } 336 assert(!listener.startedTasks.isEmpty) 337 } 338 339 f.cancel() 340 341 // Ensure that onTaskEnd is called for all started tasks. 342 finishTime = System.currentTimeMillis + WAIT_TIMEOUT_MILLIS 343 listener.synchronized { 344 var remainingWait = finishTime - System.currentTimeMillis 345 while (listener.endedTasks.size < listener.startedTasks.size && remainingWait > 0) { 346 listener.wait(finishTime - System.currentTimeMillis) 347 remainingWait = finishTime - System.currentTimeMillis 348 } 349 assert(listener.endedTasks.size === listener.startedTasks.size) 350 } 351 } 352 353 test("SparkListener moves on if a listener throws an exception") { 354 val badListener = new BadListener 355 val jobCounter1 = new BasicJobCounter 356 val jobCounter2 = new BasicJobCounter 357 sc = new SparkContext("local", "SparkListenerSuite", new SparkConf()) 358 val bus = new LiveListenerBus(sc) 359 360 // Propagate events to bad listener first 361 bus.addListener(badListener) 362 bus.addListener(jobCounter1) 363 bus.addListener(jobCounter2) 364 bus.start() 365 366 // Post events to all listeners, and wait until the queue is drained 367 (1 to 5).foreach { _ => bus.post(SparkListenerJobEnd(0, jobCompletionTime, JobSucceeded)) } 368 bus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS) 369 370 // The exception should be caught, and the event should be propagated to other listeners 371 assert(bus.listenerThreadIsAlive) 372 assert(jobCounter1.count === 5) 373 assert(jobCounter2.count === 5) 374 } 375 376 test("registering listeners via spark.extraListeners") { 377 val listeners = Seq( 378 classOf[ListenerThatAcceptsSparkConf], 379 classOf[FirehoseListenerThatAcceptsSparkConf], 380 classOf[BasicJobCounter]) 381 val conf = new SparkConf().setMaster("local").setAppName("test") 382 .set("spark.extraListeners", listeners.map(_.getName).mkString(",")) 383 sc = new SparkContext(conf) 384 sc.listenerBus.listeners.asScala.count(_.isInstanceOf[BasicJobCounter]) should be (1) 385 sc.listenerBus.listeners.asScala 386 .count(_.isInstanceOf[ListenerThatAcceptsSparkConf]) should be (1) 387 sc.listenerBus.listeners.asScala 388 .count(_.isInstanceOf[FirehoseListenerThatAcceptsSparkConf]) should be (1) 389 } 390 391 /** 392 * Assert that the given list of numbers has an average that is greater than zero. 393 */ 394 private def checkNonZeroAvg(m: Traversable[Long], msg: String) { 395 assert(m.sum / m.size.toDouble > 0.0, msg) 396 } 397 398 /** 399 * A simple listener that saves all task infos and task metrics. 400 */ 401 private class SaveStageAndTaskInfo extends SparkListener { 402 val stageInfos = mutable.Map[StageInfo, Seq[(TaskInfo, TaskMetrics)]]() 403 var taskInfoMetrics = mutable.Buffer[(TaskInfo, TaskMetrics)]() 404 405 override def onTaskEnd(task: SparkListenerTaskEnd) { 406 val info = task.taskInfo 407 val metrics = task.taskMetrics 408 if (info != null && metrics != null) { 409 taskInfoMetrics += ((info, metrics)) 410 } 411 } 412 413 override def onStageCompleted(stage: SparkListenerStageCompleted) { 414 stageInfos(stage.stageInfo) = taskInfoMetrics 415 taskInfoMetrics = mutable.Buffer.empty 416 } 417 } 418 419 /** 420 * A simple listener that saves the task indices for all task events. 421 */ 422 private class SaveTaskEvents extends SparkListener { 423 val startedTasks = new mutable.HashSet[Int]() 424 val startedGettingResultTasks = new mutable.HashSet[Int]() 425 val endedTasks = new mutable.HashSet[Int]() 426 427 override def onTaskStart(taskStart: SparkListenerTaskStart): Unit = synchronized { 428 startedTasks += taskStart.taskInfo.index 429 notify() 430 } 431 432 override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = synchronized { 433 endedTasks += taskEnd.taskInfo.index 434 notify() 435 } 436 437 override def onTaskGettingResult(taskGettingResult: SparkListenerTaskGettingResult) { 438 startedGettingResultTasks += taskGettingResult.taskInfo.index 439 } 440 } 441 442 /** 443 * A simple listener that throws an exception on job end. 444 */ 445 private class BadListener extends SparkListener { 446 override def onJobEnd(jobEnd: SparkListenerJobEnd): Unit = { throw new Exception } 447 } 448 449} 450 451// These classes can't be declared inside of the SparkListenerSuite class because we don't want 452// their constructors to contain references to SparkListenerSuite: 453 454/** 455 * A simple listener that counts the number of jobs observed. 456 */ 457private class BasicJobCounter extends SparkListener { 458 var count = 0 459 override def onJobEnd(job: SparkListenerJobEnd): Unit = count += 1 460} 461 462/** 463 * A simple listener that tries to stop SparkContext. 464 */ 465private class SparkContextStoppingListener(val sc: SparkContext) extends SparkListener { 466 @volatile var sparkExSeen = false 467 override def onJobEnd(job: SparkListenerJobEnd): Unit = { 468 try { 469 sc.stop() 470 } catch { 471 case se: SparkException => 472 sparkExSeen = true 473 } 474 } 475} 476 477private class ListenerThatAcceptsSparkConf(conf: SparkConf) extends SparkListener { 478 var count = 0 479 override def onJobEnd(job: SparkListenerJobEnd): Unit = count += 1 480} 481 482private class FirehoseListenerThatAcceptsSparkConf(conf: SparkConf) extends SparkFirehoseListener { 483 var count = 0 484 override def onEvent(event: SparkListenerEvent): Unit = event match { 485 case job: SparkListenerJobEnd => count += 1 486 case _ => 487 } 488} 489