1/*
2 * Licensed to the Apache Software Foundation (ASF) under one or more
3 * contributor license agreements.  See the NOTICE file distributed with
4 * this work for additional information regarding copyright ownership.
5 * The ASF licenses this file to You under the Apache License, Version 2.0
6 * (the "License"); you may not use this file except in compliance with
7 * the License.  You may obtain a copy of the License at
8 *
9 *    http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17
18package org.apache.spark.util
19
20import java.util.EventListener
21
22import org.apache.spark.TaskContext
23import org.apache.spark.annotation.DeveloperApi
24
25/**
26 * :: DeveloperApi ::
27 *
28 * Listener providing a callback function to invoke when a task's execution completes.
29 */
30@DeveloperApi
31trait TaskCompletionListener extends EventListener {
32  def onTaskCompletion(context: TaskContext): Unit
33}
34
35
36/**
37 * :: DeveloperApi ::
38 *
39 * Listener providing a callback function to invoke when a task's execution encounters an error.
40 * Operations defined here must be idempotent, as `onTaskFailure` can be called multiple times.
41 */
42@DeveloperApi
43trait TaskFailureListener extends EventListener {
44  def onTaskFailure(context: TaskContext, error: Throwable): Unit
45}
46
47
48/**
49 * Exception thrown when there is an exception in executing the callback in TaskCompletionListener.
50 */
51private[spark]
52class TaskCompletionListenerException(
53    errorMessages: Seq[String],
54    val previousError: Option[Throwable] = None)
55  extends RuntimeException {
56
57  override def getMessage: String = {
58    if (errorMessages.size == 1) {
59      errorMessages.head
60    } else {
61      errorMessages.zipWithIndex.map { case (msg, i) => s"Exception $i: $msg" }.mkString("\n")
62    } +
63    previousError.map { e =>
64      "\n\nPrevious exception in task: " + e.getMessage + "\n" +
65        e.getStackTrace.mkString("\t", "\n\t", "")
66    }.getOrElse("")
67  }
68}
69