1/*
2 * Licensed to the Apache Software Foundation (ASF) under one or more
3 * contributor license agreements.  See the NOTICE file distributed with
4 * this work for additional information regarding copyright ownership.
5 * The ASF licenses this file to You under the Apache License, Version 2.0
6 * (the "License"); you may not use this file except in compliance with
7 * the License.  You may obtain a copy of the License at
8 *
9 *    http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17
18package org.apache.spark.scheduler
19
20import java.io.File
21import java.net.URL
22import java.nio.ByteBuffer
23
24import scala.collection.mutable.ArrayBuffer
25import scala.concurrent.duration._
26import scala.language.postfixOps
27import scala.util.control.NonFatal
28
29import com.google.common.util.concurrent.MoreExecutors
30import org.mockito.ArgumentCaptor
31import org.mockito.Matchers.{any, anyLong}
32import org.mockito.Mockito.{spy, times, verify}
33import org.scalatest.BeforeAndAfter
34import org.scalatest.concurrent.Eventually._
35
36import org.apache.spark._
37import org.apache.spark.storage.TaskResultBlockId
38import org.apache.spark.TestUtils.JavaSourceFromString
39import org.apache.spark.util.{MutableURLClassLoader, RpcUtils, Utils}
40
41
42/**
43 * Removes the TaskResult from the BlockManager before delegating to a normal TaskResultGetter.
44 *
45 * Used to test the case where a BlockManager evicts the task result (or dies) before the
46 * TaskResult is retrieved.
47 */
48private class ResultDeletingTaskResultGetter(sparkEnv: SparkEnv, scheduler: TaskSchedulerImpl)
49  extends TaskResultGetter(sparkEnv, scheduler) {
50  var removedResult = false
51
52  @volatile var removeBlockSuccessfully = false
53
54  override def enqueueSuccessfulTask(
55    taskSetManager: TaskSetManager, tid: Long, serializedData: ByteBuffer) {
56    if (!removedResult) {
57      // Only remove the result once, since we'd like to test the case where the task eventually
58      // succeeds.
59      serializer.get().deserialize[TaskResult[_]](serializedData) match {
60        case IndirectTaskResult(blockId, size) =>
61          sparkEnv.blockManager.master.removeBlock(blockId)
62          // removeBlock is asynchronous. Need to wait it's removed successfully
63          try {
64            eventually(timeout(3 seconds), interval(200 milliseconds)) {
65              assert(!sparkEnv.blockManager.master.contains(blockId))
66            }
67            removeBlockSuccessfully = true
68          } catch {
69            case NonFatal(e) => removeBlockSuccessfully = false
70          }
71        case directResult: DirectTaskResult[_] =>
72          taskSetManager.abort("Internal error: expect only indirect results")
73      }
74      serializedData.rewind()
75      removedResult = true
76    }
77    super.enqueueSuccessfulTask(taskSetManager, tid, serializedData)
78  }
79}
80
81
82/**
83 * A [[TaskResultGetter]] that stores the [[DirectTaskResult]]s it receives from executors
84 * _before_ modifying the results in any way.
85 */
86private class MyTaskResultGetter(env: SparkEnv, scheduler: TaskSchedulerImpl)
87  extends TaskResultGetter(env, scheduler) {
88
89  // Use the current thread so we can access its results synchronously
90  protected override val getTaskResultExecutor = MoreExecutors.sameThreadExecutor()
91
92  // DirectTaskResults that we receive from the executors
93  private val _taskResults = new ArrayBuffer[DirectTaskResult[_]]
94
95  def taskResults: Seq[DirectTaskResult[_]] = _taskResults
96
97  override def enqueueSuccessfulTask(tsm: TaskSetManager, tid: Long, data: ByteBuffer): Unit = {
98    // work on a copy since the super class still needs to use the buffer
99    val newBuffer = data.duplicate()
100    _taskResults += env.closureSerializer.newInstance().deserialize[DirectTaskResult[_]](newBuffer)
101    super.enqueueSuccessfulTask(tsm, tid, data)
102  }
103}
104
105
106/**
107 * Tests related to handling task results (both direct and indirect).
108 */
109class TaskResultGetterSuite extends SparkFunSuite with BeforeAndAfter with LocalSparkContext {
110
111  // Set the RPC message size to be as small as possible (it must be an integer, so 1 is as small
112  // as we can make it) so the tests don't take too long.
113  def conf: SparkConf = new SparkConf().set("spark.rpc.message.maxSize", "1")
114
115  test("handling results smaller than max RPC message size") {
116    sc = new SparkContext("local", "test", conf)
117    val result = sc.parallelize(Seq(1), 1).map(x => 2 * x).reduce((x, y) => x)
118    assert(result === 2)
119  }
120
121  test("handling results larger than max RPC message size") {
122    sc = new SparkContext("local", "test", conf)
123    val maxRpcMessageSize = RpcUtils.maxMessageSizeBytes(conf)
124    val result =
125      sc.parallelize(Seq(1), 1).map(x => 1.to(maxRpcMessageSize).toArray).reduce((x, y) => x)
126    assert(result === 1.to(maxRpcMessageSize).toArray)
127
128    val RESULT_BLOCK_ID = TaskResultBlockId(0)
129    assert(sc.env.blockManager.master.getLocations(RESULT_BLOCK_ID).size === 0,
130      "Expect result to be removed from the block manager.")
131  }
132
133  test("task retried if result missing from block manager") {
134    // Set the maximum number of task failures to > 0, so that the task set isn't aborted
135    // after the result is missing.
136    sc = new SparkContext("local[1,2]", "test", conf)
137    // If this test hangs, it's probably because no resource offers were made after the task
138    // failed.
139    val scheduler: TaskSchedulerImpl = sc.taskScheduler match {
140      case taskScheduler: TaskSchedulerImpl =>
141        taskScheduler
142      case _ =>
143        assert(false, "Expect local cluster to use TaskSchedulerImpl")
144        throw new ClassCastException
145    }
146    val resultGetter = new ResultDeletingTaskResultGetter(sc.env, scheduler)
147    scheduler.taskResultGetter = resultGetter
148    val maxRpcMessageSize = RpcUtils.maxMessageSizeBytes(conf)
149    val result =
150      sc.parallelize(Seq(1), 1).map(x => 1.to(maxRpcMessageSize).toArray).reduce((x, y) => x)
151    assert(resultGetter.removeBlockSuccessfully)
152    assert(result === 1.to(maxRpcMessageSize).toArray)
153
154    // Make sure two tasks were run (one failed one, and a second retried one).
155    assert(scheduler.nextTaskId.get() === 2)
156  }
157
158  /**
159   * Make sure we are using the context classloader when deserializing failed TaskResults instead
160   * of the Spark classloader.
161
162   * This test compiles a jar containing an exception and tests that when it is thrown on the
163   * executor, enqueueFailedTask can correctly deserialize the failure and identify the thrown
164   * exception as the cause.
165
166   * Before this fix, enqueueFailedTask would throw a ClassNotFoundException when deserializing
167   * the exception, resulting in an UnknownReason for the TaskEndResult.
168   */
169  test("failed task deserialized with the correct classloader (SPARK-11195)") {
170    // compile a small jar containing an exception that will be thrown on an executor.
171    val tempDir = Utils.createTempDir()
172    val srcDir = new File(tempDir, "repro/")
173    srcDir.mkdirs()
174    val excSource = new JavaSourceFromString(new File(srcDir, "MyException").getAbsolutePath,
175      """package repro;
176        |
177        |public class MyException extends Exception {
178        |}
179      """.stripMargin)
180    val excFile = TestUtils.createCompiledClass("MyException", srcDir, excSource, Seq.empty)
181    val jarFile = new File(tempDir, "testJar-%s.jar".format(System.currentTimeMillis()))
182    TestUtils.createJar(Seq(excFile), jarFile, directoryPrefix = Some("repro"))
183
184    // ensure we reset the classloader after the test completes
185    val originalClassLoader = Thread.currentThread.getContextClassLoader
186    try {
187      // load the exception from the jar
188      val loader = new MutableURLClassLoader(new Array[URL](0), originalClassLoader)
189      loader.addURL(jarFile.toURI.toURL)
190      Thread.currentThread().setContextClassLoader(loader)
191      val excClass: Class[_] = Utils.classForName("repro.MyException")
192
193      // NOTE: we must run the cluster with "local" so that the executor can load the compiled
194      // jar.
195      sc = new SparkContext("local", "test", conf)
196      val rdd = sc.parallelize(Seq(1), 1).map { _ =>
197        val exc = excClass.newInstance().asInstanceOf[Exception]
198        throw exc
199      }
200
201      // the driver should not have any problems resolving the exception class and determining
202      // why the task failed.
203      val exceptionMessage = intercept[SparkException] {
204        rdd.collect()
205      }.getMessage
206
207      val expectedFailure = """(?s).*Lost task.*: repro.MyException.*""".r
208      val unknownFailure = """(?s).*Lost task.*: UnknownReason.*""".r
209
210      assert(expectedFailure.findFirstMatchIn(exceptionMessage).isDefined)
211      assert(unknownFailure.findFirstMatchIn(exceptionMessage).isEmpty)
212    } finally {
213      Thread.currentThread.setContextClassLoader(originalClassLoader)
214    }
215  }
216
217  test("task result size is set on the driver, not the executors") {
218    import InternalAccumulator._
219
220    // Set up custom TaskResultGetter and TaskSchedulerImpl spy
221    sc = new SparkContext("local", "test", conf)
222    val scheduler = sc.taskScheduler.asInstanceOf[TaskSchedulerImpl]
223    val spyScheduler = spy(scheduler)
224    val resultGetter = new MyTaskResultGetter(sc.env, spyScheduler)
225    val newDAGScheduler = new DAGScheduler(sc, spyScheduler)
226    scheduler.taskResultGetter = resultGetter
227    sc.dagScheduler = newDAGScheduler
228    sc.taskScheduler = spyScheduler
229    sc.taskScheduler.setDAGScheduler(newDAGScheduler)
230
231    // Just run 1 task and capture the corresponding DirectTaskResult
232    sc.parallelize(1 to 1, 1).count()
233    val captor = ArgumentCaptor.forClass(classOf[DirectTaskResult[_]])
234    verify(spyScheduler, times(1)).handleSuccessfulTask(any(), anyLong(), captor.capture())
235
236    // When a task finishes, the executor sends a serialized DirectTaskResult to the driver
237    // without setting the result size so as to avoid serializing the result again. Instead,
238    // the result size is set later in TaskResultGetter on the driver before passing the
239    // DirectTaskResult on to TaskSchedulerImpl. In this test, we capture the DirectTaskResult
240    // before and after the result size is set.
241    assert(resultGetter.taskResults.size === 1)
242    val resBefore = resultGetter.taskResults.head
243    val resAfter = captor.getValue
244    val resSizeBefore = resBefore.accumUpdates.find(_.name == Some(RESULT_SIZE)).map(_.value)
245    val resSizeAfter = resAfter.accumUpdates.find(_.name == Some(RESULT_SIZE)).map(_.value)
246    assert(resSizeBefore.exists(_ == 0L))
247    assert(resSizeAfter.exists(_.toString.toLong > 0L))
248  }
249
250}
251
252