1/*
2 * Licensed to the Apache Software Foundation (ASF) under one or more
3 * contributor license agreements.  See the NOTICE file distributed with
4 * this work for additional information regarding copyright ownership.
5 * The ASF licenses this file to You under the Apache License, Version 2.0
6 * (the "License"); you may not use this file except in compliance with
7 * the License.  You may obtain a copy of the License at
8 *
9 *    http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17
18package org.apache.spark.api.python
19
20import java.io.{DataInputStream, DataOutputStream, InputStream, OutputStreamWriter}
21import java.net.{InetAddress, ServerSocket, Socket, SocketException}
22import java.nio.charset.StandardCharsets
23import java.util.Arrays
24
25import scala.collection.mutable
26import scala.collection.JavaConverters._
27
28import org.apache.spark._
29import org.apache.spark.internal.Logging
30import org.apache.spark.util.{RedirectThread, Utils}
31
32private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String, String])
33  extends Logging {
34
35  import PythonWorkerFactory._
36
37  // Because forking processes from Java is expensive, we prefer to launch a single Python daemon
38  // (pyspark/daemon.py) and tell it to fork new workers for our tasks. This daemon currently
39  // only works on UNIX-based systems now because it uses signals for child management, so we can
40  // also fall back to launching workers (pyspark/worker.py) directly.
41  val useDaemon = !System.getProperty("os.name").startsWith("Windows")
42
43  var daemon: Process = null
44  val daemonHost = InetAddress.getByAddress(Array(127, 0, 0, 1))
45  var daemonPort: Int = 0
46  val daemonWorkers = new mutable.WeakHashMap[Socket, Int]()
47  val idleWorkers = new mutable.Queue[Socket]()
48  var lastActivity = 0L
49  new MonitorThread().start()
50
51  var simpleWorkers = new mutable.WeakHashMap[Socket, Process]()
52
53  val pythonPath = PythonUtils.mergePythonPaths(
54    PythonUtils.sparkPythonPath,
55    envVars.getOrElse("PYTHONPATH", ""),
56    sys.env.getOrElse("PYTHONPATH", ""))
57
58  def create(): Socket = {
59    if (useDaemon) {
60      synchronized {
61        if (idleWorkers.size > 0) {
62          return idleWorkers.dequeue()
63        }
64      }
65      createThroughDaemon()
66    } else {
67      createSimpleWorker()
68    }
69  }
70
71  /**
72   * Connect to a worker launched through pyspark/daemon.py, which forks python processes itself
73   * to avoid the high cost of forking from Java. This currently only works on UNIX-based systems.
74   */
75  private def createThroughDaemon(): Socket = {
76
77    def createSocket(): Socket = {
78      val socket = new Socket(daemonHost, daemonPort)
79      val pid = new DataInputStream(socket.getInputStream).readInt()
80      if (pid < 0) {
81        throw new IllegalStateException("Python daemon failed to launch worker with code " + pid)
82      }
83      daemonWorkers.put(socket, pid)
84      socket
85    }
86
87    synchronized {
88      // Start the daemon if it hasn't been started
89      startDaemon()
90
91      // Attempt to connect, restart and retry once if it fails
92      try {
93        createSocket()
94      } catch {
95        case exc: SocketException =>
96          logWarning("Failed to open socket to Python daemon:", exc)
97          logWarning("Assuming that daemon unexpectedly quit, attempting to restart")
98          stopDaemon()
99          startDaemon()
100          createSocket()
101      }
102    }
103  }
104
105  /**
106   * Launch a worker by executing worker.py directly and telling it to connect to us.
107   */
108  private def createSimpleWorker(): Socket = {
109    var serverSocket: ServerSocket = null
110    try {
111      serverSocket = new ServerSocket(0, 1, InetAddress.getByAddress(Array(127, 0, 0, 1)))
112
113      // Create and start the worker
114      val pb = new ProcessBuilder(Arrays.asList(pythonExec, "-m", "pyspark.worker"))
115      val workerEnv = pb.environment()
116      workerEnv.putAll(envVars.asJava)
117      workerEnv.put("PYTHONPATH", pythonPath)
118      // This is equivalent to setting the -u flag; we use it because ipython doesn't support -u:
119      workerEnv.put("PYTHONUNBUFFERED", "YES")
120      val worker = pb.start()
121
122      // Redirect worker stdout and stderr
123      redirectStreamsToStderr(worker.getInputStream, worker.getErrorStream)
124
125      // Tell the worker our port
126      val out = new  OutputStreamWriter(worker.getOutputStream, StandardCharsets.UTF_8)
127      out.write(serverSocket.getLocalPort + "\n")
128      out.flush()
129
130      // Wait for it to connect to our socket
131      serverSocket.setSoTimeout(10000)
132      try {
133        val socket = serverSocket.accept()
134        simpleWorkers.put(socket, worker)
135        return socket
136      } catch {
137        case e: Exception =>
138          throw new SparkException("Python worker did not connect back in time", e)
139      }
140    } finally {
141      if (serverSocket != null) {
142        serverSocket.close()
143      }
144    }
145    null
146  }
147
148  private def startDaemon() {
149    synchronized {
150      // Is it already running?
151      if (daemon != null) {
152        return
153      }
154
155      try {
156        // Create and start the daemon
157        val pb = new ProcessBuilder(Arrays.asList(pythonExec, "-m", "pyspark.daemon"))
158        val workerEnv = pb.environment()
159        workerEnv.putAll(envVars.asJava)
160        workerEnv.put("PYTHONPATH", pythonPath)
161        // This is equivalent to setting the -u flag; we use it because ipython doesn't support -u:
162        workerEnv.put("PYTHONUNBUFFERED", "YES")
163        daemon = pb.start()
164
165        val in = new DataInputStream(daemon.getInputStream)
166        daemonPort = in.readInt()
167
168        // Redirect daemon stdout and stderr
169        redirectStreamsToStderr(in, daemon.getErrorStream)
170
171      } catch {
172        case e: Exception =>
173
174          // If the daemon exists, wait for it to finish and get its stderr
175          val stderr = Option(daemon)
176            .flatMap { d => Utils.getStderr(d, PROCESS_WAIT_TIMEOUT_MS) }
177            .getOrElse("")
178
179          stopDaemon()
180
181          if (stderr != "") {
182            val formattedStderr = stderr.replace("\n", "\n  ")
183            val errorMessage = s"""
184              |Error from python worker:
185              |  $formattedStderr
186              |PYTHONPATH was:
187              |  $pythonPath
188              |$e"""
189
190            // Append error message from python daemon, but keep original stack trace
191            val wrappedException = new SparkException(errorMessage.stripMargin)
192            wrappedException.setStackTrace(e.getStackTrace)
193            throw wrappedException
194          } else {
195            throw e
196          }
197      }
198
199      // Important: don't close daemon's stdin (daemon.getOutputStream) so it can correctly
200      // detect our disappearance.
201    }
202  }
203
204  /**
205   * Redirect the given streams to our stderr in separate threads.
206   */
207  private def redirectStreamsToStderr(stdout: InputStream, stderr: InputStream) {
208    try {
209      new RedirectThread(stdout, System.err, "stdout reader for " + pythonExec).start()
210      new RedirectThread(stderr, System.err, "stderr reader for " + pythonExec).start()
211    } catch {
212      case e: Exception =>
213        logError("Exception in redirecting streams", e)
214    }
215  }
216
217  /**
218   * Monitor all the idle workers, kill them after timeout.
219   */
220  private class MonitorThread extends Thread(s"Idle Worker Monitor for $pythonExec") {
221
222    setDaemon(true)
223
224    override def run() {
225      while (true) {
226        synchronized {
227          if (lastActivity + IDLE_WORKER_TIMEOUT_MS < System.currentTimeMillis()) {
228            cleanupIdleWorkers()
229            lastActivity = System.currentTimeMillis()
230          }
231        }
232        Thread.sleep(10000)
233      }
234    }
235  }
236
237  private def cleanupIdleWorkers() {
238    while (idleWorkers.nonEmpty) {
239      val worker = idleWorkers.dequeue()
240      try {
241        // the worker will exit after closing the socket
242        worker.close()
243      } catch {
244        case e: Exception =>
245          logWarning("Failed to close worker socket", e)
246      }
247    }
248  }
249
250  private def stopDaemon() {
251    synchronized {
252      if (useDaemon) {
253        cleanupIdleWorkers()
254
255        // Request shutdown of existing daemon by sending SIGTERM
256        if (daemon != null) {
257          daemon.destroy()
258        }
259
260        daemon = null
261        daemonPort = 0
262      } else {
263        simpleWorkers.mapValues(_.destroy())
264      }
265    }
266  }
267
268  def stop() {
269    stopDaemon()
270  }
271
272  def stopWorker(worker: Socket) {
273    synchronized {
274      if (useDaemon) {
275        if (daemon != null) {
276          daemonWorkers.get(worker).foreach { pid =>
277            // tell daemon to kill worker by pid
278            val output = new DataOutputStream(daemon.getOutputStream)
279            output.writeInt(pid)
280            output.flush()
281            daemon.getOutputStream.flush()
282          }
283        }
284      } else {
285        simpleWorkers.get(worker).foreach(_.destroy())
286      }
287    }
288    worker.close()
289  }
290
291  def releaseWorker(worker: Socket) {
292    if (useDaemon) {
293      synchronized {
294        lastActivity = System.currentTimeMillis()
295        idleWorkers.enqueue(worker)
296      }
297    } else {
298      // Cleanup the worker socket. This will also cause the Python worker to exit.
299      try {
300        worker.close()
301      } catch {
302        case e: Exception =>
303          logWarning("Failed to close worker socket", e)
304      }
305    }
306  }
307}
308
309private object PythonWorkerFactory {
310  val PROCESS_WAIT_TIMEOUT_MS = 10000
311  val IDLE_WORKER_TIMEOUT_MS = 60000  // kill idle workers after 1 minute
312}
313