1/* 2 * Licensed to the Apache Software Foundation (ASF) under one or more 3 * contributor license agreements. See the NOTICE file distributed with 4 * this work for additional information regarding copyright ownership. 5 * The ASF licenses this file to You under the Apache License, Version 2.0 6 * (the "License"); you may not use this file except in compliance with 7 * the License. You may obtain a copy of the License at 8 * 9 * http://www.apache.org/licenses/LICENSE-2.0 10 * 11 * Unless required by applicable law or agreed to in writing, software 12 * distributed under the License is distributed on an "AS IS" BASIS, 13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 * See the License for the specific language governing permissions and 15 * limitations under the License. 16 */ 17 18package org.apache.spark.api.python 19 20import java.io.{DataInputStream, DataOutputStream, InputStream, OutputStreamWriter} 21import java.net.{InetAddress, ServerSocket, Socket, SocketException} 22import java.nio.charset.StandardCharsets 23import java.util.Arrays 24 25import scala.collection.mutable 26import scala.collection.JavaConverters._ 27 28import org.apache.spark._ 29import org.apache.spark.internal.Logging 30import org.apache.spark.util.{RedirectThread, Utils} 31 32private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String, String]) 33 extends Logging { 34 35 import PythonWorkerFactory._ 36 37 // Because forking processes from Java is expensive, we prefer to launch a single Python daemon 38 // (pyspark/daemon.py) and tell it to fork new workers for our tasks. This daemon currently 39 // only works on UNIX-based systems now because it uses signals for child management, so we can 40 // also fall back to launching workers (pyspark/worker.py) directly. 41 val useDaemon = !System.getProperty("os.name").startsWith("Windows") 42 43 var daemon: Process = null 44 val daemonHost = InetAddress.getByAddress(Array(127, 0, 0, 1)) 45 var daemonPort: Int = 0 46 val daemonWorkers = new mutable.WeakHashMap[Socket, Int]() 47 val idleWorkers = new mutable.Queue[Socket]() 48 var lastActivity = 0L 49 new MonitorThread().start() 50 51 var simpleWorkers = new mutable.WeakHashMap[Socket, Process]() 52 53 val pythonPath = PythonUtils.mergePythonPaths( 54 PythonUtils.sparkPythonPath, 55 envVars.getOrElse("PYTHONPATH", ""), 56 sys.env.getOrElse("PYTHONPATH", "")) 57 58 def create(): Socket = { 59 if (useDaemon) { 60 synchronized { 61 if (idleWorkers.size > 0) { 62 return idleWorkers.dequeue() 63 } 64 } 65 createThroughDaemon() 66 } else { 67 createSimpleWorker() 68 } 69 } 70 71 /** 72 * Connect to a worker launched through pyspark/daemon.py, which forks python processes itself 73 * to avoid the high cost of forking from Java. This currently only works on UNIX-based systems. 74 */ 75 private def createThroughDaemon(): Socket = { 76 77 def createSocket(): Socket = { 78 val socket = new Socket(daemonHost, daemonPort) 79 val pid = new DataInputStream(socket.getInputStream).readInt() 80 if (pid < 0) { 81 throw new IllegalStateException("Python daemon failed to launch worker with code " + pid) 82 } 83 daemonWorkers.put(socket, pid) 84 socket 85 } 86 87 synchronized { 88 // Start the daemon if it hasn't been started 89 startDaemon() 90 91 // Attempt to connect, restart and retry once if it fails 92 try { 93 createSocket() 94 } catch { 95 case exc: SocketException => 96 logWarning("Failed to open socket to Python daemon:", exc) 97 logWarning("Assuming that daemon unexpectedly quit, attempting to restart") 98 stopDaemon() 99 startDaemon() 100 createSocket() 101 } 102 } 103 } 104 105 /** 106 * Launch a worker by executing worker.py directly and telling it to connect to us. 107 */ 108 private def createSimpleWorker(): Socket = { 109 var serverSocket: ServerSocket = null 110 try { 111 serverSocket = new ServerSocket(0, 1, InetAddress.getByAddress(Array(127, 0, 0, 1))) 112 113 // Create and start the worker 114 val pb = new ProcessBuilder(Arrays.asList(pythonExec, "-m", "pyspark.worker")) 115 val workerEnv = pb.environment() 116 workerEnv.putAll(envVars.asJava) 117 workerEnv.put("PYTHONPATH", pythonPath) 118 // This is equivalent to setting the -u flag; we use it because ipython doesn't support -u: 119 workerEnv.put("PYTHONUNBUFFERED", "YES") 120 val worker = pb.start() 121 122 // Redirect worker stdout and stderr 123 redirectStreamsToStderr(worker.getInputStream, worker.getErrorStream) 124 125 // Tell the worker our port 126 val out = new OutputStreamWriter(worker.getOutputStream, StandardCharsets.UTF_8) 127 out.write(serverSocket.getLocalPort + "\n") 128 out.flush() 129 130 // Wait for it to connect to our socket 131 serverSocket.setSoTimeout(10000) 132 try { 133 val socket = serverSocket.accept() 134 simpleWorkers.put(socket, worker) 135 return socket 136 } catch { 137 case e: Exception => 138 throw new SparkException("Python worker did not connect back in time", e) 139 } 140 } finally { 141 if (serverSocket != null) { 142 serverSocket.close() 143 } 144 } 145 null 146 } 147 148 private def startDaemon() { 149 synchronized { 150 // Is it already running? 151 if (daemon != null) { 152 return 153 } 154 155 try { 156 // Create and start the daemon 157 val pb = new ProcessBuilder(Arrays.asList(pythonExec, "-m", "pyspark.daemon")) 158 val workerEnv = pb.environment() 159 workerEnv.putAll(envVars.asJava) 160 workerEnv.put("PYTHONPATH", pythonPath) 161 // This is equivalent to setting the -u flag; we use it because ipython doesn't support -u: 162 workerEnv.put("PYTHONUNBUFFERED", "YES") 163 daemon = pb.start() 164 165 val in = new DataInputStream(daemon.getInputStream) 166 daemonPort = in.readInt() 167 168 // Redirect daemon stdout and stderr 169 redirectStreamsToStderr(in, daemon.getErrorStream) 170 171 } catch { 172 case e: Exception => 173 174 // If the daemon exists, wait for it to finish and get its stderr 175 val stderr = Option(daemon) 176 .flatMap { d => Utils.getStderr(d, PROCESS_WAIT_TIMEOUT_MS) } 177 .getOrElse("") 178 179 stopDaemon() 180 181 if (stderr != "") { 182 val formattedStderr = stderr.replace("\n", "\n ") 183 val errorMessage = s""" 184 |Error from python worker: 185 | $formattedStderr 186 |PYTHONPATH was: 187 | $pythonPath 188 |$e""" 189 190 // Append error message from python daemon, but keep original stack trace 191 val wrappedException = new SparkException(errorMessage.stripMargin) 192 wrappedException.setStackTrace(e.getStackTrace) 193 throw wrappedException 194 } else { 195 throw e 196 } 197 } 198 199 // Important: don't close daemon's stdin (daemon.getOutputStream) so it can correctly 200 // detect our disappearance. 201 } 202 } 203 204 /** 205 * Redirect the given streams to our stderr in separate threads. 206 */ 207 private def redirectStreamsToStderr(stdout: InputStream, stderr: InputStream) { 208 try { 209 new RedirectThread(stdout, System.err, "stdout reader for " + pythonExec).start() 210 new RedirectThread(stderr, System.err, "stderr reader for " + pythonExec).start() 211 } catch { 212 case e: Exception => 213 logError("Exception in redirecting streams", e) 214 } 215 } 216 217 /** 218 * Monitor all the idle workers, kill them after timeout. 219 */ 220 private class MonitorThread extends Thread(s"Idle Worker Monitor for $pythonExec") { 221 222 setDaemon(true) 223 224 override def run() { 225 while (true) { 226 synchronized { 227 if (lastActivity + IDLE_WORKER_TIMEOUT_MS < System.currentTimeMillis()) { 228 cleanupIdleWorkers() 229 lastActivity = System.currentTimeMillis() 230 } 231 } 232 Thread.sleep(10000) 233 } 234 } 235 } 236 237 private def cleanupIdleWorkers() { 238 while (idleWorkers.nonEmpty) { 239 val worker = idleWorkers.dequeue() 240 try { 241 // the worker will exit after closing the socket 242 worker.close() 243 } catch { 244 case e: Exception => 245 logWarning("Failed to close worker socket", e) 246 } 247 } 248 } 249 250 private def stopDaemon() { 251 synchronized { 252 if (useDaemon) { 253 cleanupIdleWorkers() 254 255 // Request shutdown of existing daemon by sending SIGTERM 256 if (daemon != null) { 257 daemon.destroy() 258 } 259 260 daemon = null 261 daemonPort = 0 262 } else { 263 simpleWorkers.mapValues(_.destroy()) 264 } 265 } 266 } 267 268 def stop() { 269 stopDaemon() 270 } 271 272 def stopWorker(worker: Socket) { 273 synchronized { 274 if (useDaemon) { 275 if (daemon != null) { 276 daemonWorkers.get(worker).foreach { pid => 277 // tell daemon to kill worker by pid 278 val output = new DataOutputStream(daemon.getOutputStream) 279 output.writeInt(pid) 280 output.flush() 281 daemon.getOutputStream.flush() 282 } 283 } 284 } else { 285 simpleWorkers.get(worker).foreach(_.destroy()) 286 } 287 } 288 worker.close() 289 } 290 291 def releaseWorker(worker: Socket) { 292 if (useDaemon) { 293 synchronized { 294 lastActivity = System.currentTimeMillis() 295 idleWorkers.enqueue(worker) 296 } 297 } else { 298 // Cleanup the worker socket. This will also cause the Python worker to exit. 299 try { 300 worker.close() 301 } catch { 302 case e: Exception => 303 logWarning("Failed to close worker socket", e) 304 } 305 } 306 } 307} 308 309private object PythonWorkerFactory { 310 val PROCESS_WAIT_TIMEOUT_MS = 10000 311 val IDLE_WORKER_TIMEOUT_MS = 60000 // kill idle workers after 1 minute 312} 313