1/* 2 * Licensed to the Apache Software Foundation (ASF) under one or more 3 * contributor license agreements. See the NOTICE file distributed with 4 * this work for additional information regarding copyright ownership. 5 * The ASF licenses this file to You under the Apache License, Version 2.0 6 * (the "License"); you may not use this file except in compliance with 7 * the License. You may obtain a copy of the License at 8 * 9 * http://www.apache.org/licenses/LICENSE-2.0 10 * 11 * Unless required by applicable law or agreed to in writing, software 12 * distributed under the License is distributed on an "AS IS" BASIS, 13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 * See the License for the specific language governing permissions and 15 * limitations under the License. 16 */ 17 18package org.apache.spark 19 20import java.util.concurrent.Semaphore 21import javax.annotation.concurrent.GuardedBy 22 23import scala.collection.mutable 24import scala.collection.mutable.ArrayBuffer 25import scala.ref.WeakReference 26import scala.util.control.NonFatal 27 28import org.scalatest.Matchers 29import org.scalatest.exceptions.TestFailedException 30 31import org.apache.spark.AccumulatorParam.StringAccumulatorParam 32import org.apache.spark.scheduler._ 33import org.apache.spark.serializer.JavaSerializer 34import org.apache.spark.util.{AccumulatorContext, AccumulatorMetadata, AccumulatorV2, LongAccumulator} 35 36 37class AccumulatorSuite extends SparkFunSuite with Matchers with LocalSparkContext { 38 import AccumulatorSuite.createLongAccum 39 40 override def afterEach(): Unit = { 41 try { 42 AccumulatorContext.clear() 43 } finally { 44 super.afterEach() 45 } 46 } 47 48 implicit def setAccum[A]: AccumulableParam[mutable.Set[A], A] = 49 new AccumulableParam[mutable.Set[A], A] { 50 def addInPlace(t1: mutable.Set[A], t2: mutable.Set[A]) : mutable.Set[A] = { 51 t1 ++= t2 52 t1 53 } 54 def addAccumulator(t1: mutable.Set[A], t2: A) : mutable.Set[A] = { 55 t1 += t2 56 t1 57 } 58 def zero(t: mutable.Set[A]) : mutable.Set[A] = { 59 new mutable.HashSet[A]() 60 } 61 } 62 63 test("accumulator serialization") { 64 val ser = new JavaSerializer(new SparkConf).newInstance() 65 val acc = createLongAccum("x") 66 acc.add(5) 67 assert(acc.value == 5) 68 assert(acc.isAtDriverSide) 69 70 // serialize and de-serialize it, to simulate sending accumulator to executor. 71 val acc2 = ser.deserialize[LongAccumulator](ser.serialize(acc)) 72 // value is reset on the executors 73 assert(acc2.value == 0) 74 assert(!acc2.isAtDriverSide) 75 76 acc2.add(10) 77 // serialize and de-serialize it again, to simulate sending accumulator back to driver. 78 val acc3 = ser.deserialize[LongAccumulator](ser.serialize(acc2)) 79 // value is not reset on the driver 80 assert(acc3.value == 10) 81 assert(acc3.isAtDriverSide) 82 } 83 84 test ("basic accumulation") { 85 sc = new SparkContext("local", "test") 86 val acc: Accumulator[Int] = sc.accumulator(0) 87 88 val d = sc.parallelize(1 to 20) 89 d.foreach{x => acc += x} 90 acc.value should be (210) 91 92 val longAcc = sc.accumulator(0L) 93 val maxInt = Integer.MAX_VALUE.toLong 94 d.foreach{x => longAcc += maxInt + x} 95 longAcc.value should be (210L + maxInt * 20) 96 } 97 98 test("value not assignable from tasks") { 99 sc = new SparkContext("local", "test") 100 val acc: Accumulator[Int] = sc.accumulator(0) 101 102 val d = sc.parallelize(1 to 20) 103 intercept[SparkException] { 104 d.foreach(x => acc.value = x) 105 } 106 } 107 108 test ("add value to collection accumulators") { 109 val maxI = 1000 110 for (nThreads <- List(1, 10)) { // test single & multi-threaded 111 sc = new SparkContext("local[" + nThreads + "]", "test") 112 val acc: Accumulable[mutable.Set[Any], Any] = sc.accumulable(new mutable.HashSet[Any]()) 113 val d = sc.parallelize(1 to maxI) 114 d.foreach { 115 x => acc += x 116 } 117 val v = acc.value.asInstanceOf[mutable.Set[Int]] 118 for (i <- 1 to maxI) { 119 v should contain(i) 120 } 121 resetSparkContext() 122 } 123 } 124 125 test("value not readable in tasks") { 126 val maxI = 1000 127 for (nThreads <- List(1, 10)) { // test single & multi-threaded 128 sc = new SparkContext("local[" + nThreads + "]", "test") 129 val acc: Accumulable[mutable.Set[Any], Any] = sc.accumulable(new mutable.HashSet[Any]()) 130 val d = sc.parallelize(1 to maxI) 131 an [SparkException] should be thrownBy { 132 d.foreach { 133 x => acc.value += x 134 } 135 } 136 resetSparkContext() 137 } 138 } 139 140 test ("collection accumulators") { 141 val maxI = 1000 142 for (nThreads <- List(1, 10)) { 143 // test single & multi-threaded 144 sc = new SparkContext("local[" + nThreads + "]", "test") 145 val setAcc = sc.accumulableCollection(mutable.HashSet[Int]()) 146 val bufferAcc = sc.accumulableCollection(mutable.ArrayBuffer[Int]()) 147 val mapAcc = sc.accumulableCollection(mutable.HashMap[Int, String]()) 148 val d = sc.parallelize((1 to maxI) ++ (1 to maxI)) 149 d.foreach { 150 x => {setAcc += x; bufferAcc += x; mapAcc += (x -> x.toString)} 151 } 152 153 // Note that this is typed correctly -- no casts necessary 154 setAcc.value.size should be (maxI) 155 bufferAcc.value.size should be (2 * maxI) 156 mapAcc.value.size should be (maxI) 157 for (i <- 1 to maxI) { 158 setAcc.value should contain(i) 159 bufferAcc.value should contain(i) 160 mapAcc.value should contain (i -> i.toString) 161 } 162 resetSparkContext() 163 } 164 } 165 166 test ("localValue readable in tasks") { 167 val maxI = 1000 168 for (nThreads <- List(1, 10)) { // test single & multi-threaded 169 sc = new SparkContext("local[" + nThreads + "]", "test") 170 val acc: Accumulable[mutable.Set[Any], Any] = sc.accumulable(new mutable.HashSet[Any]()) 171 val groupedInts = (1 to (maxI/20)).map {x => (20 * (x - 1) to 20 * x).toSet} 172 val d = sc.parallelize(groupedInts) 173 d.foreach { 174 x => acc.localValue ++= x 175 } 176 acc.value should be ((0 to maxI).toSet) 177 resetSparkContext() 178 } 179 } 180 181 test ("garbage collection") { 182 // Create an accumulator and let it go out of scope to test that it's properly garbage collected 183 sc = new SparkContext("local", "test") 184 var acc: Accumulable[mutable.Set[Any], Any] = sc.accumulable(new mutable.HashSet[Any]()) 185 val accId = acc.id 186 val ref = WeakReference(acc) 187 188 // Ensure the accumulator is present 189 assert(ref.get.isDefined) 190 191 // Remove the explicit reference to it and allow weak reference to get garbage collected 192 acc = null 193 System.gc() 194 assert(ref.get.isEmpty) 195 196 AccumulatorContext.remove(accId) 197 assert(!AccumulatorContext.get(accId).isDefined) 198 } 199 200 test("get accum") { 201 // Don't register with SparkContext for cleanup 202 var acc = createLongAccum("a") 203 val accId = acc.id 204 val ref = WeakReference(acc) 205 assert(ref.get.isDefined) 206 207 // Remove the explicit reference to it and allow weak reference to get garbage collected 208 acc = null 209 System.gc() 210 assert(ref.get.isEmpty) 211 212 // Getting a garbage collected accum should throw error 213 intercept[IllegalAccessError] { 214 AccumulatorContext.get(accId) 215 } 216 217 // Getting a normal accumulator. Note: this has to be separate because referencing an 218 // accumulator above in an `assert` would keep it from being garbage collected. 219 val acc2 = createLongAccum("b") 220 assert(AccumulatorContext.get(acc2.id) === Some(acc2)) 221 222 // Getting an accumulator that does not exist should return None 223 assert(AccumulatorContext.get(100000).isEmpty) 224 } 225 226 test("string accumulator param") { 227 val acc = new Accumulator("", StringAccumulatorParam, Some("darkness")) 228 assert(acc.value === "") 229 acc.setValue("feeds") 230 assert(acc.value === "feeds") 231 acc.add("your") 232 assert(acc.value === "your") // value is overwritten, not concatenated 233 acc += "soul" 234 assert(acc.value === "soul") 235 acc ++= "with" 236 assert(acc.value === "with") 237 acc.merge("kindness") 238 assert(acc.value === "kindness") 239 } 240} 241 242private[spark] object AccumulatorSuite { 243 import InternalAccumulator._ 244 245 /** 246 * Create a long accumulator and register it to [[AccumulatorContext]]. 247 */ 248 def createLongAccum( 249 name: String, 250 countFailedValues: Boolean = false, 251 initValue: Long = 0, 252 id: Long = AccumulatorContext.newId()): LongAccumulator = { 253 val acc = new LongAccumulator 254 acc.setValue(initValue) 255 acc.metadata = AccumulatorMetadata(id, Some(name), countFailedValues) 256 AccumulatorContext.register(acc) 257 acc 258 } 259 260 /** 261 * Make an [[AccumulableInfo]] out of an [[Accumulable]] with the intent to use the 262 * info as an accumulator update. 263 */ 264 def makeInfo(a: AccumulatorV2[_, _]): AccumulableInfo = a.toInfo(Some(a.value), None) 265 266 /** 267 * Run one or more Spark jobs and verify that in at least one job the peak execution memory 268 * accumulator is updated afterwards. 269 */ 270 def verifyPeakExecutionMemorySet( 271 sc: SparkContext, 272 testName: String)(testBody: => Unit): Unit = { 273 val listener = new SaveInfoListener 274 sc.addSparkListener(listener) 275 testBody 276 // wait until all events have been processed before proceeding to assert things 277 sc.listenerBus.waitUntilEmpty(10 * 1000) 278 val accums = listener.getCompletedStageInfos.flatMap(_.accumulables.values) 279 val isSet = accums.exists { a => 280 a.name == Some(PEAK_EXECUTION_MEMORY) && a.value.exists(_.asInstanceOf[Long] > 0L) 281 } 282 if (!isSet) { 283 throw new TestFailedException(s"peak execution memory accumulator not set in '$testName'", 0) 284 } 285 } 286} 287 288/** 289 * A simple listener that keeps track of the TaskInfos and StageInfos of all completed jobs. 290 */ 291private class SaveInfoListener extends SparkListener { 292 type StageId = Int 293 type StageAttemptId = Int 294 295 private val completedStageInfos = new ArrayBuffer[StageInfo] 296 private val completedTaskInfos = 297 new mutable.HashMap[(StageId, StageAttemptId), ArrayBuffer[TaskInfo]] 298 299 // Callback to call when a job completes. Parameter is job ID. 300 @GuardedBy("this") 301 private var jobCompletionCallback: () => Unit = null 302 private val jobCompletionSem = new Semaphore(0) 303 private var exception: Throwable = null 304 305 def getCompletedStageInfos: Seq[StageInfo] = completedStageInfos.toArray.toSeq 306 def getCompletedTaskInfos: Seq[TaskInfo] = completedTaskInfos.values.flatten.toSeq 307 def getCompletedTaskInfos(stageId: StageId, stageAttemptId: StageAttemptId): Seq[TaskInfo] = 308 completedTaskInfos.getOrElse((stageId, stageAttemptId), Seq.empty[TaskInfo]) 309 310 /** 311 * If `jobCompletionCallback` is set, block until the next call has finished. 312 * If the callback failed with an exception, throw it. 313 */ 314 def awaitNextJobCompletion(): Unit = { 315 if (jobCompletionCallback != null) { 316 jobCompletionSem.acquire() 317 if (exception != null) { 318 throw exception 319 } 320 } 321 } 322 323 /** 324 * Register a callback to be called on job end. 325 * A call to this should be followed by [[awaitNextJobCompletion]]. 326 */ 327 def registerJobCompletionCallback(callback: () => Unit): Unit = { 328 jobCompletionCallback = callback 329 } 330 331 override def onJobEnd(jobEnd: SparkListenerJobEnd): Unit = { 332 if (jobCompletionCallback != null) { 333 try { 334 jobCompletionCallback() 335 } catch { 336 // Store any exception thrown here so we can throw them later in the main thread. 337 // Otherwise, if `jobCompletionCallback` threw something it wouldn't fail the test. 338 case NonFatal(e) => exception = e 339 } finally { 340 jobCompletionSem.release() 341 } 342 } 343 } 344 345 override def onStageCompleted(stageCompleted: SparkListenerStageCompleted): Unit = { 346 completedStageInfos += stageCompleted.stageInfo 347 } 348 349 override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = { 350 completedTaskInfos.getOrElseUpdate( 351 (taskEnd.stageId, taskEnd.stageAttemptId), new ArrayBuffer[TaskInfo]) += taskEnd.taskInfo 352 } 353} 354