1/* 2 * Licensed to the Apache Software Foundation (ASF) under one or more 3 * contributor license agreements. See the NOTICE file distributed with 4 * this work for additional information regarding copyright ownership. 5 * The ASF licenses this file to You under the Apache License, Version 2.0 6 * (the "License"); you may not use this file except in compliance with 7 * the License. You may obtain a copy of the License at 8 * 9 * http://www.apache.org/licenses/LICENSE-2.0 10 * 11 * Unless required by applicable law or agreed to in writing, software 12 * distributed under the License is distributed on an "AS IS" BASIS, 13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 * See the License for the specific language governing permissions and 15 * limitations under the License. 16 */ 17 18package org.apache.spark.deploy 19 20import scala.collection.mutable.HashSet 21import scala.concurrent.ExecutionContext 22import scala.reflect.ClassTag 23import scala.util.{Failure, Success} 24 25import org.apache.log4j.Logger 26 27import org.apache.spark.{SecurityManager, SparkConf} 28import org.apache.spark.deploy.DeployMessages._ 29import org.apache.spark.deploy.master.{DriverState, Master} 30import org.apache.spark.internal.Logging 31import org.apache.spark.rpc.{RpcAddress, RpcEndpointRef, RpcEnv, ThreadSafeRpcEndpoint} 32import org.apache.spark.util.{SparkExitCode, ThreadUtils, Utils} 33 34/** 35 * Proxy that relays messages to the driver. 36 * 37 * We currently don't support retry if submission fails. In HA mode, client will submit request to 38 * all masters and see which one could handle it. 39 */ 40private class ClientEndpoint( 41 override val rpcEnv: RpcEnv, 42 driverArgs: ClientArguments, 43 masterEndpoints: Seq[RpcEndpointRef], 44 conf: SparkConf) 45 extends ThreadSafeRpcEndpoint with Logging { 46 47 // A scheduled executor used to send messages at the specified time. 48 private val forwardMessageThread = 49 ThreadUtils.newDaemonSingleThreadScheduledExecutor("client-forward-message") 50 // Used to provide the implicit parameter of `Future` methods. 51 private val forwardMessageExecutionContext = 52 ExecutionContext.fromExecutor(forwardMessageThread, 53 t => t match { 54 case ie: InterruptedException => // Exit normally 55 case e: Throwable => 56 logError(e.getMessage, e) 57 System.exit(SparkExitCode.UNCAUGHT_EXCEPTION) 58 }) 59 60 private val lostMasters = new HashSet[RpcAddress] 61 private var activeMasterEndpoint: RpcEndpointRef = null 62 63 override def onStart(): Unit = { 64 driverArgs.cmd match { 65 case "launch" => 66 // TODO: We could add an env variable here and intercept it in `sc.addJar` that would 67 // truncate filesystem paths similar to what YARN does. For now, we just require 68 // people call `addJar` assuming the jar is in the same directory. 69 val mainClass = "org.apache.spark.deploy.worker.DriverWrapper" 70 71 val classPathConf = "spark.driver.extraClassPath" 72 val classPathEntries = sys.props.get(classPathConf).toSeq.flatMap { cp => 73 cp.split(java.io.File.pathSeparator) 74 } 75 76 val libraryPathConf = "spark.driver.extraLibraryPath" 77 val libraryPathEntries = sys.props.get(libraryPathConf).toSeq.flatMap { cp => 78 cp.split(java.io.File.pathSeparator) 79 } 80 81 val extraJavaOptsConf = "spark.driver.extraJavaOptions" 82 val extraJavaOpts = sys.props.get(extraJavaOptsConf) 83 .map(Utils.splitCommandString).getOrElse(Seq.empty) 84 val sparkJavaOpts = Utils.sparkJavaOpts(conf) 85 val javaOpts = sparkJavaOpts ++ extraJavaOpts 86 val command = new Command(mainClass, 87 Seq("{{WORKER_URL}}", "{{USER_JAR}}", driverArgs.mainClass) ++ driverArgs.driverOptions, 88 sys.env, classPathEntries, libraryPathEntries, javaOpts) 89 90 val driverDescription = new DriverDescription( 91 driverArgs.jarUrl, 92 driverArgs.memory, 93 driverArgs.cores, 94 driverArgs.supervise, 95 command) 96 ayncSendToMasterAndForwardReply[SubmitDriverResponse]( 97 RequestSubmitDriver(driverDescription)) 98 99 case "kill" => 100 val driverId = driverArgs.driverId 101 ayncSendToMasterAndForwardReply[KillDriverResponse](RequestKillDriver(driverId)) 102 } 103 } 104 105 /** 106 * Send the message to master and forward the reply to self asynchronously. 107 */ 108 private def ayncSendToMasterAndForwardReply[T: ClassTag](message: Any): Unit = { 109 for (masterEndpoint <- masterEndpoints) { 110 masterEndpoint.ask[T](message).onComplete { 111 case Success(v) => self.send(v) 112 case Failure(e) => 113 logWarning(s"Error sending messages to master $masterEndpoint", e) 114 }(forwardMessageExecutionContext) 115 } 116 } 117 118 /* Find out driver status then exit the JVM */ 119 def pollAndReportStatus(driverId: String): Unit = { 120 // Since ClientEndpoint is the only RpcEndpoint in the process, blocking the event loop thread 121 // is fine. 122 logInfo("... waiting before polling master for driver state") 123 Thread.sleep(5000) 124 logInfo("... polling master for driver state") 125 val statusResponse = 126 activeMasterEndpoint.askWithRetry[DriverStatusResponse](RequestDriverStatus(driverId)) 127 if (statusResponse.found) { 128 logInfo(s"State of $driverId is ${statusResponse.state.get}") 129 // Worker node, if present 130 (statusResponse.workerId, statusResponse.workerHostPort, statusResponse.state) match { 131 case (Some(id), Some(hostPort), Some(DriverState.RUNNING)) => 132 logInfo(s"Driver running on $hostPort ($id)") 133 case _ => 134 } 135 // Exception, if present 136 statusResponse.exception match { 137 case Some(e) => 138 logError(s"Exception from cluster was: $e") 139 e.printStackTrace() 140 System.exit(-1) 141 case _ => 142 System.exit(0) 143 } 144 } else { 145 logError(s"ERROR: Cluster master did not recognize $driverId") 146 System.exit(-1) 147 } 148 } 149 150 override def receive: PartialFunction[Any, Unit] = { 151 152 case SubmitDriverResponse(master, success, driverId, message) => 153 logInfo(message) 154 if (success) { 155 activeMasterEndpoint = master 156 pollAndReportStatus(driverId.get) 157 } else if (!Utils.responseFromBackup(message)) { 158 System.exit(-1) 159 } 160 161 162 case KillDriverResponse(master, driverId, success, message) => 163 logInfo(message) 164 if (success) { 165 activeMasterEndpoint = master 166 pollAndReportStatus(driverId) 167 } else if (!Utils.responseFromBackup(message)) { 168 System.exit(-1) 169 } 170 } 171 172 override def onDisconnected(remoteAddress: RpcAddress): Unit = { 173 if (!lostMasters.contains(remoteAddress)) { 174 logError(s"Error connecting to master $remoteAddress.") 175 lostMasters += remoteAddress 176 // Note that this heuristic does not account for the fact that a Master can recover within 177 // the lifetime of this client. Thus, once a Master is lost it is lost to us forever. This 178 // is not currently a concern, however, because this client does not retry submissions. 179 if (lostMasters.size >= masterEndpoints.size) { 180 logError("No master is available, exiting.") 181 System.exit(-1) 182 } 183 } 184 } 185 186 override def onNetworkError(cause: Throwable, remoteAddress: RpcAddress): Unit = { 187 if (!lostMasters.contains(remoteAddress)) { 188 logError(s"Error connecting to master ($remoteAddress).") 189 logError(s"Cause was: $cause") 190 lostMasters += remoteAddress 191 if (lostMasters.size >= masterEndpoints.size) { 192 logError("No master is available, exiting.") 193 System.exit(-1) 194 } 195 } 196 } 197 198 override def onError(cause: Throwable): Unit = { 199 logError(s"Error processing messages, exiting.") 200 cause.printStackTrace() 201 System.exit(-1) 202 } 203 204 override def onStop(): Unit = { 205 forwardMessageThread.shutdownNow() 206 } 207} 208 209/** 210 * Executable utility for starting and terminating drivers inside of a standalone cluster. 211 */ 212object Client { 213 def main(args: Array[String]) { 214 // scalastyle:off println 215 if (!sys.props.contains("SPARK_SUBMIT")) { 216 println("WARNING: This client is deprecated and will be removed in a future version of Spark") 217 println("Use ./bin/spark-submit with \"--master spark://host:port\"") 218 } 219 // scalastyle:on println 220 221 val conf = new SparkConf() 222 val driverArgs = new ClientArguments(args) 223 224 if (!conf.contains("spark.rpc.askTimeout")) { 225 conf.set("spark.rpc.askTimeout", "10s") 226 } 227 Logger.getRootLogger.setLevel(driverArgs.logLevel) 228 229 val rpcEnv = 230 RpcEnv.create("driverClient", Utils.localHostName(), 0, conf, new SecurityManager(conf)) 231 232 val masterEndpoints = driverArgs.masters.map(RpcAddress.fromSparkURL). 233 map(rpcEnv.setupEndpointRef(_, Master.ENDPOINT_NAME)) 234 rpcEnv.setupEndpoint("client", new ClientEndpoint(rpcEnv, driverArgs, masterEndpoints, conf)) 235 236 rpcEnv.awaitTermination() 237 } 238} 239