1/*
2 * Licensed to the Apache Software Foundation (ASF) under one or more
3 * contributor license agreements.  See the NOTICE file distributed with
4 * this work for additional information regarding copyright ownership.
5 * The ASF licenses this file to You under the Apache License, Version 2.0
6 * (the "License"); you may not use this file except in compliance with
7 * the License.  You may obtain a copy of the License at
8 *
9 *    http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17
18package org.apache.spark
19
20import java.util.concurrent.Semaphore
21import javax.annotation.concurrent.GuardedBy
22
23import scala.collection.mutable
24import scala.collection.mutable.ArrayBuffer
25import scala.ref.WeakReference
26import scala.util.control.NonFatal
27
28import org.scalatest.Matchers
29import org.scalatest.exceptions.TestFailedException
30
31import org.apache.spark.AccumulatorParam.StringAccumulatorParam
32import org.apache.spark.scheduler._
33import org.apache.spark.serializer.JavaSerializer
34import org.apache.spark.util.{AccumulatorContext, AccumulatorMetadata, AccumulatorV2, LongAccumulator}
35
36
37class AccumulatorSuite extends SparkFunSuite with Matchers with LocalSparkContext {
38  import AccumulatorSuite.createLongAccum
39
40  override def afterEach(): Unit = {
41    try {
42      AccumulatorContext.clear()
43    } finally {
44      super.afterEach()
45    }
46  }
47
48  implicit def setAccum[A]: AccumulableParam[mutable.Set[A], A] =
49    new AccumulableParam[mutable.Set[A], A] {
50      def addInPlace(t1: mutable.Set[A], t2: mutable.Set[A]) : mutable.Set[A] = {
51        t1 ++= t2
52        t1
53      }
54      def addAccumulator(t1: mutable.Set[A], t2: A) : mutable.Set[A] = {
55        t1 += t2
56        t1
57      }
58      def zero(t: mutable.Set[A]) : mutable.Set[A] = {
59        new mutable.HashSet[A]()
60      }
61    }
62
63  test("accumulator serialization") {
64    val ser = new JavaSerializer(new SparkConf).newInstance()
65    val acc = createLongAccum("x")
66    acc.add(5)
67    assert(acc.value == 5)
68    assert(acc.isAtDriverSide)
69
70    // serialize and de-serialize it, to simulate sending accumulator to executor.
71    val acc2 = ser.deserialize[LongAccumulator](ser.serialize(acc))
72    // value is reset on the executors
73    assert(acc2.value == 0)
74    assert(!acc2.isAtDriverSide)
75
76    acc2.add(10)
77    // serialize and de-serialize it again, to simulate sending accumulator back to driver.
78    val acc3 = ser.deserialize[LongAccumulator](ser.serialize(acc2))
79    // value is not reset on the driver
80    assert(acc3.value == 10)
81    assert(acc3.isAtDriverSide)
82  }
83
84  test ("basic accumulation") {
85    sc = new SparkContext("local", "test")
86    val acc: Accumulator[Int] = sc.accumulator(0)
87
88    val d = sc.parallelize(1 to 20)
89    d.foreach{x => acc += x}
90    acc.value should be (210)
91
92    val longAcc = sc.accumulator(0L)
93    val maxInt = Integer.MAX_VALUE.toLong
94    d.foreach{x => longAcc += maxInt + x}
95    longAcc.value should be (210L + maxInt * 20)
96  }
97
98  test("value not assignable from tasks") {
99    sc = new SparkContext("local", "test")
100    val acc: Accumulator[Int] = sc.accumulator(0)
101
102    val d = sc.parallelize(1 to 20)
103    intercept[SparkException] {
104      d.foreach(x => acc.value = x)
105    }
106  }
107
108  test ("add value to collection accumulators") {
109    val maxI = 1000
110    for (nThreads <- List(1, 10)) { // test single & multi-threaded
111      sc = new SparkContext("local[" + nThreads + "]", "test")
112      val acc: Accumulable[mutable.Set[Any], Any] = sc.accumulable(new mutable.HashSet[Any]())
113      val d = sc.parallelize(1 to maxI)
114      d.foreach {
115        x => acc += x
116      }
117      val v = acc.value.asInstanceOf[mutable.Set[Int]]
118      for (i <- 1 to maxI) {
119        v should contain(i)
120      }
121      resetSparkContext()
122    }
123  }
124
125  test("value not readable in tasks") {
126    val maxI = 1000
127    for (nThreads <- List(1, 10)) { // test single & multi-threaded
128      sc = new SparkContext("local[" + nThreads + "]", "test")
129      val acc: Accumulable[mutable.Set[Any], Any] = sc.accumulable(new mutable.HashSet[Any]())
130      val d = sc.parallelize(1 to maxI)
131      an [SparkException] should be thrownBy {
132        d.foreach {
133          x => acc.value += x
134        }
135      }
136      resetSparkContext()
137    }
138  }
139
140  test ("collection accumulators") {
141    val maxI = 1000
142    for (nThreads <- List(1, 10)) {
143      // test single & multi-threaded
144      sc = new SparkContext("local[" + nThreads + "]", "test")
145      val setAcc = sc.accumulableCollection(mutable.HashSet[Int]())
146      val bufferAcc = sc.accumulableCollection(mutable.ArrayBuffer[Int]())
147      val mapAcc = sc.accumulableCollection(mutable.HashMap[Int, String]())
148      val d = sc.parallelize((1 to maxI) ++ (1 to maxI))
149      d.foreach {
150        x => {setAcc += x; bufferAcc += x; mapAcc += (x -> x.toString)}
151      }
152
153      // Note that this is typed correctly -- no casts necessary
154      setAcc.value.size should be (maxI)
155      bufferAcc.value.size should be (2 * maxI)
156      mapAcc.value.size should be (maxI)
157      for (i <- 1 to maxI) {
158        setAcc.value should contain(i)
159        bufferAcc.value should contain(i)
160        mapAcc.value should contain (i -> i.toString)
161      }
162      resetSparkContext()
163    }
164  }
165
166  test ("localValue readable in tasks") {
167    val maxI = 1000
168    for (nThreads <- List(1, 10)) { // test single & multi-threaded
169      sc = new SparkContext("local[" + nThreads + "]", "test")
170      val acc: Accumulable[mutable.Set[Any], Any] = sc.accumulable(new mutable.HashSet[Any]())
171      val groupedInts = (1 to (maxI/20)).map {x => (20 * (x - 1) to 20 * x).toSet}
172      val d = sc.parallelize(groupedInts)
173      d.foreach {
174        x => acc.localValue ++= x
175      }
176      acc.value should be ((0 to maxI).toSet)
177      resetSparkContext()
178    }
179  }
180
181  test ("garbage collection") {
182    // Create an accumulator and let it go out of scope to test that it's properly garbage collected
183    sc = new SparkContext("local", "test")
184    var acc: Accumulable[mutable.Set[Any], Any] = sc.accumulable(new mutable.HashSet[Any]())
185    val accId = acc.id
186    val ref = WeakReference(acc)
187
188    // Ensure the accumulator is present
189    assert(ref.get.isDefined)
190
191    // Remove the explicit reference to it and allow weak reference to get garbage collected
192    acc = null
193    System.gc()
194    assert(ref.get.isEmpty)
195
196    AccumulatorContext.remove(accId)
197    assert(!AccumulatorContext.get(accId).isDefined)
198  }
199
200  test("get accum") {
201    // Don't register with SparkContext for cleanup
202    var acc = createLongAccum("a")
203    val accId = acc.id
204    val ref = WeakReference(acc)
205    assert(ref.get.isDefined)
206
207    // Remove the explicit reference to it and allow weak reference to get garbage collected
208    acc = null
209    System.gc()
210    assert(ref.get.isEmpty)
211
212    // Getting a garbage collected accum should throw error
213    intercept[IllegalAccessError] {
214      AccumulatorContext.get(accId)
215    }
216
217    // Getting a normal accumulator. Note: this has to be separate because referencing an
218    // accumulator above in an `assert` would keep it from being garbage collected.
219    val acc2 = createLongAccum("b")
220    assert(AccumulatorContext.get(acc2.id) === Some(acc2))
221
222    // Getting an accumulator that does not exist should return None
223    assert(AccumulatorContext.get(100000).isEmpty)
224  }
225
226  test("string accumulator param") {
227    val acc = new Accumulator("", StringAccumulatorParam, Some("darkness"))
228    assert(acc.value === "")
229    acc.setValue("feeds")
230    assert(acc.value === "feeds")
231    acc.add("your")
232    assert(acc.value === "your") // value is overwritten, not concatenated
233    acc += "soul"
234    assert(acc.value === "soul")
235    acc ++= "with"
236    assert(acc.value === "with")
237    acc.merge("kindness")
238    assert(acc.value === "kindness")
239  }
240}
241
242private[spark] object AccumulatorSuite {
243  import InternalAccumulator._
244
245  /**
246   * Create a long accumulator and register it to [[AccumulatorContext]].
247   */
248  def createLongAccum(
249      name: String,
250      countFailedValues: Boolean = false,
251      initValue: Long = 0,
252      id: Long = AccumulatorContext.newId()): LongAccumulator = {
253    val acc = new LongAccumulator
254    acc.setValue(initValue)
255    acc.metadata = AccumulatorMetadata(id, Some(name), countFailedValues)
256    AccumulatorContext.register(acc)
257    acc
258  }
259
260  /**
261   * Make an [[AccumulableInfo]] out of an [[Accumulable]] with the intent to use the
262   * info as an accumulator update.
263   */
264  def makeInfo(a: AccumulatorV2[_, _]): AccumulableInfo = a.toInfo(Some(a.value), None)
265
266  /**
267   * Run one or more Spark jobs and verify that in at least one job the peak execution memory
268   * accumulator is updated afterwards.
269   */
270  def verifyPeakExecutionMemorySet(
271      sc: SparkContext,
272      testName: String)(testBody: => Unit): Unit = {
273    val listener = new SaveInfoListener
274    sc.addSparkListener(listener)
275    testBody
276    // wait until all events have been processed before proceeding to assert things
277    sc.listenerBus.waitUntilEmpty(10 * 1000)
278    val accums = listener.getCompletedStageInfos.flatMap(_.accumulables.values)
279    val isSet = accums.exists { a =>
280      a.name == Some(PEAK_EXECUTION_MEMORY) && a.value.exists(_.asInstanceOf[Long] > 0L)
281    }
282    if (!isSet) {
283      throw new TestFailedException(s"peak execution memory accumulator not set in '$testName'", 0)
284    }
285  }
286}
287
288/**
289 * A simple listener that keeps track of the TaskInfos and StageInfos of all completed jobs.
290 */
291private class SaveInfoListener extends SparkListener {
292  type StageId = Int
293  type StageAttemptId = Int
294
295  private val completedStageInfos = new ArrayBuffer[StageInfo]
296  private val completedTaskInfos =
297    new mutable.HashMap[(StageId, StageAttemptId), ArrayBuffer[TaskInfo]]
298
299  // Callback to call when a job completes. Parameter is job ID.
300  @GuardedBy("this")
301  private var jobCompletionCallback: () => Unit = null
302  private val jobCompletionSem = new Semaphore(0)
303  private var exception: Throwable = null
304
305  def getCompletedStageInfos: Seq[StageInfo] = completedStageInfos.toArray.toSeq
306  def getCompletedTaskInfos: Seq[TaskInfo] = completedTaskInfos.values.flatten.toSeq
307  def getCompletedTaskInfos(stageId: StageId, stageAttemptId: StageAttemptId): Seq[TaskInfo] =
308    completedTaskInfos.getOrElse((stageId, stageAttemptId), Seq.empty[TaskInfo])
309
310  /**
311   * If `jobCompletionCallback` is set, block until the next call has finished.
312   * If the callback failed with an exception, throw it.
313   */
314  def awaitNextJobCompletion(): Unit = {
315    if (jobCompletionCallback != null) {
316      jobCompletionSem.acquire()
317      if (exception != null) {
318        throw exception
319      }
320    }
321  }
322
323  /**
324   * Register a callback to be called on job end.
325   * A call to this should be followed by [[awaitNextJobCompletion]].
326   */
327  def registerJobCompletionCallback(callback: () => Unit): Unit = {
328    jobCompletionCallback = callback
329  }
330
331  override def onJobEnd(jobEnd: SparkListenerJobEnd): Unit = {
332    if (jobCompletionCallback != null) {
333      try {
334        jobCompletionCallback()
335      } catch {
336        // Store any exception thrown here so we can throw them later in the main thread.
337        // Otherwise, if `jobCompletionCallback` threw something it wouldn't fail the test.
338        case NonFatal(e) => exception = e
339      } finally {
340        jobCompletionSem.release()
341      }
342    }
343  }
344
345  override def onStageCompleted(stageCompleted: SparkListenerStageCompleted): Unit = {
346    completedStageInfos += stageCompleted.stageInfo
347  }
348
349  override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = {
350    completedTaskInfos.getOrElseUpdate(
351      (taskEnd.stageId, taskEnd.stageAttemptId), new ArrayBuffer[TaskInfo]) += taskEnd.taskInfo
352  }
353}
354