1/*
2 * Licensed to the Apache Software Foundation (ASF) under one or more
3 * contributor license agreements.  See the NOTICE file distributed with
4 * this work for additional information regarding copyright ownership.
5 * The ASF licenses this file to You under the Apache License, Version 2.0
6 * (the "License"); you may not use this file except in compliance with
7 * the License.  You may obtain a copy of the License at
8 *
9 *    http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17
18package org.apache.spark.deploy
19
20import scala.collection.mutable.HashSet
21import scala.concurrent.ExecutionContext
22import scala.reflect.ClassTag
23import scala.util.{Failure, Success}
24
25import org.apache.log4j.Logger
26
27import org.apache.spark.{SecurityManager, SparkConf}
28import org.apache.spark.deploy.DeployMessages._
29import org.apache.spark.deploy.master.{DriverState, Master}
30import org.apache.spark.internal.Logging
31import org.apache.spark.rpc.{RpcAddress, RpcEndpointRef, RpcEnv, ThreadSafeRpcEndpoint}
32import org.apache.spark.util.{SparkExitCode, ThreadUtils, Utils}
33
34/**
35 * Proxy that relays messages to the driver.
36 *
37 * We currently don't support retry if submission fails. In HA mode, client will submit request to
38 * all masters and see which one could handle it.
39 */
40private class ClientEndpoint(
41    override val rpcEnv: RpcEnv,
42    driverArgs: ClientArguments,
43    masterEndpoints: Seq[RpcEndpointRef],
44    conf: SparkConf)
45  extends ThreadSafeRpcEndpoint with Logging {
46
47  // A scheduled executor used to send messages at the specified time.
48  private val forwardMessageThread =
49    ThreadUtils.newDaemonSingleThreadScheduledExecutor("client-forward-message")
50  // Used to provide the implicit parameter of `Future` methods.
51  private val forwardMessageExecutionContext =
52    ExecutionContext.fromExecutor(forwardMessageThread,
53      t => t match {
54        case ie: InterruptedException => // Exit normally
55        case e: Throwable =>
56          logError(e.getMessage, e)
57          System.exit(SparkExitCode.UNCAUGHT_EXCEPTION)
58      })
59
60   private val lostMasters = new HashSet[RpcAddress]
61   private var activeMasterEndpoint: RpcEndpointRef = null
62
63  override def onStart(): Unit = {
64    driverArgs.cmd match {
65      case "launch" =>
66        // TODO: We could add an env variable here and intercept it in `sc.addJar` that would
67        //       truncate filesystem paths similar to what YARN does. For now, we just require
68        //       people call `addJar` assuming the jar is in the same directory.
69        val mainClass = "org.apache.spark.deploy.worker.DriverWrapper"
70
71        val classPathConf = "spark.driver.extraClassPath"
72        val classPathEntries = sys.props.get(classPathConf).toSeq.flatMap { cp =>
73          cp.split(java.io.File.pathSeparator)
74        }
75
76        val libraryPathConf = "spark.driver.extraLibraryPath"
77        val libraryPathEntries = sys.props.get(libraryPathConf).toSeq.flatMap { cp =>
78          cp.split(java.io.File.pathSeparator)
79        }
80
81        val extraJavaOptsConf = "spark.driver.extraJavaOptions"
82        val extraJavaOpts = sys.props.get(extraJavaOptsConf)
83          .map(Utils.splitCommandString).getOrElse(Seq.empty)
84        val sparkJavaOpts = Utils.sparkJavaOpts(conf)
85        val javaOpts = sparkJavaOpts ++ extraJavaOpts
86        val command = new Command(mainClass,
87          Seq("{{WORKER_URL}}", "{{USER_JAR}}", driverArgs.mainClass) ++ driverArgs.driverOptions,
88          sys.env, classPathEntries, libraryPathEntries, javaOpts)
89
90        val driverDescription = new DriverDescription(
91          driverArgs.jarUrl,
92          driverArgs.memory,
93          driverArgs.cores,
94          driverArgs.supervise,
95          command)
96        ayncSendToMasterAndForwardReply[SubmitDriverResponse](
97          RequestSubmitDriver(driverDescription))
98
99      case "kill" =>
100        val driverId = driverArgs.driverId
101        ayncSendToMasterAndForwardReply[KillDriverResponse](RequestKillDriver(driverId))
102    }
103  }
104
105  /**
106   * Send the message to master and forward the reply to self asynchronously.
107   */
108  private def ayncSendToMasterAndForwardReply[T: ClassTag](message: Any): Unit = {
109    for (masterEndpoint <- masterEndpoints) {
110      masterEndpoint.ask[T](message).onComplete {
111        case Success(v) => self.send(v)
112        case Failure(e) =>
113          logWarning(s"Error sending messages to master $masterEndpoint", e)
114      }(forwardMessageExecutionContext)
115    }
116  }
117
118  /* Find out driver status then exit the JVM */
119  def pollAndReportStatus(driverId: String): Unit = {
120    // Since ClientEndpoint is the only RpcEndpoint in the process, blocking the event loop thread
121    // is fine.
122    logInfo("... waiting before polling master for driver state")
123    Thread.sleep(5000)
124    logInfo("... polling master for driver state")
125    val statusResponse =
126      activeMasterEndpoint.askWithRetry[DriverStatusResponse](RequestDriverStatus(driverId))
127    if (statusResponse.found) {
128      logInfo(s"State of $driverId is ${statusResponse.state.get}")
129      // Worker node, if present
130      (statusResponse.workerId, statusResponse.workerHostPort, statusResponse.state) match {
131        case (Some(id), Some(hostPort), Some(DriverState.RUNNING)) =>
132          logInfo(s"Driver running on $hostPort ($id)")
133        case _ =>
134      }
135      // Exception, if present
136      statusResponse.exception match {
137        case Some(e) =>
138          logError(s"Exception from cluster was: $e")
139          e.printStackTrace()
140          System.exit(-1)
141        case _ =>
142          System.exit(0)
143      }
144    } else {
145      logError(s"ERROR: Cluster master did not recognize $driverId")
146      System.exit(-1)
147    }
148  }
149
150  override def receive: PartialFunction[Any, Unit] = {
151
152    case SubmitDriverResponse(master, success, driverId, message) =>
153      logInfo(message)
154      if (success) {
155        activeMasterEndpoint = master
156        pollAndReportStatus(driverId.get)
157      } else if (!Utils.responseFromBackup(message)) {
158        System.exit(-1)
159      }
160
161
162    case KillDriverResponse(master, driverId, success, message) =>
163      logInfo(message)
164      if (success) {
165        activeMasterEndpoint = master
166        pollAndReportStatus(driverId)
167      } else if (!Utils.responseFromBackup(message)) {
168        System.exit(-1)
169      }
170  }
171
172  override def onDisconnected(remoteAddress: RpcAddress): Unit = {
173    if (!lostMasters.contains(remoteAddress)) {
174      logError(s"Error connecting to master $remoteAddress.")
175      lostMasters += remoteAddress
176      // Note that this heuristic does not account for the fact that a Master can recover within
177      // the lifetime of this client. Thus, once a Master is lost it is lost to us forever. This
178      // is not currently a concern, however, because this client does not retry submissions.
179      if (lostMasters.size >= masterEndpoints.size) {
180        logError("No master is available, exiting.")
181        System.exit(-1)
182      }
183    }
184  }
185
186  override def onNetworkError(cause: Throwable, remoteAddress: RpcAddress): Unit = {
187    if (!lostMasters.contains(remoteAddress)) {
188      logError(s"Error connecting to master ($remoteAddress).")
189      logError(s"Cause was: $cause")
190      lostMasters += remoteAddress
191      if (lostMasters.size >= masterEndpoints.size) {
192        logError("No master is available, exiting.")
193        System.exit(-1)
194      }
195    }
196  }
197
198  override def onError(cause: Throwable): Unit = {
199    logError(s"Error processing messages, exiting.")
200    cause.printStackTrace()
201    System.exit(-1)
202  }
203
204  override def onStop(): Unit = {
205    forwardMessageThread.shutdownNow()
206  }
207}
208
209/**
210 * Executable utility for starting and terminating drivers inside of a standalone cluster.
211 */
212object Client {
213  def main(args: Array[String]) {
214    // scalastyle:off println
215    if (!sys.props.contains("SPARK_SUBMIT")) {
216      println("WARNING: This client is deprecated and will be removed in a future version of Spark")
217      println("Use ./bin/spark-submit with \"--master spark://host:port\"")
218    }
219    // scalastyle:on println
220
221    val conf = new SparkConf()
222    val driverArgs = new ClientArguments(args)
223
224    if (!conf.contains("spark.rpc.askTimeout")) {
225      conf.set("spark.rpc.askTimeout", "10s")
226    }
227    Logger.getRootLogger.setLevel(driverArgs.logLevel)
228
229    val rpcEnv =
230      RpcEnv.create("driverClient", Utils.localHostName(), 0, conf, new SecurityManager(conf))
231
232    val masterEndpoints = driverArgs.masters.map(RpcAddress.fromSparkURL).
233      map(rpcEnv.setupEndpointRef(_, Master.ENDPOINT_NAME))
234    rpcEnv.setupEndpoint("client", new ClientEndpoint(rpcEnv, driverArgs, masterEndpoints, conf))
235
236    rpcEnv.awaitTermination()
237  }
238}
239