1/*
2 * Licensed to the Apache Software Foundation (ASF) under one or more
3 * contributor license agreements.  See the NOTICE file distributed with
4 * this work for additional information regarding copyright ownership.
5 * The ASF licenses this file to You under the Apache License, Version 2.0
6 * (the "License"); you may not use this file except in compliance with
7 * the License.  You may obtain a copy of the License at
8 *
9 *    http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17
18package org.apache.spark.deploy
19
20import java.io.File
21import java.net.URI
22
23import scala.collection.mutable.ArrayBuffer
24import scala.collection.JavaConverters._
25import scala.util.Try
26
27import org.apache.spark.{SparkConf, SparkUserAppException}
28import org.apache.spark.api.python.PythonUtils
29import org.apache.spark.internal.config._
30import org.apache.spark.util.{RedirectThread, Utils}
31
32/**
33 * A main class used to launch Python applications. It executes python as a
34 * subprocess and then has it connect back to the JVM to access system properties, etc.
35 */
36object PythonRunner {
37  def main(args: Array[String]) {
38    val pythonFile = args(0)
39    val pyFiles = args(1)
40    val otherArgs = args.slice(2, args.length)
41    val sparkConf = new SparkConf()
42    val pythonExec = sparkConf.get(PYSPARK_DRIVER_PYTHON)
43      .orElse(sparkConf.get(PYSPARK_PYTHON))
44      .orElse(sys.env.get("PYSPARK_DRIVER_PYTHON"))
45      .orElse(sys.env.get("PYSPARK_PYTHON"))
46      .getOrElse("python")
47
48    // Format python file paths before adding them to the PYTHONPATH
49    val formattedPythonFile = formatPath(pythonFile)
50    val formattedPyFiles = formatPaths(pyFiles)
51
52    // Launch a Py4J gateway server for the process to connect to; this will let it see our
53    // Java system properties and such
54    val gatewayServer = new py4j.GatewayServer(null, 0)
55    val thread = new Thread(new Runnable() {
56      override def run(): Unit = Utils.logUncaughtExceptions {
57        gatewayServer.start()
58      }
59    })
60    thread.setName("py4j-gateway-init")
61    thread.setDaemon(true)
62    thread.start()
63
64    // Wait until the gateway server has started, so that we know which port is it bound to.
65    // `gatewayServer.start()` will start a new thread and run the server code there, after
66    // initializing the socket, so the thread started above will end as soon as the server is
67    // ready to serve connections.
68    thread.join()
69
70    // Build up a PYTHONPATH that includes the Spark assembly (where this class is), the
71    // python directories in SPARK_HOME (if set), and any files in the pyFiles argument
72    val pathElements = new ArrayBuffer[String]
73    pathElements ++= formattedPyFiles
74    pathElements += PythonUtils.sparkPythonPath
75    pathElements += sys.env.getOrElse("PYTHONPATH", "")
76    val pythonPath = PythonUtils.mergePythonPaths(pathElements: _*)
77
78    // Launch Python process
79    val builder = new ProcessBuilder((Seq(pythonExec, formattedPythonFile) ++ otherArgs).asJava)
80    val env = builder.environment()
81    env.put("PYTHONPATH", pythonPath)
82    // This is equivalent to setting the -u flag; we use it because ipython doesn't support -u:
83    env.put("PYTHONUNBUFFERED", "YES") // value is needed to be set to a non-empty string
84    env.put("PYSPARK_GATEWAY_PORT", "" + gatewayServer.getListeningPort)
85    // pass conf spark.pyspark.python to python process, the only way to pass info to
86    // python process is through environment variable.
87    sparkConf.get(PYSPARK_PYTHON).foreach(env.put("PYSPARK_PYTHON", _))
88    builder.redirectErrorStream(true) // Ugly but needed for stdout and stderr to synchronize
89    try {
90      val process = builder.start()
91
92      new RedirectThread(process.getInputStream, System.out, "redirect output").start()
93
94      val exitCode = process.waitFor()
95      if (exitCode != 0) {
96        throw new SparkUserAppException(exitCode)
97      }
98    } finally {
99      gatewayServer.shutdown()
100    }
101  }
102
103  /**
104   * Format the python file path so that it can be added to the PYTHONPATH correctly.
105   *
106   * Python does not understand URI schemes in paths. Before adding python files to the
107   * PYTHONPATH, we need to extract the path from the URI. This is safe to do because we
108   * currently only support local python files.
109   */
110  def formatPath(path: String, testWindows: Boolean = false): String = {
111    if (Utils.nonLocalPaths(path, testWindows).nonEmpty) {
112      throw new IllegalArgumentException("Launching Python applications through " +
113        s"spark-submit is currently only supported for local files: $path")
114    }
115    // get path when scheme is file.
116    val uri = Try(new URI(path)).getOrElse(new File(path).toURI)
117    var formattedPath = uri.getScheme match {
118      case null => path
119      case "file" | "local" => uri.getPath
120      case _ => null
121    }
122
123    // Guard against malformed paths potentially throwing NPE
124    if (formattedPath == null) {
125      throw new IllegalArgumentException(s"Python file path is malformed: $path")
126    }
127
128    // In Windows, the drive should not be prefixed with "/"
129    // For instance, python does not understand "/C:/path/to/sheep.py"
130    if (Utils.isWindows && formattedPath.matches("/[a-zA-Z]:/.*")) {
131      formattedPath = formattedPath.stripPrefix("/")
132    }
133    formattedPath
134  }
135
136  /**
137   * Format each python file path in the comma-delimited list of paths, so it can be
138   * added to the PYTHONPATH correctly.
139   */
140  def formatPaths(paths: String, testWindows: Boolean = false): Array[String] = {
141    Option(paths).getOrElse("")
142      .split(",")
143      .filter(_.nonEmpty)
144      .map { p => formatPath(p, testWindows) }
145  }
146
147}
148