1/* 2 * Licensed to the Apache Software Foundation (ASF) under one or more 3 * contributor license agreements. See the NOTICE file distributed with 4 * this work for additional information regarding copyright ownership. 5 * The ASF licenses this file to You under the Apache License, Version 2.0 6 * (the "License"); you may not use this file except in compliance with 7 * the License. You may obtain a copy of the License at 8 * 9 * http://www.apache.org/licenses/LICENSE-2.0 10 * 11 * Unless required by applicable law or agreed to in writing, software 12 * distributed under the License is distributed on an "AS IS" BASIS, 13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 * See the License for the specific language governing permissions and 15 * limitations under the License. 16 */ 17 18package org.apache.spark.deploy 19 20import java.io.File 21import java.net.URI 22 23import scala.collection.mutable.ArrayBuffer 24import scala.collection.JavaConverters._ 25import scala.util.Try 26 27import org.apache.spark.{SparkConf, SparkUserAppException} 28import org.apache.spark.api.python.PythonUtils 29import org.apache.spark.internal.config._ 30import org.apache.spark.util.{RedirectThread, Utils} 31 32/** 33 * A main class used to launch Python applications. It executes python as a 34 * subprocess and then has it connect back to the JVM to access system properties, etc. 35 */ 36object PythonRunner { 37 def main(args: Array[String]) { 38 val pythonFile = args(0) 39 val pyFiles = args(1) 40 val otherArgs = args.slice(2, args.length) 41 val sparkConf = new SparkConf() 42 val pythonExec = sparkConf.get(PYSPARK_DRIVER_PYTHON) 43 .orElse(sparkConf.get(PYSPARK_PYTHON)) 44 .orElse(sys.env.get("PYSPARK_DRIVER_PYTHON")) 45 .orElse(sys.env.get("PYSPARK_PYTHON")) 46 .getOrElse("python") 47 48 // Format python file paths before adding them to the PYTHONPATH 49 val formattedPythonFile = formatPath(pythonFile) 50 val formattedPyFiles = formatPaths(pyFiles) 51 52 // Launch a Py4J gateway server for the process to connect to; this will let it see our 53 // Java system properties and such 54 val gatewayServer = new py4j.GatewayServer(null, 0) 55 val thread = new Thread(new Runnable() { 56 override def run(): Unit = Utils.logUncaughtExceptions { 57 gatewayServer.start() 58 } 59 }) 60 thread.setName("py4j-gateway-init") 61 thread.setDaemon(true) 62 thread.start() 63 64 // Wait until the gateway server has started, so that we know which port is it bound to. 65 // `gatewayServer.start()` will start a new thread and run the server code there, after 66 // initializing the socket, so the thread started above will end as soon as the server is 67 // ready to serve connections. 68 thread.join() 69 70 // Build up a PYTHONPATH that includes the Spark assembly (where this class is), the 71 // python directories in SPARK_HOME (if set), and any files in the pyFiles argument 72 val pathElements = new ArrayBuffer[String] 73 pathElements ++= formattedPyFiles 74 pathElements += PythonUtils.sparkPythonPath 75 pathElements += sys.env.getOrElse("PYTHONPATH", "") 76 val pythonPath = PythonUtils.mergePythonPaths(pathElements: _*) 77 78 // Launch Python process 79 val builder = new ProcessBuilder((Seq(pythonExec, formattedPythonFile) ++ otherArgs).asJava) 80 val env = builder.environment() 81 env.put("PYTHONPATH", pythonPath) 82 // This is equivalent to setting the -u flag; we use it because ipython doesn't support -u: 83 env.put("PYTHONUNBUFFERED", "YES") // value is needed to be set to a non-empty string 84 env.put("PYSPARK_GATEWAY_PORT", "" + gatewayServer.getListeningPort) 85 // pass conf spark.pyspark.python to python process, the only way to pass info to 86 // python process is through environment variable. 87 sparkConf.get(PYSPARK_PYTHON).foreach(env.put("PYSPARK_PYTHON", _)) 88 builder.redirectErrorStream(true) // Ugly but needed for stdout and stderr to synchronize 89 try { 90 val process = builder.start() 91 92 new RedirectThread(process.getInputStream, System.out, "redirect output").start() 93 94 val exitCode = process.waitFor() 95 if (exitCode != 0) { 96 throw new SparkUserAppException(exitCode) 97 } 98 } finally { 99 gatewayServer.shutdown() 100 } 101 } 102 103 /** 104 * Format the python file path so that it can be added to the PYTHONPATH correctly. 105 * 106 * Python does not understand URI schemes in paths. Before adding python files to the 107 * PYTHONPATH, we need to extract the path from the URI. This is safe to do because we 108 * currently only support local python files. 109 */ 110 def formatPath(path: String, testWindows: Boolean = false): String = { 111 if (Utils.nonLocalPaths(path, testWindows).nonEmpty) { 112 throw new IllegalArgumentException("Launching Python applications through " + 113 s"spark-submit is currently only supported for local files: $path") 114 } 115 // get path when scheme is file. 116 val uri = Try(new URI(path)).getOrElse(new File(path).toURI) 117 var formattedPath = uri.getScheme match { 118 case null => path 119 case "file" | "local" => uri.getPath 120 case _ => null 121 } 122 123 // Guard against malformed paths potentially throwing NPE 124 if (formattedPath == null) { 125 throw new IllegalArgumentException(s"Python file path is malformed: $path") 126 } 127 128 // In Windows, the drive should not be prefixed with "/" 129 // For instance, python does not understand "/C:/path/to/sheep.py" 130 if (Utils.isWindows && formattedPath.matches("/[a-zA-Z]:/.*")) { 131 formattedPath = formattedPath.stripPrefix("/") 132 } 133 formattedPath 134 } 135 136 /** 137 * Format each python file path in the comma-delimited list of paths, so it can be 138 * added to the PYTHONPATH correctly. 139 */ 140 def formatPaths(paths: String, testWindows: Boolean = false): Array[String] = { 141 Option(paths).getOrElse("") 142 .split(",") 143 .filter(_.nonEmpty) 144 .map { p => formatPath(p, testWindows) } 145 } 146 147} 148