1/* 2 * Licensed to the Apache Software Foundation (ASF) under one or more 3 * contributor license agreements. See the NOTICE file distributed with 4 * this work for additional information regarding copyright ownership. 5 * The ASF licenses this file to You under the Apache License, Version 2.0 6 * (the "License"); you may not use this file except in compliance with 7 * the License. You may obtain a copy of the License at 8 * 9 * http://www.apache.org/licenses/LICENSE-2.0 10 * 11 * Unless required by applicable law or agreed to in writing, software 12 * distributed under the License is distributed on an "AS IS" BASIS, 13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 * See the License for the specific language governing permissions and 15 * limitations under the License. 16 */ 17 18package org.apache.spark.scheduler 19 20import java.io.File 21import java.net.URL 22import java.nio.ByteBuffer 23 24import scala.collection.mutable.ArrayBuffer 25import scala.concurrent.duration._ 26import scala.language.postfixOps 27import scala.util.control.NonFatal 28 29import com.google.common.util.concurrent.MoreExecutors 30import org.mockito.ArgumentCaptor 31import org.mockito.Matchers.{any, anyLong} 32import org.mockito.Mockito.{spy, times, verify} 33import org.scalatest.BeforeAndAfter 34import org.scalatest.concurrent.Eventually._ 35 36import org.apache.spark._ 37import org.apache.spark.storage.TaskResultBlockId 38import org.apache.spark.TestUtils.JavaSourceFromString 39import org.apache.spark.util.{MutableURLClassLoader, RpcUtils, Utils} 40 41 42/** 43 * Removes the TaskResult from the BlockManager before delegating to a normal TaskResultGetter. 44 * 45 * Used to test the case where a BlockManager evicts the task result (or dies) before the 46 * TaskResult is retrieved. 47 */ 48private class ResultDeletingTaskResultGetter(sparkEnv: SparkEnv, scheduler: TaskSchedulerImpl) 49 extends TaskResultGetter(sparkEnv, scheduler) { 50 var removedResult = false 51 52 @volatile var removeBlockSuccessfully = false 53 54 override def enqueueSuccessfulTask( 55 taskSetManager: TaskSetManager, tid: Long, serializedData: ByteBuffer) { 56 if (!removedResult) { 57 // Only remove the result once, since we'd like to test the case where the task eventually 58 // succeeds. 59 serializer.get().deserialize[TaskResult[_]](serializedData) match { 60 case IndirectTaskResult(blockId, size) => 61 sparkEnv.blockManager.master.removeBlock(blockId) 62 // removeBlock is asynchronous. Need to wait it's removed successfully 63 try { 64 eventually(timeout(3 seconds), interval(200 milliseconds)) { 65 assert(!sparkEnv.blockManager.master.contains(blockId)) 66 } 67 removeBlockSuccessfully = true 68 } catch { 69 case NonFatal(e) => removeBlockSuccessfully = false 70 } 71 case directResult: DirectTaskResult[_] => 72 taskSetManager.abort("Internal error: expect only indirect results") 73 } 74 serializedData.rewind() 75 removedResult = true 76 } 77 super.enqueueSuccessfulTask(taskSetManager, tid, serializedData) 78 } 79} 80 81 82/** 83 * A [[TaskResultGetter]] that stores the [[DirectTaskResult]]s it receives from executors 84 * _before_ modifying the results in any way. 85 */ 86private class MyTaskResultGetter(env: SparkEnv, scheduler: TaskSchedulerImpl) 87 extends TaskResultGetter(env, scheduler) { 88 89 // Use the current thread so we can access its results synchronously 90 protected override val getTaskResultExecutor = MoreExecutors.sameThreadExecutor() 91 92 // DirectTaskResults that we receive from the executors 93 private val _taskResults = new ArrayBuffer[DirectTaskResult[_]] 94 95 def taskResults: Seq[DirectTaskResult[_]] = _taskResults 96 97 override def enqueueSuccessfulTask(tsm: TaskSetManager, tid: Long, data: ByteBuffer): Unit = { 98 // work on a copy since the super class still needs to use the buffer 99 val newBuffer = data.duplicate() 100 _taskResults += env.closureSerializer.newInstance().deserialize[DirectTaskResult[_]](newBuffer) 101 super.enqueueSuccessfulTask(tsm, tid, data) 102 } 103} 104 105 106/** 107 * Tests related to handling task results (both direct and indirect). 108 */ 109class TaskResultGetterSuite extends SparkFunSuite with BeforeAndAfter with LocalSparkContext { 110 111 // Set the RPC message size to be as small as possible (it must be an integer, so 1 is as small 112 // as we can make it) so the tests don't take too long. 113 def conf: SparkConf = new SparkConf().set("spark.rpc.message.maxSize", "1") 114 115 test("handling results smaller than max RPC message size") { 116 sc = new SparkContext("local", "test", conf) 117 val result = sc.parallelize(Seq(1), 1).map(x => 2 * x).reduce((x, y) => x) 118 assert(result === 2) 119 } 120 121 test("handling results larger than max RPC message size") { 122 sc = new SparkContext("local", "test", conf) 123 val maxRpcMessageSize = RpcUtils.maxMessageSizeBytes(conf) 124 val result = 125 sc.parallelize(Seq(1), 1).map(x => 1.to(maxRpcMessageSize).toArray).reduce((x, y) => x) 126 assert(result === 1.to(maxRpcMessageSize).toArray) 127 128 val RESULT_BLOCK_ID = TaskResultBlockId(0) 129 assert(sc.env.blockManager.master.getLocations(RESULT_BLOCK_ID).size === 0, 130 "Expect result to be removed from the block manager.") 131 } 132 133 test("task retried if result missing from block manager") { 134 // Set the maximum number of task failures to > 0, so that the task set isn't aborted 135 // after the result is missing. 136 sc = new SparkContext("local[1,2]", "test", conf) 137 // If this test hangs, it's probably because no resource offers were made after the task 138 // failed. 139 val scheduler: TaskSchedulerImpl = sc.taskScheduler match { 140 case taskScheduler: TaskSchedulerImpl => 141 taskScheduler 142 case _ => 143 assert(false, "Expect local cluster to use TaskSchedulerImpl") 144 throw new ClassCastException 145 } 146 val resultGetter = new ResultDeletingTaskResultGetter(sc.env, scheduler) 147 scheduler.taskResultGetter = resultGetter 148 val maxRpcMessageSize = RpcUtils.maxMessageSizeBytes(conf) 149 val result = 150 sc.parallelize(Seq(1), 1).map(x => 1.to(maxRpcMessageSize).toArray).reduce((x, y) => x) 151 assert(resultGetter.removeBlockSuccessfully) 152 assert(result === 1.to(maxRpcMessageSize).toArray) 153 154 // Make sure two tasks were run (one failed one, and a second retried one). 155 assert(scheduler.nextTaskId.get() === 2) 156 } 157 158 /** 159 * Make sure we are using the context classloader when deserializing failed TaskResults instead 160 * of the Spark classloader. 161 162 * This test compiles a jar containing an exception and tests that when it is thrown on the 163 * executor, enqueueFailedTask can correctly deserialize the failure and identify the thrown 164 * exception as the cause. 165 166 * Before this fix, enqueueFailedTask would throw a ClassNotFoundException when deserializing 167 * the exception, resulting in an UnknownReason for the TaskEndResult. 168 */ 169 test("failed task deserialized with the correct classloader (SPARK-11195)") { 170 // compile a small jar containing an exception that will be thrown on an executor. 171 val tempDir = Utils.createTempDir() 172 val srcDir = new File(tempDir, "repro/") 173 srcDir.mkdirs() 174 val excSource = new JavaSourceFromString(new File(srcDir, "MyException").getAbsolutePath, 175 """package repro; 176 | 177 |public class MyException extends Exception { 178 |} 179 """.stripMargin) 180 val excFile = TestUtils.createCompiledClass("MyException", srcDir, excSource, Seq.empty) 181 val jarFile = new File(tempDir, "testJar-%s.jar".format(System.currentTimeMillis())) 182 TestUtils.createJar(Seq(excFile), jarFile, directoryPrefix = Some("repro")) 183 184 // ensure we reset the classloader after the test completes 185 val originalClassLoader = Thread.currentThread.getContextClassLoader 186 try { 187 // load the exception from the jar 188 val loader = new MutableURLClassLoader(new Array[URL](0), originalClassLoader) 189 loader.addURL(jarFile.toURI.toURL) 190 Thread.currentThread().setContextClassLoader(loader) 191 val excClass: Class[_] = Utils.classForName("repro.MyException") 192 193 // NOTE: we must run the cluster with "local" so that the executor can load the compiled 194 // jar. 195 sc = new SparkContext("local", "test", conf) 196 val rdd = sc.parallelize(Seq(1), 1).map { _ => 197 val exc = excClass.newInstance().asInstanceOf[Exception] 198 throw exc 199 } 200 201 // the driver should not have any problems resolving the exception class and determining 202 // why the task failed. 203 val exceptionMessage = intercept[SparkException] { 204 rdd.collect() 205 }.getMessage 206 207 val expectedFailure = """(?s).*Lost task.*: repro.MyException.*""".r 208 val unknownFailure = """(?s).*Lost task.*: UnknownReason.*""".r 209 210 assert(expectedFailure.findFirstMatchIn(exceptionMessage).isDefined) 211 assert(unknownFailure.findFirstMatchIn(exceptionMessage).isEmpty) 212 } finally { 213 Thread.currentThread.setContextClassLoader(originalClassLoader) 214 } 215 } 216 217 test("task result size is set on the driver, not the executors") { 218 import InternalAccumulator._ 219 220 // Set up custom TaskResultGetter and TaskSchedulerImpl spy 221 sc = new SparkContext("local", "test", conf) 222 val scheduler = sc.taskScheduler.asInstanceOf[TaskSchedulerImpl] 223 val spyScheduler = spy(scheduler) 224 val resultGetter = new MyTaskResultGetter(sc.env, spyScheduler) 225 val newDAGScheduler = new DAGScheduler(sc, spyScheduler) 226 scheduler.taskResultGetter = resultGetter 227 sc.dagScheduler = newDAGScheduler 228 sc.taskScheduler = spyScheduler 229 sc.taskScheduler.setDAGScheduler(newDAGScheduler) 230 231 // Just run 1 task and capture the corresponding DirectTaskResult 232 sc.parallelize(1 to 1, 1).count() 233 val captor = ArgumentCaptor.forClass(classOf[DirectTaskResult[_]]) 234 verify(spyScheduler, times(1)).handleSuccessfulTask(any(), anyLong(), captor.capture()) 235 236 // When a task finishes, the executor sends a serialized DirectTaskResult to the driver 237 // without setting the result size so as to avoid serializing the result again. Instead, 238 // the result size is set later in TaskResultGetter on the driver before passing the 239 // DirectTaskResult on to TaskSchedulerImpl. In this test, we capture the DirectTaskResult 240 // before and after the result size is set. 241 assert(resultGetter.taskResults.size === 1) 242 val resBefore = resultGetter.taskResults.head 243 val resAfter = captor.getValue 244 val resSizeBefore = resBefore.accumUpdates.find(_.name == Some(RESULT_SIZE)).map(_.value) 245 val resSizeAfter = resAfter.accumUpdates.find(_.name == Some(RESULT_SIZE)).map(_.value) 246 assert(resSizeBefore.exists(_ == 0L)) 247 assert(resSizeAfter.exists(_.toString.toLong > 0L)) 248 } 249 250} 251 252