1/*
2 * Licensed to the Apache Software Foundation (ASF) under one or more
3 * contributor license agreements.  See the NOTICE file distributed with
4 * this work for additional information regarding copyright ownership.
5 * The ASF licenses this file to You under the Apache License, Version 2.0
6 * (the "License"); you may not use this file except in compliance with
7 * the License.  You may obtain a copy of the License at
8 *
9 *    http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17
18package org.apache.spark.scheduler.cluster
19
20import scala.concurrent.{ExecutionContext, Future}
21import scala.util.{Failure, Success}
22import scala.util.control.NonFatal
23
24import org.apache.hadoop.yarn.api.records.{ApplicationAttemptId, ApplicationId}
25
26import org.apache.spark.SparkContext
27import org.apache.spark.internal.Logging
28import org.apache.spark.rpc._
29import org.apache.spark.scheduler._
30import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages._
31import org.apache.spark.ui.JettyUtils
32import org.apache.spark.util.{RpcUtils, ThreadUtils}
33
34/**
35 * Abstract Yarn scheduler backend that contains common logic
36 * between the client and cluster Yarn scheduler backends.
37 */
38private[spark] abstract class YarnSchedulerBackend(
39    scheduler: TaskSchedulerImpl,
40    sc: SparkContext)
41  extends CoarseGrainedSchedulerBackend(scheduler, sc.env.rpcEnv) {
42
43  override val minRegisteredRatio =
44    if (conf.getOption("spark.scheduler.minRegisteredResourcesRatio").isEmpty) {
45      0.8
46    } else {
47      super.minRegisteredRatio
48    }
49
50  protected var totalExpectedExecutors = 0
51
52  private val yarnSchedulerEndpoint = new YarnSchedulerEndpoint(rpcEnv)
53
54  private val yarnSchedulerEndpointRef = rpcEnv.setupEndpoint(
55    YarnSchedulerBackend.ENDPOINT_NAME, yarnSchedulerEndpoint)
56
57  private implicit val askTimeout = RpcUtils.askRpcTimeout(sc.conf)
58
59  /** Application ID. */
60  protected var appId: Option[ApplicationId] = None
61
62  /** Attempt ID. This is unset for client-mode schedulers */
63  private var attemptId: Option[ApplicationAttemptId] = None
64
65  /** Scheduler extension services. */
66  private val services: SchedulerExtensionServices = new SchedulerExtensionServices()
67
68  // Flag to specify whether this schedulerBackend should be reset.
69  private var shouldResetOnAmRegister = false
70
71  /**
72   * Bind to YARN. This *must* be done before calling [[start()]].
73   *
74   * @param appId YARN application ID
75   * @param attemptId Optional YARN attempt ID
76   */
77  protected def bindToYarn(appId: ApplicationId, attemptId: Option[ApplicationAttemptId]): Unit = {
78    this.appId = Some(appId)
79    this.attemptId = attemptId
80  }
81
82  override def start() {
83    require(appId.isDefined, "application ID unset")
84    val binding = SchedulerExtensionServiceBinding(sc, appId.get, attemptId)
85    services.start(binding)
86    super.start()
87  }
88
89  override def stop(): Unit = {
90    try {
91      // SPARK-12009: To prevent Yarn allocator from requesting backup for the executors which
92      // was Stopped by SchedulerBackend.
93      requestTotalExecutors(0, 0, Map.empty)
94      super.stop()
95    } finally {
96      services.stop()
97    }
98  }
99
100  /**
101   * Get the attempt ID for this run, if the cluster manager supports multiple
102   * attempts. Applications run in client mode will not have attempt IDs.
103   * This attempt ID only includes attempt counter, like "1", "2".
104   *
105   * @return The application attempt id, if available.
106   */
107  override def applicationAttemptId(): Option[String] = {
108    attemptId.map(_.getAttemptId.toString)
109  }
110
111  /**
112   * Get an application ID associated with the job.
113   * This returns the string value of [[appId]] if set, otherwise
114   * the locally-generated ID from the superclass.
115   * @return The application ID
116   */
117  override def applicationId(): String = {
118    appId.map(_.toString).getOrElse {
119      logWarning("Application ID is not initialized yet.")
120      super.applicationId
121    }
122  }
123
124  /**
125   * Request executors from the ApplicationMaster by specifying the total number desired.
126   * This includes executors already pending or running.
127   */
128  override def doRequestTotalExecutors(requestedTotal: Int): Future[Boolean] = {
129    yarnSchedulerEndpointRef.ask[Boolean](
130      RequestExecutors(requestedTotal, localityAwareTasks, hostToLocalTaskCount))
131  }
132
133  /**
134   * Request that the ApplicationMaster kill the specified executors.
135   */
136  override def doKillExecutors(executorIds: Seq[String]): Future[Boolean] = {
137    yarnSchedulerEndpointRef.ask[Boolean](KillExecutors(executorIds))
138  }
139
140  override def sufficientResourcesRegistered(): Boolean = {
141    totalRegisteredExecutors.get() >= totalExpectedExecutors * minRegisteredRatio
142  }
143
144  /**
145   * Add filters to the SparkUI.
146   */
147  private def addWebUIFilter(
148      filterName: String,
149      filterParams: Map[String, String],
150      proxyBase: String): Unit = {
151    if (proxyBase != null && proxyBase.nonEmpty) {
152      System.setProperty("spark.ui.proxyBase", proxyBase)
153    }
154
155    val hasFilter =
156      filterName != null && filterName.nonEmpty &&
157      filterParams != null && filterParams.nonEmpty
158    if (hasFilter) {
159      logInfo(s"Add WebUI Filter. $filterName, $filterParams, $proxyBase")
160      conf.set("spark.ui.filters", filterName)
161      filterParams.foreach { case (k, v) => conf.set(s"spark.$filterName.param.$k", v) }
162      scheduler.sc.ui.foreach { ui => JettyUtils.addFilters(ui.getHandlers, conf) }
163    }
164  }
165
166  override def createDriverEndpoint(properties: Seq[(String, String)]): DriverEndpoint = {
167    new YarnDriverEndpoint(rpcEnv, properties)
168  }
169
170  /**
171   * Reset the state of SchedulerBackend to the initial state. This is happened when AM is failed
172   * and re-registered itself to driver after a failure. The stale state in driver should be
173   * cleaned.
174   */
175  override protected def reset(): Unit = {
176    super.reset()
177    sc.executorAllocationManager.foreach(_.reset())
178  }
179
180  /**
181   * Override the DriverEndpoint to add extra logic for the case when an executor is disconnected.
182   * This endpoint communicates with the executors and queries the AM for an executor's exit
183   * status when the executor is disconnected.
184   */
185  private class YarnDriverEndpoint(rpcEnv: RpcEnv, sparkProperties: Seq[(String, String)])
186      extends DriverEndpoint(rpcEnv, sparkProperties) {
187
188    /**
189     * When onDisconnected is received at the driver endpoint, the superclass DriverEndpoint
190     * handles it by assuming the Executor was lost for a bad reason and removes the executor
191     * immediately.
192     *
193     * In YARN's case however it is crucial to talk to the application master and ask why the
194     * executor had exited. If the executor exited for some reason unrelated to the running tasks
195     * (e.g., preemption), according to the application master, then we pass that information down
196     * to the TaskSetManager to inform the TaskSetManager that tasks on that lost executor should
197     * not count towards a job failure.
198     */
199    override def onDisconnected(rpcAddress: RpcAddress): Unit = {
200      addressToExecutorId.get(rpcAddress).foreach { executorId =>
201        if (disableExecutor(executorId)) {
202          yarnSchedulerEndpoint.handleExecutorDisconnectedFromDriver(executorId, rpcAddress)
203        }
204      }
205    }
206  }
207
208  /**
209   * An [[RpcEndpoint]] that communicates with the ApplicationMaster.
210   */
211  private class YarnSchedulerEndpoint(override val rpcEnv: RpcEnv)
212    extends ThreadSafeRpcEndpoint with Logging {
213    private var amEndpoint: Option[RpcEndpointRef] = None
214
215    private[YarnSchedulerBackend] def handleExecutorDisconnectedFromDriver(
216        executorId: String,
217        executorRpcAddress: RpcAddress): Unit = {
218      val removeExecutorMessage = amEndpoint match {
219        case Some(am) =>
220          val lossReasonRequest = GetExecutorLossReason(executorId)
221          am.ask[ExecutorLossReason](lossReasonRequest, askTimeout)
222            .map { reason => RemoveExecutor(executorId, reason) }(ThreadUtils.sameThread)
223            .recover {
224              case NonFatal(e) =>
225                logWarning(s"Attempted to get executor loss reason" +
226                  s" for executor id ${executorId} at RPC address ${executorRpcAddress}," +
227                  s" but got no response. Marking as slave lost.", e)
228                RemoveExecutor(executorId, SlaveLost())
229            }(ThreadUtils.sameThread)
230        case None =>
231          logWarning("Attempted to check for an executor loss reason" +
232            " before the AM has registered!")
233          Future.successful(RemoveExecutor(executorId, SlaveLost("AM is not yet registered.")))
234      }
235
236      removeExecutorMessage
237        .flatMap { message =>
238          driverEndpoint.ask[Boolean](message)
239        }(ThreadUtils.sameThread)
240        .onFailure {
241          case NonFatal(e) => logError(
242            s"Error requesting driver to remove executor $executorId after disconnection.", e)
243        }(ThreadUtils.sameThread)
244    }
245
246    override def receive: PartialFunction[Any, Unit] = {
247      case RegisterClusterManager(am) =>
248        logInfo(s"ApplicationMaster registered as $am")
249        amEndpoint = Option(am)
250        if (!shouldResetOnAmRegister) {
251          shouldResetOnAmRegister = true
252        } else {
253          // AM is already registered before, this potentially means that AM failed and
254          // a new one registered after the failure. This will only happen in yarn-client mode.
255          reset()
256        }
257
258      case AddWebUIFilter(filterName, filterParams, proxyBase) =>
259        addWebUIFilter(filterName, filterParams, proxyBase)
260
261      case r @ RemoveExecutor(executorId, reason) =>
262        logWarning(reason.toString)
263        driverEndpoint.ask[Boolean](r).onFailure {
264          case e =>
265            logError("Error requesting driver to remove executor" +
266              s" $executorId for reason $reason", e)
267        }(ThreadUtils.sameThread)
268    }
269
270
271    override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = {
272      case r: RequestExecutors =>
273        amEndpoint match {
274          case Some(am) =>
275            am.ask[Boolean](r).andThen {
276              case Success(b) => context.reply(b)
277              case Failure(NonFatal(e)) =>
278                logError(s"Sending $r to AM was unsuccessful", e)
279                context.sendFailure(e)
280            }(ThreadUtils.sameThread)
281          case None =>
282            logWarning("Attempted to request executors before the AM has registered!")
283            context.reply(false)
284        }
285
286      case k: KillExecutors =>
287        amEndpoint match {
288          case Some(am) =>
289            am.ask[Boolean](k).andThen {
290              case Success(b) => context.reply(b)
291              case Failure(NonFatal(e)) =>
292                logError(s"Sending $k to AM was unsuccessful", e)
293                context.sendFailure(e)
294            }(ThreadUtils.sameThread)
295          case None =>
296            logWarning("Attempted to kill executors before the AM has registered!")
297            context.reply(false)
298        }
299
300      case RetrieveLastAllocatedExecutorId =>
301        context.reply(currentExecutorIdCounter)
302    }
303
304    override def onDisconnected(remoteAddress: RpcAddress): Unit = {
305      if (amEndpoint.exists(_.address == remoteAddress)) {
306        logWarning(s"ApplicationMaster has disassociated: $remoteAddress")
307        amEndpoint = None
308      }
309    }
310  }
311}
312
313private[spark] object YarnSchedulerBackend {
314  val ENDPOINT_NAME = "YarnScheduler"
315}
316