1/*
2 * Licensed to the Apache Software Foundation (ASF) under one or more
3 * contributor license agreements.  See the NOTICE file distributed with
4 * this work for additional information regarding copyright ownership.
5 * The ASF licenses this file to You under the Apache License, Version 2.0
6 * (the "License"); you may not use this file except in compliance with
7 * the License.  You may obtain a copy of the License at
8 *
9 *    http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17
18package org.apache.spark
19
20import scala.collection.JavaConverters._
21import scala.collection.mutable.ArrayBuffer
22
23import org.apache.spark.executor.TaskMetrics
24import org.apache.spark.scheduler.AccumulableInfo
25import org.apache.spark.shuffle.FetchFailedException
26import org.apache.spark.util.{AccumulatorContext, AccumulatorV2}
27
28
29class InternalAccumulatorSuite extends SparkFunSuite with LocalSparkContext {
30  import InternalAccumulator._
31
32  override def afterEach(): Unit = {
33    try {
34      AccumulatorContext.clear()
35    } finally {
36      super.afterEach()
37    }
38  }
39
40  test("internal accumulators in TaskContext") {
41    val taskContext = TaskContext.empty()
42    val accumUpdates = taskContext.taskMetrics.accumulators()
43    assert(accumUpdates.size > 0)
44    val testAccum = taskContext.taskMetrics.testAccum.get
45    assert(accumUpdates.exists(_.id == testAccum.id))
46  }
47
48  test("internal accumulators in a stage") {
49    val listener = new SaveInfoListener
50    val numPartitions = 10
51    sc = new SparkContext("local", "test")
52    sc.addSparkListener(listener)
53    // Have each task add 1 to the internal accumulator
54    val rdd = sc.parallelize(1 to 100, numPartitions).mapPartitions { iter =>
55      TaskContext.get().taskMetrics().testAccum.get.add(1)
56      iter
57    }
58    // Register asserts in job completion callback to avoid flakiness
59    listener.registerJobCompletionCallback { () =>
60      val stageInfos = listener.getCompletedStageInfos
61      val taskInfos = listener.getCompletedTaskInfos
62      assert(stageInfos.size === 1)
63      assert(taskInfos.size === numPartitions)
64      // The accumulator values should be merged in the stage
65      val stageAccum = findTestAccum(stageInfos.head.accumulables.values)
66      assert(stageAccum.value.get.asInstanceOf[Long] === numPartitions)
67      // The accumulator should be updated locally on each task
68      val taskAccumValues = taskInfos.map { taskInfo =>
69        val taskAccum = findTestAccum(taskInfo.accumulables)
70        assert(taskAccum.update.isDefined)
71        assert(taskAccum.update.get.asInstanceOf[Long] === 1L)
72        taskAccum.value.get.asInstanceOf[Long]
73      }
74      // Each task should keep track of the partial value on the way, i.e. 1, 2, ... numPartitions
75      assert(taskAccumValues.sorted === (1L to numPartitions).toSeq)
76    }
77    rdd.count()
78    listener.awaitNextJobCompletion()
79  }
80
81  test("internal accumulators in multiple stages") {
82    val listener = new SaveInfoListener
83    val numPartitions = 10
84    sc = new SparkContext("local", "test")
85    sc.addSparkListener(listener)
86    // Each stage creates its own set of internal accumulators so the
87    // values for the same metric should not be mixed up across stages
88    val rdd = sc.parallelize(1 to 100, numPartitions)
89      .map { i => (i, i) }
90      .mapPartitions { iter =>
91        TaskContext.get().taskMetrics().testAccum.get.add(1)
92        iter
93      }
94      .reduceByKey { case (x, y) => x + y }
95      .mapPartitions { iter =>
96        TaskContext.get().taskMetrics().testAccum.get.add(10)
97        iter
98      }
99      .repartition(numPartitions * 2)
100      .mapPartitions { iter =>
101        TaskContext.get().taskMetrics().testAccum.get.add(100)
102        iter
103      }
104    // Register asserts in job completion callback to avoid flakiness
105    listener.registerJobCompletionCallback { () =>
106    // We ran 3 stages, and the accumulator values should be distinct
107      val stageInfos = listener.getCompletedStageInfos
108      assert(stageInfos.size === 3)
109      val (firstStageAccum, secondStageAccum, thirdStageAccum) =
110        (findTestAccum(stageInfos(0).accumulables.values),
111        findTestAccum(stageInfos(1).accumulables.values),
112        findTestAccum(stageInfos(2).accumulables.values))
113      assert(firstStageAccum.value.get.asInstanceOf[Long] === numPartitions)
114      assert(secondStageAccum.value.get.asInstanceOf[Long] === numPartitions * 10)
115      assert(thirdStageAccum.value.get.asInstanceOf[Long] === numPartitions * 2 * 100)
116    }
117    rdd.count()
118  }
119
120  test("internal accumulators in resubmitted stages") {
121    val listener = new SaveInfoListener
122    val numPartitions = 10
123    sc = new SparkContext("local", "test")
124    sc.addSparkListener(listener)
125
126    // Simulate fetch failures in order to trigger a stage retry. Here we run 1 job with
127    // 2 stages. On the second stage, we trigger a fetch failure on the first stage attempt.
128    // This should retry both stages in the scheduler. Note that we only want to fail the
129    // first stage attempt because we want the stage to eventually succeed.
130    val x = sc.parallelize(1 to 100, numPartitions)
131      .mapPartitions { iter => TaskContext.get().taskMetrics().testAccum.get.add(1); iter }
132      .groupBy(identity)
133    val sid = x.dependencies.head.asInstanceOf[ShuffleDependency[_, _, _]].shuffleHandle.shuffleId
134    val rdd = x.mapPartitionsWithIndex { case (i, iter) =>
135      // Fail the first stage attempt. Here we use the task attempt ID to determine this.
136      // This job runs 2 stages, and we're in the second stage. Therefore, any task attempt
137      // ID that's < 2 * numPartitions belongs to the first attempt of this stage.
138      val taskContext = TaskContext.get()
139      val isFirstStageAttempt = taskContext.taskAttemptId() < numPartitions * 2
140      if (isFirstStageAttempt) {
141        throw new FetchFailedException(
142          SparkEnv.get.blockManager.blockManagerId,
143          sid,
144          taskContext.partitionId(),
145          taskContext.partitionId(),
146          "simulated fetch failure")
147      } else {
148        iter
149      }
150    }
151
152    // Register asserts in job completion callback to avoid flakiness
153    listener.registerJobCompletionCallback { () =>
154      val stageInfos = listener.getCompletedStageInfos
155      assert(stageInfos.size === 4) // 1 shuffle map stage + 1 result stage, both are retried
156      val mapStageId = stageInfos.head.stageId
157      val mapStageInfo1stAttempt = stageInfos.head
158      val mapStageInfo2ndAttempt = {
159        stageInfos.tail.find(_.stageId == mapStageId).getOrElse {
160          fail("expected two attempts of the same shuffle map stage.")
161        }
162      }
163      val stageAccum1stAttempt = findTestAccum(mapStageInfo1stAttempt.accumulables.values)
164      val stageAccum2ndAttempt = findTestAccum(mapStageInfo2ndAttempt.accumulables.values)
165      // Both map stages should have succeeded, since the fetch failure happened in the
166      // result stage, not the map stage. This means we should get the accumulator updates
167      // from all partitions.
168      assert(stageAccum1stAttempt.value.get.asInstanceOf[Long] === numPartitions)
169      assert(stageAccum2ndAttempt.value.get.asInstanceOf[Long] === numPartitions)
170      // Because this test resubmitted the map stage with all missing partitions, we should have
171      // created a fresh set of internal accumulators in the 2nd stage attempt. Assert this is
172      // the case by comparing the accumulator IDs between the two attempts.
173      // Note: it would be good to also test the case where the map stage is resubmitted where
174      // only a subset of the original partitions are missing. However, this scenario is very
175      // difficult to construct without potentially introducing flakiness.
176      assert(stageAccum1stAttempt.id != stageAccum2ndAttempt.id)
177    }
178    rdd.count()
179    listener.awaitNextJobCompletion()
180  }
181
182  test("internal accumulators are registered for cleanups") {
183    sc = new SparkContext("local", "test") {
184      private val myCleaner = new SaveAccumContextCleaner(this)
185      override def cleaner: Option[ContextCleaner] = Some(myCleaner)
186    }
187    assert(AccumulatorContext.numAccums == 0)
188    sc.parallelize(1 to 100).map { i => (i, i) }.reduceByKey { _ + _ }.count()
189    val numInternalAccums = TaskMetrics.empty.internalAccums.length
190    // We ran 2 stages, so we should have 2 sets of internal accumulators, 1 for each stage
191    assert(AccumulatorContext.numAccums === numInternalAccums * 2)
192    val accumsRegistered = sc.cleaner match {
193      case Some(cleaner: SaveAccumContextCleaner) => cleaner.accumsRegisteredForCleanup
194      case _ => Seq.empty[Long]
195    }
196    // Make sure the same set of accumulators is registered for cleanup
197    assert(accumsRegistered.size === numInternalAccums * 2)
198    assert(accumsRegistered.toSet.size === AccumulatorContext.numAccums)
199    accumsRegistered.foreach(id => assert(AccumulatorContext.get(id) != None))
200  }
201
202  /**
203   * Return the accumulable info that matches the specified name.
204   */
205  private def findTestAccum(accums: Iterable[AccumulableInfo]): AccumulableInfo = {
206    accums.find { a => a.name == Some(TEST_ACCUM) }.getOrElse {
207      fail(s"unable to find internal accumulator called $TEST_ACCUM")
208    }
209  }
210
211  /**
212   * A special [[ContextCleaner]] that saves the IDs of the accumulators registered for cleanup.
213   */
214  private class SaveAccumContextCleaner(sc: SparkContext) extends ContextCleaner(sc) {
215    private val accumsRegistered = new ArrayBuffer[Long]
216
217    override def registerAccumulatorForCleanup(a: AccumulatorV2[_, _]): Unit = {
218      accumsRegistered += a.id
219      super.registerAccumulatorForCleanup(a)
220    }
221
222    def accumsRegisteredForCleanup: Seq[Long] = accumsRegistered.toArray
223  }
224
225}
226