1/*
2 * Licensed to the Apache Software Foundation (ASF) under one or more
3 * contributor license agreements.  See the NOTICE file distributed with
4 * this work for additional information regarding copyright ownership.
5 * The ASF licenses this file to You under the Apache License, Version 2.0
6 * (the "License"); you may not use this file except in compliance with
7 * the License.  You may obtain a copy of the License at
8 *
9 *    http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17
18package org.apache.spark.util
19
20import java.io._
21import java.lang.management.{LockInfo, ManagementFactory, MonitorInfo, ThreadInfo}
22import java.net._
23import java.nio.ByteBuffer
24import java.nio.channels.Channels
25import java.nio.charset.StandardCharsets
26import java.nio.file.{Files, Paths}
27import java.util.{Locale, Properties, Random, UUID}
28import java.util.concurrent._
29import java.util.concurrent.atomic.AtomicBoolean
30import java.util.zip.GZIPInputStream
31import javax.net.ssl.HttpsURLConnection
32
33import scala.annotation.tailrec
34import scala.collection.JavaConverters._
35import scala.collection.Map
36import scala.collection.mutable.ArrayBuffer
37import scala.io.Source
38import scala.reflect.ClassTag
39import scala.util.Try
40import scala.util.control.{ControlThrowable, NonFatal}
41
42import _root_.io.netty.channel.unix.Errors.NativeIoException
43import com.google.common.cache.{CacheBuilder, CacheLoader, LoadingCache}
44import com.google.common.io.{ByteStreams, Files => GFiles}
45import com.google.common.net.InetAddresses
46import org.apache.commons.lang3.SystemUtils
47import org.apache.hadoop.conf.Configuration
48import org.apache.hadoop.fs.{FileSystem, FileUtil, Path}
49import org.apache.hadoop.security.UserGroupInformation
50import org.apache.log4j.PropertyConfigurator
51import org.eclipse.jetty.util.MultiException
52import org.json4s._
53import org.slf4j.Logger
54
55import org.apache.spark._
56import org.apache.spark.deploy.SparkHadoopUtil
57import org.apache.spark.internal.Logging
58import org.apache.spark.internal.config.{DYN_ALLOCATION_INITIAL_EXECUTORS, DYN_ALLOCATION_MIN_EXECUTORS, EXECUTOR_INSTANCES}
59import org.apache.spark.network.util.JavaUtils
60import org.apache.spark.serializer.{DeserializationStream, SerializationStream, SerializerInstance}
61import org.apache.spark.util.logging.RollingFileAppender
62
63/** CallSite represents a place in user code. It can have a short and a long form. */
64private[spark] case class CallSite(shortForm: String, longForm: String)
65
66private[spark] object CallSite {
67  val SHORT_FORM = "callSite.short"
68  val LONG_FORM = "callSite.long"
69  val empty = CallSite("", "")
70}
71
72/**
73 * Various utility methods used by Spark.
74 */
75private[spark] object Utils extends Logging {
76  val random = new Random()
77
78  /**
79   * Define a default value for driver memory here since this value is referenced across the code
80   * base and nearly all files already use Utils.scala
81   */
82  val DEFAULT_DRIVER_MEM_MB = JavaUtils.DEFAULT_DRIVER_MEM_MB.toInt
83
84  private val MAX_DIR_CREATION_ATTEMPTS: Int = 10
85  @volatile private var localRootDirs: Array[String] = null
86
87  /**
88   * The performance overhead of creating and logging strings for wide schemas can be large. To
89   * limit the impact, we bound the number of fields to include by default. This can be overridden
90   * by setting the 'spark.debug.maxToStringFields' conf in SparkEnv.
91   */
92  val DEFAULT_MAX_TO_STRING_FIELDS = 25
93
94  private def maxNumToStringFields = {
95    if (SparkEnv.get != null) {
96      SparkEnv.get.conf.getInt("spark.debug.maxToStringFields", DEFAULT_MAX_TO_STRING_FIELDS)
97    } else {
98      DEFAULT_MAX_TO_STRING_FIELDS
99    }
100  }
101
102  /** Whether we have warned about plan string truncation yet. */
103  private val truncationWarningPrinted = new AtomicBoolean(false)
104
105  /**
106   * Format a sequence with semantics similar to calling .mkString(). Any elements beyond
107   * maxNumToStringFields will be dropped and replaced by a "... N more fields" placeholder.
108   *
109   * @return the trimmed and formatted string.
110   */
111  def truncatedString[T](
112      seq: Seq[T],
113      start: String,
114      sep: String,
115      end: String,
116      maxNumFields: Int = maxNumToStringFields): String = {
117    if (seq.length > maxNumFields) {
118      if (truncationWarningPrinted.compareAndSet(false, true)) {
119        logWarning(
120          "Truncated the string representation of a plan since it was too large. This " +
121          "behavior can be adjusted by setting 'spark.debug.maxToStringFields' in SparkEnv.conf.")
122      }
123      val numFields = math.max(0, maxNumFields - 1)
124      seq.take(numFields).mkString(
125        start, sep, sep + "... " + (seq.length - numFields) + " more fields" + end)
126    } else {
127      seq.mkString(start, sep, end)
128    }
129  }
130
131  /** Shorthand for calling truncatedString() without start or end strings. */
132  def truncatedString[T](seq: Seq[T], sep: String): String = truncatedString(seq, "", sep, "")
133
134  /** Serialize an object using Java serialization */
135  def serialize[T](o: T): Array[Byte] = {
136    val bos = new ByteArrayOutputStream()
137    val oos = new ObjectOutputStream(bos)
138    oos.writeObject(o)
139    oos.close()
140    bos.toByteArray
141  }
142
143  /** Deserialize an object using Java serialization */
144  def deserialize[T](bytes: Array[Byte]): T = {
145    val bis = new ByteArrayInputStream(bytes)
146    val ois = new ObjectInputStream(bis)
147    ois.readObject.asInstanceOf[T]
148  }
149
150  /** Deserialize an object using Java serialization and the given ClassLoader */
151  def deserialize[T](bytes: Array[Byte], loader: ClassLoader): T = {
152    val bis = new ByteArrayInputStream(bytes)
153    val ois = new ObjectInputStream(bis) {
154      override def resolveClass(desc: ObjectStreamClass): Class[_] = {
155        // scalastyle:off classforname
156        Class.forName(desc.getName, false, loader)
157        // scalastyle:on classforname
158      }
159    }
160    ois.readObject.asInstanceOf[T]
161  }
162
163  /** Deserialize a Long value (used for [[org.apache.spark.api.python.PythonPartitioner]]) */
164  def deserializeLongValue(bytes: Array[Byte]) : Long = {
165    // Note: we assume that we are given a Long value encoded in network (big-endian) byte order
166    var result = bytes(7) & 0xFFL
167    result = result + ((bytes(6) & 0xFFL) << 8)
168    result = result + ((bytes(5) & 0xFFL) << 16)
169    result = result + ((bytes(4) & 0xFFL) << 24)
170    result = result + ((bytes(3) & 0xFFL) << 32)
171    result = result + ((bytes(2) & 0xFFL) << 40)
172    result = result + ((bytes(1) & 0xFFL) << 48)
173    result + ((bytes(0) & 0xFFL) << 56)
174  }
175
176  /** Serialize via nested stream using specific serializer */
177  def serializeViaNestedStream(os: OutputStream, ser: SerializerInstance)(
178      f: SerializationStream => Unit): Unit = {
179    val osWrapper = ser.serializeStream(new OutputStream {
180      override def write(b: Int): Unit = os.write(b)
181      override def write(b: Array[Byte], off: Int, len: Int): Unit = os.write(b, off, len)
182    })
183    try {
184      f(osWrapper)
185    } finally {
186      osWrapper.close()
187    }
188  }
189
190  /** Deserialize via nested stream using specific serializer */
191  def deserializeViaNestedStream(is: InputStream, ser: SerializerInstance)(
192      f: DeserializationStream => Unit): Unit = {
193    val isWrapper = ser.deserializeStream(new InputStream {
194      override def read(): Int = is.read()
195      override def read(b: Array[Byte], off: Int, len: Int): Int = is.read(b, off, len)
196    })
197    try {
198      f(isWrapper)
199    } finally {
200      isWrapper.close()
201    }
202  }
203
204  /**
205   * Get the ClassLoader which loaded Spark.
206   */
207  def getSparkClassLoader: ClassLoader = getClass.getClassLoader
208
209  /**
210   * Get the Context ClassLoader on this thread or, if not present, the ClassLoader that
211   * loaded Spark.
212   *
213   * This should be used whenever passing a ClassLoader to Class.ForName or finding the currently
214   * active loader when setting up ClassLoader delegation chains.
215   */
216  def getContextOrSparkClassLoader: ClassLoader =
217    Option(Thread.currentThread().getContextClassLoader).getOrElse(getSparkClassLoader)
218
219  /** Determines whether the provided class is loadable in the current thread. */
220  def classIsLoadable(clazz: String): Boolean = {
221    // scalastyle:off classforname
222    Try { Class.forName(clazz, false, getContextOrSparkClassLoader) }.isSuccess
223    // scalastyle:on classforname
224  }
225
226  // scalastyle:off classforname
227  /** Preferred alternative to Class.forName(className) */
228  def classForName(className: String): Class[_] = {
229    Class.forName(className, true, getContextOrSparkClassLoader)
230    // scalastyle:on classforname
231  }
232
233  /**
234   * Primitive often used when writing [[java.nio.ByteBuffer]] to [[java.io.DataOutput]]
235   */
236  def writeByteBuffer(bb: ByteBuffer, out: DataOutput): Unit = {
237    if (bb.hasArray) {
238      out.write(bb.array(), bb.arrayOffset() + bb.position(), bb.remaining())
239    } else {
240      val bbval = new Array[Byte](bb.remaining())
241      bb.get(bbval)
242      out.write(bbval)
243    }
244  }
245
246  /**
247   * Primitive often used when writing [[java.nio.ByteBuffer]] to [[java.io.OutputStream]]
248   */
249  def writeByteBuffer(bb: ByteBuffer, out: OutputStream): Unit = {
250    if (bb.hasArray) {
251      out.write(bb.array(), bb.arrayOffset() + bb.position(), bb.remaining())
252    } else {
253      val bbval = new Array[Byte](bb.remaining())
254      bb.get(bbval)
255      out.write(bbval)
256    }
257  }
258
259  /**
260   * JDK equivalent of `chmod 700 file`.
261   *
262   * @param file the file whose permissions will be modified
263   * @return true if the permissions were successfully changed, false otherwise.
264   */
265  def chmod700(file: File): Boolean = {
266    file.setReadable(false, false) &&
267    file.setReadable(true, true) &&
268    file.setWritable(false, false) &&
269    file.setWritable(true, true) &&
270    file.setExecutable(false, false) &&
271    file.setExecutable(true, true)
272  }
273
274  /**
275   * Create a directory inside the given parent directory. The directory is guaranteed to be
276   * newly created, and is not marked for automatic deletion.
277   */
278  def createDirectory(root: String, namePrefix: String = "spark"): File = {
279    var attempts = 0
280    val maxAttempts = MAX_DIR_CREATION_ATTEMPTS
281    var dir: File = null
282    while (dir == null) {
283      attempts += 1
284      if (attempts > maxAttempts) {
285        throw new IOException("Failed to create a temp directory (under " + root + ") after " +
286          maxAttempts + " attempts!")
287      }
288      try {
289        dir = new File(root, namePrefix + "-" + UUID.randomUUID.toString)
290        if (dir.exists() || !dir.mkdirs()) {
291          dir = null
292        }
293      } catch { case e: SecurityException => dir = null; }
294    }
295
296    dir.getCanonicalFile
297  }
298
299  /**
300   * Create a temporary directory inside the given parent directory. The directory will be
301   * automatically deleted when the VM shuts down.
302   */
303  def createTempDir(
304      root: String = System.getProperty("java.io.tmpdir"),
305      namePrefix: String = "spark"): File = {
306    val dir = createDirectory(root, namePrefix)
307    ShutdownHookManager.registerShutdownDeleteDir(dir)
308    dir
309  }
310
311  /**
312   * Copy all data from an InputStream to an OutputStream. NIO way of file stream to file stream
313   * copying is disabled by default unless explicitly set transferToEnabled as true,
314   * the parameter transferToEnabled should be configured by spark.file.transferTo = [true|false].
315   */
316  def copyStream(in: InputStream,
317                 out: OutputStream,
318                 closeStreams: Boolean = false,
319                 transferToEnabled: Boolean = false): Long =
320  {
321    var count = 0L
322    tryWithSafeFinally {
323      if (in.isInstanceOf[FileInputStream] && out.isInstanceOf[FileOutputStream]
324        && transferToEnabled) {
325        // When both streams are File stream, use transferTo to improve copy performance.
326        val inChannel = in.asInstanceOf[FileInputStream].getChannel()
327        val outChannel = out.asInstanceOf[FileOutputStream].getChannel()
328        val initialPos = outChannel.position()
329        val size = inChannel.size()
330
331        // In case transferTo method transferred less data than we have required.
332        while (count < size) {
333          count += inChannel.transferTo(count, size - count, outChannel)
334        }
335
336        // Check the position after transferTo loop to see if it is in the right position and
337        // give user information if not.
338        // Position will not be increased to the expected length after calling transferTo in
339        // kernel version 2.6.32, this issue can be seen in
340        // https://bugs.openjdk.java.net/browse/JDK-7052359
341        // This will lead to stream corruption issue when using sort-based shuffle (SPARK-3948).
342        val finalPos = outChannel.position()
343        assert(finalPos == initialPos + size,
344          s"""
345             |Current position $finalPos do not equal to expected position ${initialPos + size}
346             |after transferTo, please check your kernel version to see if it is 2.6.32,
347             |this is a kernel bug which will lead to unexpected behavior when using transferTo.
348             |You can set spark.file.transferTo = false to disable this NIO feature.
349           """.stripMargin)
350      } else {
351        val buf = new Array[Byte](8192)
352        var n = 0
353        while (n != -1) {
354          n = in.read(buf)
355          if (n != -1) {
356            out.write(buf, 0, n)
357            count += n
358          }
359        }
360      }
361      count
362    } {
363      if (closeStreams) {
364        try {
365          in.close()
366        } finally {
367          out.close()
368        }
369      }
370    }
371  }
372
373  /**
374   * Construct a URI container information used for authentication.
375   * This also sets the default authenticator to properly negotiation the
376   * user/password based on the URI.
377   *
378   * Note this relies on the Authenticator.setDefault being set properly to decode
379   * the user name and password. This is currently set in the SecurityManager.
380   */
381  def constructURIForAuthentication(uri: URI, securityMgr: SecurityManager): URI = {
382    val userCred = securityMgr.getSecretKey()
383    if (userCred == null) throw new Exception("Secret key is null with authentication on")
384    val userInfo = securityMgr.getHttpUser()  + ":" + userCred
385    new URI(uri.getScheme(), userInfo, uri.getHost(), uri.getPort(), uri.getPath(),
386      uri.getQuery(), uri.getFragment())
387  }
388
389  /**
390   * A file name may contain some invalid URI characters, such as " ". This method will convert the
391   * file name to a raw path accepted by `java.net.URI(String)`.
392   *
393   * Note: the file name must not contain "/" or "\"
394   */
395  def encodeFileNameToURIRawPath(fileName: String): String = {
396    require(!fileName.contains("/") && !fileName.contains("\\"))
397    // `file` and `localhost` are not used. Just to prevent URI from parsing `fileName` as
398    // scheme or host. The prefix "/" is required because URI doesn't accept a relative path.
399    // We should remove it after we get the raw path.
400    new URI("file", null, "localhost", -1, "/" + fileName, null, null).getRawPath.substring(1)
401  }
402
403  /**
404   * Get the file name from uri's raw path and decode it. If the raw path of uri ends with "/",
405   * return the name before the last "/".
406   */
407  def decodeFileNameInURI(uri: URI): String = {
408    val rawPath = uri.getRawPath
409    val rawFileName = rawPath.split("/").last
410    new URI("file:///" + rawFileName).getPath.substring(1)
411  }
412
413    /**
414   * Download a file or directory to target directory. Supports fetching the file in a variety of
415   * ways, including HTTP, Hadoop-compatible filesystems, and files on a standard filesystem, based
416   * on the URL parameter. Fetching directories is only supported from Hadoop-compatible
417   * filesystems.
418   *
419   * If `useCache` is true, first attempts to fetch the file to a local cache that's shared
420   * across executors running the same application. `useCache` is used mainly for
421   * the executors, and not in local mode.
422   *
423   * Throws SparkException if the target file already exists and has different contents than
424   * the requested file.
425   */
426  def fetchFile(
427      url: String,
428      targetDir: File,
429      conf: SparkConf,
430      securityMgr: SecurityManager,
431      hadoopConf: Configuration,
432      timestamp: Long,
433      useCache: Boolean) {
434    val fileName = decodeFileNameInURI(new URI(url))
435    val targetFile = new File(targetDir, fileName)
436    val fetchCacheEnabled = conf.getBoolean("spark.files.useFetchCache", defaultValue = true)
437    if (useCache && fetchCacheEnabled) {
438      val cachedFileName = s"${url.hashCode}${timestamp}_cache"
439      val lockFileName = s"${url.hashCode}${timestamp}_lock"
440      val localDir = new File(getLocalDir(conf))
441      val lockFile = new File(localDir, lockFileName)
442      val lockFileChannel = new RandomAccessFile(lockFile, "rw").getChannel()
443      // Only one executor entry.
444      // The FileLock is only used to control synchronization for executors download file,
445      // it's always safe regardless of lock type (mandatory or advisory).
446      val lock = lockFileChannel.lock()
447      val cachedFile = new File(localDir, cachedFileName)
448      try {
449        if (!cachedFile.exists()) {
450          doFetchFile(url, localDir, cachedFileName, conf, securityMgr, hadoopConf)
451        }
452      } finally {
453        lock.release()
454        lockFileChannel.close()
455      }
456      copyFile(
457        url,
458        cachedFile,
459        targetFile,
460        conf.getBoolean("spark.files.overwrite", false)
461      )
462    } else {
463      doFetchFile(url, targetDir, fileName, conf, securityMgr, hadoopConf)
464    }
465
466    // Decompress the file if it's a .tar or .tar.gz
467    if (fileName.endsWith(".tar.gz") || fileName.endsWith(".tgz")) {
468      logInfo("Untarring " + fileName)
469      executeAndGetOutput(Seq("tar", "-xzf", fileName), targetDir)
470    } else if (fileName.endsWith(".tar")) {
471      logInfo("Untarring " + fileName)
472      executeAndGetOutput(Seq("tar", "-xf", fileName), targetDir)
473    }
474    // Make the file executable - That's necessary for scripts
475    FileUtil.chmod(targetFile.getAbsolutePath, "a+x")
476
477    // Windows does not grant read permission by default to non-admin users
478    // Add read permission to owner explicitly
479    if (isWindows) {
480      FileUtil.chmod(targetFile.getAbsolutePath, "u+r")
481    }
482  }
483
484  /**
485   * Download `in` to `tempFile`, then move it to `destFile`.
486   *
487   * If `destFile` already exists:
488   *   - no-op if its contents equal those of `sourceFile`,
489   *   - throw an exception if `fileOverwrite` is false,
490   *   - attempt to overwrite it otherwise.
491   *
492   * @param url URL that `sourceFile` originated from, for logging purposes.
493   * @param in InputStream to download.
494   * @param destFile File path to move `tempFile` to.
495   * @param fileOverwrite Whether to delete/overwrite an existing `destFile` that does not match
496   *                      `sourceFile`
497   */
498  private def downloadFile(
499      url: String,
500      in: InputStream,
501      destFile: File,
502      fileOverwrite: Boolean): Unit = {
503    val tempFile = File.createTempFile("fetchFileTemp", null,
504      new File(destFile.getParentFile.getAbsolutePath))
505    logInfo(s"Fetching $url to $tempFile")
506
507    try {
508      val out = new FileOutputStream(tempFile)
509      Utils.copyStream(in, out, closeStreams = true)
510      copyFile(url, tempFile, destFile, fileOverwrite, removeSourceFile = true)
511    } finally {
512      // Catch-all for the couple of cases where for some reason we didn't move `tempFile` to
513      // `destFile`.
514      if (tempFile.exists()) {
515        tempFile.delete()
516      }
517    }
518  }
519
520  /**
521   * Copy `sourceFile` to `destFile`.
522   *
523   * If `destFile` already exists:
524   *   - no-op if its contents equal those of `sourceFile`,
525   *   - throw an exception if `fileOverwrite` is false,
526   *   - attempt to overwrite it otherwise.
527   *
528   * @param url URL that `sourceFile` originated from, for logging purposes.
529   * @param sourceFile File path to copy/move from.
530   * @param destFile File path to copy/move to.
531   * @param fileOverwrite Whether to delete/overwrite an existing `destFile` that does not match
532   *                      `sourceFile`
533   * @param removeSourceFile Whether to remove `sourceFile` after / as part of moving/copying it to
534   *                         `destFile`.
535   */
536  private def copyFile(
537      url: String,
538      sourceFile: File,
539      destFile: File,
540      fileOverwrite: Boolean,
541      removeSourceFile: Boolean = false): Unit = {
542
543    if (destFile.exists) {
544      if (!filesEqualRecursive(sourceFile, destFile)) {
545        if (fileOverwrite) {
546          logInfo(
547            s"File $destFile exists and does not match contents of $url, replacing it with $url"
548          )
549          if (!destFile.delete()) {
550            throw new SparkException(
551              "Failed to delete %s while attempting to overwrite it with %s".format(
552                destFile.getAbsolutePath,
553                sourceFile.getAbsolutePath
554              )
555            )
556          }
557        } else {
558          throw new SparkException(
559            s"File $destFile exists and does not match contents of $url")
560        }
561      } else {
562        // Do nothing if the file contents are the same, i.e. this file has been copied
563        // previously.
564        logInfo(
565          "%s has been previously copied to %s".format(
566            sourceFile.getAbsolutePath,
567            destFile.getAbsolutePath
568          )
569        )
570        return
571      }
572    }
573
574    // The file does not exist in the target directory. Copy or move it there.
575    if (removeSourceFile) {
576      Files.move(sourceFile.toPath, destFile.toPath)
577    } else {
578      logInfo(s"Copying ${sourceFile.getAbsolutePath} to ${destFile.getAbsolutePath}")
579      copyRecursive(sourceFile, destFile)
580    }
581  }
582
583  private def filesEqualRecursive(file1: File, file2: File): Boolean = {
584    if (file1.isDirectory && file2.isDirectory) {
585      val subfiles1 = file1.listFiles()
586      val subfiles2 = file2.listFiles()
587      if (subfiles1.size != subfiles2.size) {
588        return false
589      }
590      subfiles1.sortBy(_.getName).zip(subfiles2.sortBy(_.getName)).forall {
591        case (f1, f2) => filesEqualRecursive(f1, f2)
592      }
593    } else if (file1.isFile && file2.isFile) {
594      GFiles.equal(file1, file2)
595    } else {
596      false
597    }
598  }
599
600  private def copyRecursive(source: File, dest: File): Unit = {
601    if (source.isDirectory) {
602      if (!dest.mkdir()) {
603        throw new IOException(s"Failed to create directory ${dest.getPath}")
604      }
605      val subfiles = source.listFiles()
606      subfiles.foreach(f => copyRecursive(f, new File(dest, f.getName)))
607    } else {
608      Files.copy(source.toPath, dest.toPath)
609    }
610  }
611
612  /**
613   * Download a file or directory to target directory. Supports fetching the file in a variety of
614   * ways, including HTTP, Hadoop-compatible filesystems, and files on a standard filesystem, based
615   * on the URL parameter. Fetching directories is only supported from Hadoop-compatible
616   * filesystems.
617   *
618   * Throws SparkException if the target file already exists and has different contents than
619   * the requested file.
620   */
621  private def doFetchFile(
622      url: String,
623      targetDir: File,
624      filename: String,
625      conf: SparkConf,
626      securityMgr: SecurityManager,
627      hadoopConf: Configuration) {
628    val targetFile = new File(targetDir, filename)
629    val uri = new URI(url)
630    val fileOverwrite = conf.getBoolean("spark.files.overwrite", defaultValue = false)
631    Option(uri.getScheme).getOrElse("file") match {
632      case "spark" =>
633        if (SparkEnv.get == null) {
634          throw new IllegalStateException(
635            "Cannot retrieve files with 'spark' scheme without an active SparkEnv.")
636        }
637        val source = SparkEnv.get.rpcEnv.openChannel(url)
638        val is = Channels.newInputStream(source)
639        downloadFile(url, is, targetFile, fileOverwrite)
640      case "http" | "https" | "ftp" =>
641        var uc: URLConnection = null
642        if (securityMgr.isAuthenticationEnabled()) {
643          logDebug("fetchFile with security enabled")
644          val newuri = constructURIForAuthentication(uri, securityMgr)
645          uc = newuri.toURL().openConnection()
646          uc.setAllowUserInteraction(false)
647        } else {
648          logDebug("fetchFile not using security")
649          uc = new URL(url).openConnection()
650        }
651        Utils.setupSecureURLConnection(uc, securityMgr)
652
653        val timeoutMs =
654          conf.getTimeAsSeconds("spark.files.fetchTimeout", "60s").toInt * 1000
655        uc.setConnectTimeout(timeoutMs)
656        uc.setReadTimeout(timeoutMs)
657        uc.connect()
658        val in = uc.getInputStream()
659        downloadFile(url, in, targetFile, fileOverwrite)
660      case "file" =>
661        // In the case of a local file, copy the local file to the target directory.
662        // Note the difference between uri vs url.
663        val sourceFile = if (uri.isAbsolute) new File(uri) else new File(url)
664        copyFile(url, sourceFile, targetFile, fileOverwrite)
665      case _ =>
666        val fs = getHadoopFileSystem(uri, hadoopConf)
667        val path = new Path(uri)
668        fetchHcfsFile(path, targetDir, fs, conf, hadoopConf, fileOverwrite,
669                      filename = Some(filename))
670    }
671  }
672
673  /**
674   * Fetch a file or directory from a Hadoop-compatible filesystem.
675   *
676   * Visible for testing
677   */
678  private[spark] def fetchHcfsFile(
679      path: Path,
680      targetDir: File,
681      fs: FileSystem,
682      conf: SparkConf,
683      hadoopConf: Configuration,
684      fileOverwrite: Boolean,
685      filename: Option[String] = None): Unit = {
686    if (!targetDir.exists() && !targetDir.mkdir()) {
687      throw new IOException(s"Failed to create directory ${targetDir.getPath}")
688    }
689    val dest = new File(targetDir, filename.getOrElse(path.getName))
690    if (fs.isFile(path)) {
691      val in = fs.open(path)
692      try {
693        downloadFile(path.toString, in, dest, fileOverwrite)
694      } finally {
695        in.close()
696      }
697    } else {
698      fs.listStatus(path).foreach { fileStatus =>
699        fetchHcfsFile(fileStatus.getPath(), dest, fs, conf, hadoopConf, fileOverwrite)
700      }
701    }
702  }
703
704  /**
705   * Validate that a given URI is actually a valid URL as well.
706   * @param uri The URI to validate
707   */
708  @throws[MalformedURLException]("when the URI is an invalid URL")
709  def validateURL(uri: URI): Unit = {
710    Option(uri.getScheme).getOrElse("file") match {
711      case "http" | "https" | "ftp" =>
712        try {
713          uri.toURL
714        } catch {
715          case e: MalformedURLException =>
716            val ex = new MalformedURLException(s"URI (${uri.toString}) is not a valid URL.")
717            ex.initCause(e)
718            throw ex
719        }
720      case _ => // will not be turned into a URL anyway
721    }
722  }
723
724  /**
725   * Get the path of a temporary directory.  Spark's local directories can be configured through
726   * multiple settings, which are used with the following precedence:
727   *
728   *   - If called from inside of a YARN container, this will return a directory chosen by YARN.
729   *   - If the SPARK_LOCAL_DIRS environment variable is set, this will return a directory from it.
730   *   - Otherwise, if the spark.local.dir is set, this will return a directory from it.
731   *   - Otherwise, this will return java.io.tmpdir.
732   *
733   * Some of these configuration options might be lists of multiple paths, but this method will
734   * always return a single directory.
735   */
736  def getLocalDir(conf: SparkConf): String = {
737    getOrCreateLocalRootDirs(conf)(0)
738  }
739
740  private[spark] def isRunningInYarnContainer(conf: SparkConf): Boolean = {
741    // These environment variables are set by YARN.
742    conf.getenv("CONTAINER_ID") != null
743  }
744
745  /**
746   * Gets or creates the directories listed in spark.local.dir or SPARK_LOCAL_DIRS,
747   * and returns only the directories that exist / could be created.
748   *
749   * If no directories could be created, this will return an empty list.
750   *
751   * This method will cache the local directories for the application when it's first invoked.
752   * So calling it multiple times with a different configuration will always return the same
753   * set of directories.
754   */
755  private[spark] def getOrCreateLocalRootDirs(conf: SparkConf): Array[String] = {
756    if (localRootDirs == null) {
757      this.synchronized {
758        if (localRootDirs == null) {
759          localRootDirs = getOrCreateLocalRootDirsImpl(conf)
760        }
761      }
762    }
763    localRootDirs
764  }
765
766  /**
767   * Return the configured local directories where Spark can write files. This
768   * method does not create any directories on its own, it only encapsulates the
769   * logic of locating the local directories according to deployment mode.
770   */
771  def getConfiguredLocalDirs(conf: SparkConf): Array[String] = {
772    val shuffleServiceEnabled = conf.getBoolean("spark.shuffle.service.enabled", false)
773    if (isRunningInYarnContainer(conf)) {
774      // If we are in yarn mode, systems can have different disk layouts so we must set it
775      // to what Yarn on this system said was available. Note this assumes that Yarn has
776      // created the directories already, and that they are secured so that only the
777      // user has access to them.
778      getYarnLocalDirs(conf).split(",")
779    } else if (conf.getenv("SPARK_EXECUTOR_DIRS") != null) {
780      conf.getenv("SPARK_EXECUTOR_DIRS").split(File.pathSeparator)
781    } else if (conf.getenv("SPARK_LOCAL_DIRS") != null) {
782      conf.getenv("SPARK_LOCAL_DIRS").split(",")
783    } else if (conf.getenv("MESOS_DIRECTORY") != null && !shuffleServiceEnabled) {
784      // Mesos already creates a directory per Mesos task. Spark should use that directory
785      // instead so all temporary files are automatically cleaned up when the Mesos task ends.
786      // Note that we don't want this if the shuffle service is enabled because we want to
787      // continue to serve shuffle files after the executors that wrote them have already exited.
788      Array(conf.getenv("MESOS_DIRECTORY"))
789    } else {
790      if (conf.getenv("MESOS_DIRECTORY") != null && shuffleServiceEnabled) {
791        logInfo("MESOS_DIRECTORY available but not using provided Mesos sandbox because " +
792          "spark.shuffle.service.enabled is enabled.")
793      }
794      // In non-Yarn mode (or for the driver in yarn-client mode), we cannot trust the user
795      // configuration to point to a secure directory. So create a subdirectory with restricted
796      // permissions under each listed directory.
797      conf.get("spark.local.dir", System.getProperty("java.io.tmpdir")).split(",")
798    }
799  }
800
801  private def getOrCreateLocalRootDirsImpl(conf: SparkConf): Array[String] = {
802    getConfiguredLocalDirs(conf).flatMap { root =>
803      try {
804        val rootDir = new File(root)
805        if (rootDir.exists || rootDir.mkdirs()) {
806          val dir = createTempDir(root)
807          chmod700(dir)
808          Some(dir.getAbsolutePath)
809        } else {
810          logError(s"Failed to create dir in $root. Ignoring this directory.")
811          None
812        }
813      } catch {
814        case e: IOException =>
815          logError(s"Failed to create local root dir in $root. Ignoring this directory.")
816          None
817      }
818    }
819  }
820
821  /** Get the Yarn approved local directories. */
822  private def getYarnLocalDirs(conf: SparkConf): String = {
823    val localDirs = Option(conf.getenv("LOCAL_DIRS")).getOrElse("")
824
825    if (localDirs.isEmpty) {
826      throw new Exception("Yarn Local dirs can't be empty")
827    }
828    localDirs
829  }
830
831  /** Used by unit tests. Do not call from other places. */
832  private[spark] def clearLocalRootDirs(): Unit = {
833    localRootDirs = null
834  }
835
836  /**
837   * Shuffle the elements of a collection into a random order, returning the
838   * result in a new collection. Unlike scala.util.Random.shuffle, this method
839   * uses a local random number generator, avoiding inter-thread contention.
840   */
841  def randomize[T: ClassTag](seq: TraversableOnce[T]): Seq[T] = {
842    randomizeInPlace(seq.toArray)
843  }
844
845  /**
846   * Shuffle the elements of an array into a random order, modifying the
847   * original array. Returns the original array.
848   */
849  def randomizeInPlace[T](arr: Array[T], rand: Random = new Random): Array[T] = {
850    for (i <- (arr.length - 1) to 1 by -1) {
851      val j = rand.nextInt(i + 1)
852      val tmp = arr(j)
853      arr(j) = arr(i)
854      arr(i) = tmp
855    }
856    arr
857  }
858
859  /**
860   * Get the local host's IP address in dotted-quad format (e.g. 1.2.3.4).
861   * Note, this is typically not used from within core spark.
862   */
863  private lazy val localIpAddress: InetAddress = findLocalInetAddress()
864
865  private def findLocalInetAddress(): InetAddress = {
866    val defaultIpOverride = System.getenv("SPARK_LOCAL_IP")
867    if (defaultIpOverride != null) {
868      InetAddress.getByName(defaultIpOverride)
869    } else {
870      val address = InetAddress.getLocalHost
871      if (address.isLoopbackAddress) {
872        // Address resolves to something like 127.0.1.1, which happens on Debian; try to find
873        // a better address using the local network interfaces
874        // getNetworkInterfaces returns ifs in reverse order compared to ifconfig output order
875        // on unix-like system. On windows, it returns in index order.
876        // It's more proper to pick ip address following system output order.
877        val activeNetworkIFs = NetworkInterface.getNetworkInterfaces.asScala.toSeq
878        val reOrderedNetworkIFs = if (isWindows) activeNetworkIFs else activeNetworkIFs.reverse
879
880        for (ni <- reOrderedNetworkIFs) {
881          val addresses = ni.getInetAddresses.asScala
882            .filterNot(addr => addr.isLinkLocalAddress || addr.isLoopbackAddress).toSeq
883          if (addresses.nonEmpty) {
884            val addr = addresses.find(_.isInstanceOf[Inet4Address]).getOrElse(addresses.head)
885            // because of Inet6Address.toHostName may add interface at the end if it knows about it
886            val strippedAddress = InetAddress.getByAddress(addr.getAddress)
887            // We've found an address that looks reasonable!
888            logWarning("Your hostname, " + InetAddress.getLocalHost.getHostName + " resolves to" +
889              " a loopback address: " + address.getHostAddress + "; using " +
890              strippedAddress.getHostAddress + " instead (on interface " + ni.getName + ")")
891            logWarning("Set SPARK_LOCAL_IP if you need to bind to another address")
892            return strippedAddress
893          }
894        }
895        logWarning("Your hostname, " + InetAddress.getLocalHost.getHostName + " resolves to" +
896          " a loopback address: " + address.getHostAddress + ", but we couldn't find any" +
897          " external IP address!")
898        logWarning("Set SPARK_LOCAL_IP if you need to bind to another address")
899      }
900      address
901    }
902  }
903
904  private var customHostname: Option[String] = sys.env.get("SPARK_LOCAL_HOSTNAME")
905
906  /**
907   * Allow setting a custom host name because when we run on Mesos we need to use the same
908   * hostname it reports to the master.
909   */
910  def setCustomHostname(hostname: String) {
911    // DEBUG code
912    Utils.checkHost(hostname)
913    customHostname = Some(hostname)
914  }
915
916  /**
917   * Get the local machine's hostname.
918   */
919  def localHostName(): String = {
920    customHostname.getOrElse(localIpAddress.getHostAddress)
921  }
922
923  /**
924   * Get the local machine's URI.
925   */
926  def localHostNameForURI(): String = {
927    customHostname.getOrElse(InetAddresses.toUriString(localIpAddress))
928  }
929
930  def checkHost(host: String, message: String = "") {
931    assert(host.indexOf(':') == -1, message)
932  }
933
934  def checkHostPort(hostPort: String, message: String = "") {
935    assert(hostPort.indexOf(':') != -1, message)
936  }
937
938  // Typically, this will be of order of number of nodes in cluster
939  // If not, we should change it to LRUCache or something.
940  private val hostPortParseResults = new ConcurrentHashMap[String, (String, Int)]()
941
942  def parseHostPort(hostPort: String): (String, Int) = {
943    // Check cache first.
944    val cached = hostPortParseResults.get(hostPort)
945    if (cached != null) {
946      return cached
947    }
948
949    val indx: Int = hostPort.lastIndexOf(':')
950    // This is potentially broken - when dealing with ipv6 addresses for example, sigh ...
951    // but then hadoop does not support ipv6 right now.
952    // For now, we assume that if port exists, then it is valid - not check if it is an int > 0
953    if (-1 == indx) {
954      val retval = (hostPort, 0)
955      hostPortParseResults.put(hostPort, retval)
956      return retval
957    }
958
959    val retval = (hostPort.substring(0, indx).trim(), hostPort.substring(indx + 1).trim().toInt)
960    hostPortParseResults.putIfAbsent(hostPort, retval)
961    hostPortParseResults.get(hostPort)
962  }
963
964  /**
965   * Return the string to tell how long has passed in milliseconds.
966   */
967  def getUsedTimeMs(startTimeMs: Long): String = {
968    " " + (System.currentTimeMillis - startTimeMs) + " ms"
969  }
970
971  private def listFilesSafely(file: File): Seq[File] = {
972    if (file.exists()) {
973      val files = file.listFiles()
974      if (files == null) {
975        throw new IOException("Failed to list files for dir: " + file)
976      }
977      files
978    } else {
979      List()
980    }
981  }
982
983  /**
984   * Delete a file or directory and its contents recursively.
985   * Don't follow directories if they are symlinks.
986   * Throws an exception if deletion is unsuccessful.
987   */
988  def deleteRecursively(file: File) {
989    if (file != null) {
990      try {
991        if (file.isDirectory && !isSymlink(file)) {
992          var savedIOException: IOException = null
993          for (child <- listFilesSafely(file)) {
994            try {
995              deleteRecursively(child)
996            } catch {
997              // In case of multiple exceptions, only last one will be thrown
998              case ioe: IOException => savedIOException = ioe
999            }
1000          }
1001          if (savedIOException != null) {
1002            throw savedIOException
1003          }
1004          ShutdownHookManager.removeShutdownDeleteDir(file)
1005        }
1006      } finally {
1007        if (!file.delete()) {
1008          // Delete can also fail if the file simply did not exist
1009          if (file.exists()) {
1010            throw new IOException("Failed to delete: " + file.getAbsolutePath)
1011          }
1012        }
1013      }
1014    }
1015  }
1016
1017  /**
1018   * Check to see if file is a symbolic link.
1019   */
1020  def isSymlink(file: File): Boolean = {
1021    return Files.isSymbolicLink(Paths.get(file.toURI))
1022  }
1023
1024  /**
1025   * Determines if a directory contains any files newer than cutoff seconds.
1026   *
1027   * @param dir must be the path to a directory, or IllegalArgumentException is thrown
1028   * @param cutoff measured in seconds. Returns true if there are any files or directories in the
1029   *               given directory whose last modified time is later than this many seconds ago
1030   */
1031  def doesDirectoryContainAnyNewFiles(dir: File, cutoff: Long): Boolean = {
1032    if (!dir.isDirectory) {
1033      throw new IllegalArgumentException(s"$dir is not a directory!")
1034    }
1035    val filesAndDirs = dir.listFiles()
1036    val cutoffTimeInMillis = System.currentTimeMillis - (cutoff * 1000)
1037
1038    filesAndDirs.exists(_.lastModified() > cutoffTimeInMillis) ||
1039    filesAndDirs.filter(_.isDirectory).exists(
1040      subdir => doesDirectoryContainAnyNewFiles(subdir, cutoff)
1041    )
1042  }
1043
1044  /**
1045   * Convert a time parameter such as (50s, 100ms, or 250us) to microseconds for internal use. If
1046   * no suffix is provided, the passed number is assumed to be in ms.
1047   */
1048  def timeStringAsMs(str: String): Long = {
1049    JavaUtils.timeStringAsMs(str)
1050  }
1051
1052  /**
1053   * Convert a time parameter such as (50s, 100ms, or 250us) to seconds for internal use. If
1054   * no suffix is provided, the passed number is assumed to be in seconds.
1055   */
1056  def timeStringAsSeconds(str: String): Long = {
1057    JavaUtils.timeStringAsSec(str)
1058  }
1059
1060  /**
1061   * Convert a passed byte string (e.g. 50b, 100k, or 250m) to bytes for internal use.
1062   *
1063   * If no suffix is provided, the passed number is assumed to be in bytes.
1064   */
1065  def byteStringAsBytes(str: String): Long = {
1066    JavaUtils.byteStringAsBytes(str)
1067  }
1068
1069  /**
1070   * Convert a passed byte string (e.g. 50b, 100k, or 250m) to kibibytes for internal use.
1071   *
1072   * If no suffix is provided, the passed number is assumed to be in kibibytes.
1073   */
1074  def byteStringAsKb(str: String): Long = {
1075    JavaUtils.byteStringAsKb(str)
1076  }
1077
1078  /**
1079   * Convert a passed byte string (e.g. 50b, 100k, or 250m) to mebibytes for internal use.
1080   *
1081   * If no suffix is provided, the passed number is assumed to be in mebibytes.
1082   */
1083  def byteStringAsMb(str: String): Long = {
1084    JavaUtils.byteStringAsMb(str)
1085  }
1086
1087  /**
1088   * Convert a passed byte string (e.g. 50b, 100k, or 250m, 500g) to gibibytes for internal use.
1089   *
1090   * If no suffix is provided, the passed number is assumed to be in gibibytes.
1091   */
1092  def byteStringAsGb(str: String): Long = {
1093    JavaUtils.byteStringAsGb(str)
1094  }
1095
1096  /**
1097   * Convert a Java memory parameter passed to -Xmx (such as 300m or 1g) to a number of mebibytes.
1098   */
1099  def memoryStringToMb(str: String): Int = {
1100    // Convert to bytes, rather than directly to MB, because when no units are specified the unit
1101    // is assumed to be bytes
1102    (JavaUtils.byteStringAsBytes(str) / 1024 / 1024).toInt
1103  }
1104
1105  /**
1106   * Convert a quantity in bytes to a human-readable string such as "4.0 MB".
1107   */
1108  def bytesToString(size: Long): String = {
1109    val TB = 1L << 40
1110    val GB = 1L << 30
1111    val MB = 1L << 20
1112    val KB = 1L << 10
1113
1114    val (value, unit) = {
1115      if (size >= 2*TB) {
1116        (size.asInstanceOf[Double] / TB, "TB")
1117      } else if (size >= 2*GB) {
1118        (size.asInstanceOf[Double] / GB, "GB")
1119      } else if (size >= 2*MB) {
1120        (size.asInstanceOf[Double] / MB, "MB")
1121      } else if (size >= 2*KB) {
1122        (size.asInstanceOf[Double] / KB, "KB")
1123      } else {
1124        (size.asInstanceOf[Double], "B")
1125      }
1126    }
1127    "%.1f %s".formatLocal(Locale.US, value, unit)
1128  }
1129
1130  /**
1131   * Returns a human-readable string representing a duration such as "35ms"
1132   */
1133  def msDurationToString(ms: Long): String = {
1134    val second = 1000
1135    val minute = 60 * second
1136    val hour = 60 * minute
1137
1138    ms match {
1139      case t if t < second =>
1140        "%d ms".format(t)
1141      case t if t < minute =>
1142        "%.1f s".format(t.toFloat / second)
1143      case t if t < hour =>
1144        "%.1f m".format(t.toFloat / minute)
1145      case t =>
1146        "%.2f h".format(t.toFloat / hour)
1147    }
1148  }
1149
1150  /**
1151   * Convert a quantity in megabytes to a human-readable string such as "4.0 MB".
1152   */
1153  def megabytesToString(megabytes: Long): String = {
1154    bytesToString(megabytes * 1024L * 1024L)
1155  }
1156
1157  /**
1158   * Execute a command and return the process running the command.
1159   */
1160  def executeCommand(
1161      command: Seq[String],
1162      workingDir: File = new File("."),
1163      extraEnvironment: Map[String, String] = Map.empty,
1164      redirectStderr: Boolean = true): Process = {
1165    val builder = new ProcessBuilder(command: _*).directory(workingDir)
1166    val environment = builder.environment()
1167    for ((key, value) <- extraEnvironment) {
1168      environment.put(key, value)
1169    }
1170    val process = builder.start()
1171    if (redirectStderr) {
1172      val threadName = "redirect stderr for command " + command(0)
1173      def log(s: String): Unit = logInfo(s)
1174      processStreamByLine(threadName, process.getErrorStream, log)
1175    }
1176    process
1177  }
1178
1179  /**
1180   * Execute a command and get its output, throwing an exception if it yields a code other than 0.
1181   */
1182  def executeAndGetOutput(
1183      command: Seq[String],
1184      workingDir: File = new File("."),
1185      extraEnvironment: Map[String, String] = Map.empty,
1186      redirectStderr: Boolean = true): String = {
1187    val process = executeCommand(command, workingDir, extraEnvironment, redirectStderr)
1188    val output = new StringBuilder
1189    val threadName = "read stdout for " + command(0)
1190    def appendToOutput(s: String): Unit = output.append(s).append("\n")
1191    val stdoutThread = processStreamByLine(threadName, process.getInputStream, appendToOutput)
1192    val exitCode = process.waitFor()
1193    stdoutThread.join()   // Wait for it to finish reading output
1194    if (exitCode != 0) {
1195      logError(s"Process $command exited with code $exitCode: $output")
1196      throw new SparkException(s"Process $command exited with code $exitCode")
1197    }
1198    output.toString
1199  }
1200
1201  /**
1202   * Return and start a daemon thread that processes the content of the input stream line by line.
1203   */
1204  def processStreamByLine(
1205      threadName: String,
1206      inputStream: InputStream,
1207      processLine: String => Unit): Thread = {
1208    val t = new Thread(threadName) {
1209      override def run() {
1210        for (line <- Source.fromInputStream(inputStream).getLines()) {
1211          processLine(line)
1212        }
1213      }
1214    }
1215    t.setDaemon(true)
1216    t.start()
1217    t
1218  }
1219
1220  /**
1221   * Execute a block of code that evaluates to Unit, forwarding any uncaught exceptions to the
1222   * default UncaughtExceptionHandler
1223   *
1224   * NOTE: This method is to be called by the spark-started JVM process.
1225   */
1226  def tryOrExit(block: => Unit) {
1227    try {
1228      block
1229    } catch {
1230      case e: ControlThrowable => throw e
1231      case t: Throwable => SparkUncaughtExceptionHandler.uncaughtException(t)
1232    }
1233  }
1234
1235  /**
1236   * Execute a block of code that evaluates to Unit, stop SparkContext if there is any uncaught
1237   * exception
1238   *
1239   * NOTE: This method is to be called by the driver-side components to avoid stopping the
1240   * user-started JVM process completely; in contrast, tryOrExit is to be called in the
1241   * spark-started JVM process .
1242   */
1243  def tryOrStopSparkContext(sc: SparkContext)(block: => Unit) {
1244    try {
1245      block
1246    } catch {
1247      case e: ControlThrowable => throw e
1248      case t: Throwable =>
1249        val currentThreadName = Thread.currentThread().getName
1250        if (sc != null) {
1251          logError(s"uncaught error in thread $currentThreadName, stopping SparkContext", t)
1252          sc.stopInNewThread()
1253        }
1254        if (!NonFatal(t)) {
1255          logError(s"throw uncaught fatal error in thread $currentThreadName", t)
1256          throw t
1257        }
1258    }
1259  }
1260
1261  /**
1262   * Execute a block of code that returns a value, re-throwing any non-fatal uncaught
1263   * exceptions as IOException. This is used when implementing Externalizable and Serializable's
1264   * read and write methods, since Java's serializer will not report non-IOExceptions properly;
1265   * see SPARK-4080 for more context.
1266   */
1267  def tryOrIOException[T](block: => T): T = {
1268    try {
1269      block
1270    } catch {
1271      case e: IOException =>
1272        logError("Exception encountered", e)
1273        throw e
1274      case NonFatal(e) =>
1275        logError("Exception encountered", e)
1276        throw new IOException(e)
1277    }
1278  }
1279
1280  /** Executes the given block. Log non-fatal errors if any, and only throw fatal errors */
1281  def tryLogNonFatalError(block: => Unit) {
1282    try {
1283      block
1284    } catch {
1285      case NonFatal(t) =>
1286        logError(s"Uncaught exception in thread ${Thread.currentThread().getName}", t)
1287    }
1288  }
1289
1290  /**
1291   * Execute a block of code, then a finally block, but if exceptions happen in
1292   * the finally block, do not suppress the original exception.
1293   *
1294   * This is primarily an issue with `finally { out.close() }` blocks, where
1295   * close needs to be called to clean up `out`, but if an exception happened
1296   * in `out.write`, it's likely `out` may be corrupted and `out.close` will
1297   * fail as well. This would then suppress the original/likely more meaningful
1298   * exception from the original `out.write` call.
1299   */
1300  def tryWithSafeFinally[T](block: => T)(finallyBlock: => Unit): T = {
1301    var originalThrowable: Throwable = null
1302    try {
1303      block
1304    } catch {
1305      case t: Throwable =>
1306        // Purposefully not using NonFatal, because even fatal exceptions
1307        // we don't want to have our finallyBlock suppress
1308        originalThrowable = t
1309        throw originalThrowable
1310    } finally {
1311      try {
1312        finallyBlock
1313      } catch {
1314        case t: Throwable =>
1315          if (originalThrowable != null) {
1316            originalThrowable.addSuppressed(t)
1317            logWarning(s"Suppressing exception in finally: " + t.getMessage, t)
1318            throw originalThrowable
1319          } else {
1320            throw t
1321          }
1322      }
1323    }
1324  }
1325
1326  /**
1327   * Execute a block of code and call the failure callbacks in the catch block. If exceptions occur
1328   * in either the catch or the finally block, they are appended to the list of suppressed
1329   * exceptions in original exception which is then rethrown.
1330   *
1331   * This is primarily an issue with `catch { abort() }` or `finally { out.close() }` blocks,
1332   * where the abort/close needs to be called to clean up `out`, but if an exception happened
1333   * in `out.write`, it's likely `out` may be corrupted and `abort` or `out.close` will
1334   * fail as well. This would then suppress the original/likely more meaningful
1335   * exception from the original `out.write` call.
1336   */
1337  def tryWithSafeFinallyAndFailureCallbacks[T](block: => T)
1338      (catchBlock: => Unit = (), finallyBlock: => Unit = ()): T = {
1339    var originalThrowable: Throwable = null
1340    try {
1341      block
1342    } catch {
1343      case cause: Throwable =>
1344        // Purposefully not using NonFatal, because even fatal exceptions
1345        // we don't want to have our finallyBlock suppress
1346        originalThrowable = cause
1347        try {
1348          logError("Aborting task", originalThrowable)
1349          TaskContext.get().asInstanceOf[TaskContextImpl].markTaskFailed(originalThrowable)
1350          catchBlock
1351        } catch {
1352          case t: Throwable =>
1353            originalThrowable.addSuppressed(t)
1354            logWarning(s"Suppressing exception in catch: " + t.getMessage, t)
1355        }
1356        throw originalThrowable
1357    } finally {
1358      try {
1359        finallyBlock
1360      } catch {
1361        case t: Throwable =>
1362          if (originalThrowable != null) {
1363            originalThrowable.addSuppressed(t)
1364            logWarning(s"Suppressing exception in finally: " + t.getMessage, t)
1365            throw originalThrowable
1366          } else {
1367            throw t
1368          }
1369      }
1370    }
1371  }
1372
1373  /** Default filtering function for finding call sites using `getCallSite`. */
1374  private def sparkInternalExclusionFunction(className: String): Boolean = {
1375    // A regular expression to match classes of the internal Spark API's
1376    // that we want to skip when finding the call site of a method.
1377    val SPARK_CORE_CLASS_REGEX =
1378      """^org\.apache\.spark(\.api\.java)?(\.util)?(\.rdd)?(\.broadcast)?\.[A-Z]""".r
1379    val SPARK_SQL_CLASS_REGEX = """^org\.apache\.spark\.sql.*""".r
1380    val SCALA_CORE_CLASS_PREFIX = "scala"
1381    val isSparkClass = SPARK_CORE_CLASS_REGEX.findFirstIn(className).isDefined ||
1382      SPARK_SQL_CLASS_REGEX.findFirstIn(className).isDefined
1383    val isScalaClass = className.startsWith(SCALA_CORE_CLASS_PREFIX)
1384    // If the class is a Spark internal class or a Scala class, then exclude.
1385    isSparkClass || isScalaClass
1386  }
1387
1388  /**
1389   * When called inside a class in the spark package, returns the name of the user code class
1390   * (outside the spark package) that called into Spark, as well as which Spark method they called.
1391   * This is used, for example, to tell users where in their code each RDD got created.
1392   *
1393   * @param skipClass Function that is used to exclude non-user-code classes.
1394   */
1395  def getCallSite(skipClass: String => Boolean = sparkInternalExclusionFunction): CallSite = {
1396    // Keep crawling up the stack trace until we find the first function not inside of the spark
1397    // package. We track the last (shallowest) contiguous Spark method. This might be an RDD
1398    // transformation, a SparkContext function (such as parallelize), or anything else that leads
1399    // to instantiation of an RDD. We also track the first (deepest) user method, file, and line.
1400    var lastSparkMethod = "<unknown>"
1401    var firstUserFile = "<unknown>"
1402    var firstUserLine = 0
1403    var insideSpark = true
1404    var callStack = new ArrayBuffer[String]() :+ "<unknown>"
1405
1406    Thread.currentThread.getStackTrace().foreach { ste: StackTraceElement =>
1407      // When running under some profilers, the current stack trace might contain some bogus
1408      // frames. This is intended to ensure that we don't crash in these situations by
1409      // ignoring any frames that we can't examine.
1410      if (ste != null && ste.getMethodName != null
1411        && !ste.getMethodName.contains("getStackTrace")) {
1412        if (insideSpark) {
1413          if (skipClass(ste.getClassName)) {
1414            lastSparkMethod = if (ste.getMethodName == "<init>") {
1415              // Spark method is a constructor; get its class name
1416              ste.getClassName.substring(ste.getClassName.lastIndexOf('.') + 1)
1417            } else {
1418              ste.getMethodName
1419            }
1420            callStack(0) = ste.toString // Put last Spark method on top of the stack trace.
1421          } else {
1422            if (ste.getFileName != null) {
1423              firstUserFile = ste.getFileName
1424              if (ste.getLineNumber >= 0) {
1425                firstUserLine = ste.getLineNumber
1426              }
1427            }
1428            callStack += ste.toString
1429            insideSpark = false
1430          }
1431        } else {
1432          callStack += ste.toString
1433        }
1434      }
1435    }
1436
1437    val callStackDepth = System.getProperty("spark.callstack.depth", "20").toInt
1438    val shortForm =
1439      if (firstUserFile == "HiveSessionImpl.java") {
1440        // To be more user friendly, show a nicer string for queries submitted from the JDBC
1441        // server.
1442        "Spark JDBC Server Query"
1443      } else {
1444        s"$lastSparkMethod at $firstUserFile:$firstUserLine"
1445      }
1446    val longForm = callStack.take(callStackDepth).mkString("\n")
1447
1448    CallSite(shortForm, longForm)
1449  }
1450
1451  private val UNCOMPRESSED_LOG_FILE_LENGTH_CACHE_SIZE_CONF =
1452    "spark.worker.ui.compressedLogFileLengthCacheSize"
1453  private val DEFAULT_UNCOMPRESSED_LOG_FILE_LENGTH_CACHE_SIZE = 100
1454  private var compressedLogFileLengthCache: LoadingCache[String, java.lang.Long] = null
1455  private def getCompressedLogFileLengthCache(
1456      sparkConf: SparkConf): LoadingCache[String, java.lang.Long] = this.synchronized {
1457    if (compressedLogFileLengthCache == null) {
1458      val compressedLogFileLengthCacheSize = sparkConf.getInt(
1459        UNCOMPRESSED_LOG_FILE_LENGTH_CACHE_SIZE_CONF,
1460        DEFAULT_UNCOMPRESSED_LOG_FILE_LENGTH_CACHE_SIZE)
1461      compressedLogFileLengthCache = CacheBuilder.newBuilder()
1462        .maximumSize(compressedLogFileLengthCacheSize)
1463        .build[String, java.lang.Long](new CacheLoader[String, java.lang.Long]() {
1464        override def load(path: String): java.lang.Long = {
1465          Utils.getCompressedFileLength(new File(path))
1466        }
1467      })
1468    }
1469    compressedLogFileLengthCache
1470  }
1471
1472  /**
1473   * Return the file length, if the file is compressed it returns the uncompressed file length.
1474   * It also caches the uncompressed file size to avoid repeated decompression. The cache size is
1475   * read from workerConf.
1476   */
1477  def getFileLength(file: File, workConf: SparkConf): Long = {
1478    if (file.getName.endsWith(".gz")) {
1479      getCompressedLogFileLengthCache(workConf).get(file.getAbsolutePath)
1480    } else {
1481      file.length
1482    }
1483  }
1484
1485  /** Return uncompressed file length of a compressed file. */
1486  private def getCompressedFileLength(file: File): Long = {
1487    try {
1488      // Uncompress .gz file to determine file size.
1489      var fileSize = 0L
1490      val gzInputStream = new GZIPInputStream(new FileInputStream(file))
1491      val bufSize = 1024
1492      val buf = new Array[Byte](bufSize)
1493      var numBytes = ByteStreams.read(gzInputStream, buf, 0, bufSize)
1494      while (numBytes > 0) {
1495        fileSize += numBytes
1496        numBytes = ByteStreams.read(gzInputStream, buf, 0, bufSize)
1497      }
1498      fileSize
1499    } catch {
1500      case e: Throwable =>
1501        logError(s"Cannot get file length of ${file}", e)
1502        throw e
1503    }
1504  }
1505
1506  /** Return a string containing part of a file from byte 'start' to 'end'. */
1507  def offsetBytes(path: String, length: Long, start: Long, end: Long): String = {
1508    val file = new File(path)
1509    val effectiveEnd = math.min(length, end)
1510    val effectiveStart = math.max(0, start)
1511    val buff = new Array[Byte]((effectiveEnd-effectiveStart).toInt)
1512    val stream = if (path.endsWith(".gz")) {
1513      new GZIPInputStream(new FileInputStream(file))
1514    } else {
1515      new FileInputStream(file)
1516    }
1517
1518    try {
1519      ByteStreams.skipFully(stream, effectiveStart)
1520      ByteStreams.readFully(stream, buff)
1521    } finally {
1522      stream.close()
1523    }
1524    Source.fromBytes(buff).mkString
1525  }
1526
1527  /**
1528   * Return a string containing data across a set of files. The `startIndex`
1529   * and `endIndex` is based on the cumulative size of all the files take in
1530   * the given order. See figure below for more details.
1531   */
1532  def offsetBytes(files: Seq[File], fileLengths: Seq[Long], start: Long, end: Long): String = {
1533    assert(files.length == fileLengths.length)
1534    val startIndex = math.max(start, 0)
1535    val endIndex = math.min(end, fileLengths.sum)
1536    val fileToLength = files.zip(fileLengths).toMap
1537    logDebug("Log files: \n" + fileToLength.mkString("\n"))
1538
1539    val stringBuffer = new StringBuffer((endIndex - startIndex).toInt)
1540    var sum = 0L
1541    files.zip(fileLengths).foreach { case (file, fileLength) =>
1542      val startIndexOfFile = sum
1543      val endIndexOfFile = sum + fileToLength(file)
1544      logDebug(s"Processing file $file, " +
1545        s"with start index = $startIndexOfFile, end index = $endIndex")
1546
1547      /*
1548                                      ____________
1549       range 1:                      |            |
1550                                     |   case A   |
1551
1552       files:   |==== file 1 ====|====== file 2 ======|===== file 3 =====|
1553
1554                     |   case B  .       case C       .    case D    |
1555       range 2:      |___________.____________________.______________|
1556       */
1557
1558      if (startIndex <= startIndexOfFile  && endIndex >= endIndexOfFile) {
1559        // Case C: read the whole file
1560        stringBuffer.append(offsetBytes(file.getAbsolutePath, fileLength, 0, fileToLength(file)))
1561      } else if (startIndex > startIndexOfFile && startIndex < endIndexOfFile) {
1562        // Case A and B: read from [start of required range] to [end of file / end of range]
1563        val effectiveStartIndex = startIndex - startIndexOfFile
1564        val effectiveEndIndex = math.min(endIndex - startIndexOfFile, fileToLength(file))
1565        stringBuffer.append(Utils.offsetBytes(
1566          file.getAbsolutePath, fileLength, effectiveStartIndex, effectiveEndIndex))
1567      } else if (endIndex > startIndexOfFile && endIndex < endIndexOfFile) {
1568        // Case D: read from [start of file] to [end of require range]
1569        val effectiveStartIndex = math.max(startIndex - startIndexOfFile, 0)
1570        val effectiveEndIndex = endIndex - startIndexOfFile
1571        stringBuffer.append(Utils.offsetBytes(
1572          file.getAbsolutePath, fileLength, effectiveStartIndex, effectiveEndIndex))
1573      }
1574      sum += fileToLength(file)
1575      logDebug(s"After processing file $file, string built is ${stringBuffer.toString}")
1576    }
1577    stringBuffer.toString
1578  }
1579
1580  /**
1581   * Clone an object using a Spark serializer.
1582   */
1583  def clone[T: ClassTag](value: T, serializer: SerializerInstance): T = {
1584    serializer.deserialize[T](serializer.serialize(value))
1585  }
1586
1587  private def isSpace(c: Char): Boolean = {
1588    " \t\r\n".indexOf(c) != -1
1589  }
1590
1591  /**
1592   * Split a string of potentially quoted arguments from the command line the way that a shell
1593   * would do it to determine arguments to a command. For example, if the string is 'a "b c" d',
1594   * then it would be parsed as three arguments: 'a', 'b c' and 'd'.
1595   */
1596  def splitCommandString(s: String): Seq[String] = {
1597    val buf = new ArrayBuffer[String]
1598    var inWord = false
1599    var inSingleQuote = false
1600    var inDoubleQuote = false
1601    val curWord = new StringBuilder
1602    def endWord() {
1603      buf += curWord.toString
1604      curWord.clear()
1605    }
1606    var i = 0
1607    while (i < s.length) {
1608      val nextChar = s.charAt(i)
1609      if (inDoubleQuote) {
1610        if (nextChar == '"') {
1611          inDoubleQuote = false
1612        } else if (nextChar == '\\') {
1613          if (i < s.length - 1) {
1614            // Append the next character directly, because only " and \ may be escaped in
1615            // double quotes after the shell's own expansion
1616            curWord.append(s.charAt(i + 1))
1617            i += 1
1618          }
1619        } else {
1620          curWord.append(nextChar)
1621        }
1622      } else if (inSingleQuote) {
1623        if (nextChar == '\'') {
1624          inSingleQuote = false
1625        } else {
1626          curWord.append(nextChar)
1627        }
1628        // Backslashes are not treated specially in single quotes
1629      } else if (nextChar == '"') {
1630        inWord = true
1631        inDoubleQuote = true
1632      } else if (nextChar == '\'') {
1633        inWord = true
1634        inSingleQuote = true
1635      } else if (!isSpace(nextChar)) {
1636        curWord.append(nextChar)
1637        inWord = true
1638      } else if (inWord && isSpace(nextChar)) {
1639        endWord()
1640        inWord = false
1641      }
1642      i += 1
1643    }
1644    if (inWord || inDoubleQuote || inSingleQuote) {
1645      endWord()
1646    }
1647    buf
1648  }
1649
1650 /* Calculates 'x' modulo 'mod', takes to consideration sign of x,
1651  * i.e. if 'x' is negative, than 'x' % 'mod' is negative too
1652  * so function return (x % mod) + mod in that case.
1653  */
1654  def nonNegativeMod(x: Int, mod: Int): Int = {
1655    val rawMod = x % mod
1656    rawMod + (if (rawMod < 0) mod else 0)
1657  }
1658
1659  // Handles idiosyncrasies with hash (add more as required)
1660  // This method should be kept in sync with
1661  // org.apache.spark.network.util.JavaUtils#nonNegativeHash().
1662  def nonNegativeHash(obj: AnyRef): Int = {
1663
1664    // Required ?
1665    if (obj eq null) return 0
1666
1667    val hash = obj.hashCode
1668    // math.abs fails for Int.MinValue
1669    val hashAbs = if (Int.MinValue != hash) math.abs(hash) else 0
1670
1671    // Nothing else to guard against ?
1672    hashAbs
1673  }
1674
1675  /**
1676   * NaN-safe version of `java.lang.Double.compare()` which allows NaN values to be compared
1677   * according to semantics where NaN == NaN and NaN is greater than any non-NaN double.
1678   */
1679  def nanSafeCompareDoubles(x: Double, y: Double): Int = {
1680    val xIsNan: Boolean = java.lang.Double.isNaN(x)
1681    val yIsNan: Boolean = java.lang.Double.isNaN(y)
1682    if ((xIsNan && yIsNan) || (x == y)) 0
1683    else if (xIsNan) 1
1684    else if (yIsNan) -1
1685    else if (x > y) 1
1686    else -1
1687  }
1688
1689  /**
1690   * NaN-safe version of `java.lang.Float.compare()` which allows NaN values to be compared
1691   * according to semantics where NaN == NaN and NaN is greater than any non-NaN float.
1692   */
1693  def nanSafeCompareFloats(x: Float, y: Float): Int = {
1694    val xIsNan: Boolean = java.lang.Float.isNaN(x)
1695    val yIsNan: Boolean = java.lang.Float.isNaN(y)
1696    if ((xIsNan && yIsNan) || (x == y)) 0
1697    else if (xIsNan) 1
1698    else if (yIsNan) -1
1699    else if (x > y) 1
1700    else -1
1701  }
1702
1703  /**
1704   * Returns the system properties map that is thread-safe to iterator over. It gets the
1705   * properties which have been set explicitly, as well as those for which only a default value
1706   * has been defined.
1707   */
1708  def getSystemProperties: Map[String, String] = {
1709    System.getProperties.stringPropertyNames().asScala
1710      .map(key => (key, System.getProperty(key))).toMap
1711  }
1712
1713  /**
1714   * Method executed for repeating a task for side effects.
1715   * Unlike a for comprehension, it permits JVM JIT optimization
1716   */
1717  def times(numIters: Int)(f: => Unit): Unit = {
1718    var i = 0
1719    while (i < numIters) {
1720      f
1721      i += 1
1722    }
1723  }
1724
1725  /**
1726   * Timing method based on iterations that permit JVM JIT optimization.
1727   *
1728   * @param numIters number of iterations
1729   * @param f function to be executed. If prepare is not None, the running time of each call to f
1730   *          must be an order of magnitude longer than one millisecond for accurate timing.
1731   * @param prepare function to be executed before each call to f. Its running time doesn't count.
1732   * @return the total time across all iterations (not counting preparation time)
1733   */
1734  def timeIt(numIters: Int)(f: => Unit, prepare: Option[() => Unit] = None): Long = {
1735    if (prepare.isEmpty) {
1736      val start = System.currentTimeMillis
1737      times(numIters)(f)
1738      System.currentTimeMillis - start
1739    } else {
1740      var i = 0
1741      var sum = 0L
1742      while (i < numIters) {
1743        prepare.get.apply()
1744        val start = System.currentTimeMillis
1745        f
1746        sum += System.currentTimeMillis - start
1747        i += 1
1748      }
1749      sum
1750    }
1751  }
1752
1753  /**
1754   * Counts the number of elements of an iterator using a while loop rather than calling
1755   * [[scala.collection.Iterator#size]] because it uses a for loop, which is slightly slower
1756   * in the current version of Scala.
1757   */
1758  def getIteratorSize[T](iterator: Iterator[T]): Long = {
1759    var count = 0L
1760    while (iterator.hasNext) {
1761      count += 1L
1762      iterator.next()
1763    }
1764    count
1765  }
1766
1767  /**
1768   * Generate a zipWithIndex iterator, avoid index value overflowing problem
1769   * in scala's zipWithIndex
1770   */
1771  def getIteratorZipWithIndex[T](iterator: Iterator[T], startIndex: Long): Iterator[(T, Long)] = {
1772    new Iterator[(T, Long)] {
1773      require(startIndex >= 0, "startIndex should be >= 0.")
1774      var index: Long = startIndex - 1L
1775      def hasNext: Boolean = iterator.hasNext
1776      def next(): (T, Long) = {
1777        index += 1L
1778        (iterator.next(), index)
1779      }
1780    }
1781  }
1782
1783  /**
1784   * Creates a symlink.
1785   *
1786   * @param src absolute path to the source
1787   * @param dst relative path for the destination
1788   */
1789  def symlink(src: File, dst: File): Unit = {
1790    if (!src.isAbsolute()) {
1791      throw new IOException("Source must be absolute")
1792    }
1793    if (dst.isAbsolute()) {
1794      throw new IOException("Destination must be relative")
1795    }
1796    Files.createSymbolicLink(dst.toPath, src.toPath)
1797  }
1798
1799
1800  /** Return the class name of the given object, removing all dollar signs */
1801  def getFormattedClassName(obj: AnyRef): String = {
1802    obj.getClass.getSimpleName.replace("$", "")
1803  }
1804
1805  /** Return an option that translates JNothing to None */
1806  def jsonOption(json: JValue): Option[JValue] = {
1807    json match {
1808      case JNothing => None
1809      case value: JValue => Some(value)
1810    }
1811  }
1812
1813  /** Return an empty JSON object */
1814  def emptyJson: JsonAST.JObject = JObject(List[JField]())
1815
1816  /**
1817   * Return a Hadoop FileSystem with the scheme encoded in the given path.
1818   */
1819  def getHadoopFileSystem(path: URI, conf: Configuration): FileSystem = {
1820    FileSystem.get(path, conf)
1821  }
1822
1823  /**
1824   * Return a Hadoop FileSystem with the scheme encoded in the given path.
1825   */
1826  def getHadoopFileSystem(path: String, conf: Configuration): FileSystem = {
1827    getHadoopFileSystem(new URI(path), conf)
1828  }
1829
1830  /**
1831   * Return the absolute path of a file in the given directory.
1832   */
1833  def getFilePath(dir: File, fileName: String): Path = {
1834    assert(dir.isDirectory)
1835    val path = new File(dir, fileName).getAbsolutePath
1836    new Path(path)
1837  }
1838
1839  /**
1840   * Whether the underlying operating system is Windows.
1841   */
1842  val isWindows = SystemUtils.IS_OS_WINDOWS
1843
1844  /**
1845   * Whether the underlying operating system is Mac OS X.
1846   */
1847  val isMac = SystemUtils.IS_OS_MAC_OSX
1848
1849  /**
1850   * Pattern for matching a Windows drive, which contains only a single alphabet character.
1851   */
1852  val windowsDrive = "([a-zA-Z])".r
1853
1854  /**
1855   * Indicates whether Spark is currently running unit tests.
1856   */
1857  def isTesting: Boolean = {
1858    sys.env.contains("SPARK_TESTING") || sys.props.contains("spark.testing")
1859  }
1860
1861  /**
1862   * Strip the directory from a path name
1863   */
1864  def stripDirectory(path: String): String = {
1865    new File(path).getName
1866  }
1867
1868  /**
1869   * Terminates a process waiting for at most the specified duration.
1870   *
1871   * @return the process exit value if it was successfully terminated, else None
1872   */
1873  def terminateProcess(process: Process, timeoutMs: Long): Option[Int] = {
1874    // Politely destroy first
1875    process.destroy()
1876
1877    if (waitForProcess(process, timeoutMs)) {
1878      // Successful exit
1879      Option(process.exitValue())
1880    } else {
1881      // Java 8 added a new API which will more forcibly kill the process. Use that if available.
1882      try {
1883        classOf[Process].getMethod("destroyForcibly").invoke(process)
1884      } catch {
1885        case _: NoSuchMethodException => return None // Not available; give up
1886        case NonFatal(e) => logWarning("Exception when attempting to kill process", e)
1887      }
1888      // Wait, again, although this really should return almost immediately
1889      if (waitForProcess(process, timeoutMs)) {
1890        Option(process.exitValue())
1891      } else {
1892        logWarning("Timed out waiting to forcibly kill process")
1893        None
1894      }
1895    }
1896  }
1897
1898  /**
1899   * Wait for a process to terminate for at most the specified duration.
1900   *
1901   * @return whether the process actually terminated before the given timeout.
1902   */
1903  def waitForProcess(process: Process, timeoutMs: Long): Boolean = {
1904    try {
1905      // Use Java 8 method if available
1906      classOf[Process].getMethod("waitFor", java.lang.Long.TYPE, classOf[TimeUnit])
1907        .invoke(process, timeoutMs.asInstanceOf[java.lang.Long], TimeUnit.MILLISECONDS)
1908        .asInstanceOf[Boolean]
1909    } catch {
1910      case _: NoSuchMethodException =>
1911        // Otherwise implement it manually
1912        var terminated = false
1913        val startTime = System.currentTimeMillis
1914        while (!terminated) {
1915          try {
1916            process.exitValue()
1917            terminated = true
1918          } catch {
1919            case e: IllegalThreadStateException =>
1920              // Process not terminated yet
1921              if (System.currentTimeMillis - startTime > timeoutMs) {
1922                return false
1923              }
1924              Thread.sleep(100)
1925          }
1926        }
1927        true
1928    }
1929  }
1930
1931  /**
1932   * Return the stderr of a process after waiting for the process to terminate.
1933   * If the process does not terminate within the specified timeout, return None.
1934   */
1935  def getStderr(process: Process, timeoutMs: Long): Option[String] = {
1936    val terminated = Utils.waitForProcess(process, timeoutMs)
1937    if (terminated) {
1938      Some(Source.fromInputStream(process.getErrorStream).getLines().mkString("\n"))
1939    } else {
1940      None
1941    }
1942  }
1943
1944  /**
1945   * Execute the given block, logging and re-throwing any uncaught exception.
1946   * This is particularly useful for wrapping code that runs in a thread, to ensure
1947   * that exceptions are printed, and to avoid having to catch Throwable.
1948   */
1949  def logUncaughtExceptions[T](f: => T): T = {
1950    try {
1951      f
1952    } catch {
1953      case ct: ControlThrowable =>
1954        throw ct
1955      case t: Throwable =>
1956        logError(s"Uncaught exception in thread ${Thread.currentThread().getName}", t)
1957        throw t
1958    }
1959  }
1960
1961  /** Executes the given block in a Try, logging any uncaught exceptions. */
1962  def tryLog[T](f: => T): Try[T] = {
1963    try {
1964      val res = f
1965      scala.util.Success(res)
1966    } catch {
1967      case ct: ControlThrowable =>
1968        throw ct
1969      case t: Throwable =>
1970        logError(s"Uncaught exception in thread ${Thread.currentThread().getName}", t)
1971        scala.util.Failure(t)
1972    }
1973  }
1974
1975  /** Returns true if the given exception was fatal. See docs for scala.util.control.NonFatal. */
1976  def isFatalError(e: Throwable): Boolean = {
1977    e match {
1978      case NonFatal(_) |
1979           _: InterruptedException |
1980           _: NotImplementedError |
1981           _: ControlThrowable |
1982           _: LinkageError =>
1983        false
1984      case _ =>
1985        true
1986    }
1987  }
1988
1989  /**
1990   * Return a well-formed URI for the file described by a user input string.
1991   *
1992   * If the supplied path does not contain a scheme, or is a relative path, it will be
1993   * converted into an absolute path with a file:// scheme.
1994   */
1995  def resolveURI(path: String): URI = {
1996    try {
1997      val uri = new URI(path)
1998      if (uri.getScheme() != null) {
1999        return uri
2000      }
2001      // make sure to handle if the path has a fragment (applies to yarn
2002      // distributed cache)
2003      if (uri.getFragment() != null) {
2004        val absoluteURI = new File(uri.getPath()).getAbsoluteFile().toURI()
2005        return new URI(absoluteURI.getScheme(), absoluteURI.getHost(), absoluteURI.getPath(),
2006          uri.getFragment())
2007      }
2008    } catch {
2009      case e: URISyntaxException =>
2010    }
2011    new File(path).getAbsoluteFile().toURI()
2012  }
2013
2014  /** Resolve a comma-separated list of paths. */
2015  def resolveURIs(paths: String): String = {
2016    if (paths == null || paths.trim.isEmpty) {
2017      ""
2018    } else {
2019      paths.split(",").filter(_.trim.nonEmpty).map { p => Utils.resolveURI(p) }.mkString(",")
2020    }
2021  }
2022
2023  /** Return all non-local paths from a comma-separated list of paths. */
2024  def nonLocalPaths(paths: String, testWindows: Boolean = false): Array[String] = {
2025    val windows = isWindows || testWindows
2026    if (paths == null || paths.trim.isEmpty) {
2027      Array.empty
2028    } else {
2029      paths.split(",").filter { p =>
2030        val uri = resolveURI(p)
2031        Option(uri.getScheme).getOrElse("file") match {
2032          case windowsDrive(d) if windows => false
2033          case "local" | "file" => false
2034          case _ => true
2035        }
2036      }
2037    }
2038  }
2039
2040  /**
2041   * Load default Spark properties from the given file. If no file is provided,
2042   * use the common defaults file. This mutates state in the given SparkConf and
2043   * in this JVM's system properties if the config specified in the file is not
2044   * already set. Return the path of the properties file used.
2045   */
2046  def loadDefaultSparkProperties(conf: SparkConf, filePath: String = null): String = {
2047    val path = Option(filePath).getOrElse(getDefaultPropertiesFile())
2048    Option(path).foreach { confFile =>
2049      getPropertiesFromFile(confFile).filter { case (k, v) =>
2050        k.startsWith("spark.")
2051      }.foreach { case (k, v) =>
2052        conf.setIfMissing(k, v)
2053        sys.props.getOrElseUpdate(k, v)
2054      }
2055    }
2056    path
2057  }
2058
2059  /** Load properties present in the given file. */
2060  def getPropertiesFromFile(filename: String): Map[String, String] = {
2061    val file = new File(filename)
2062    require(file.exists(), s"Properties file $file does not exist")
2063    require(file.isFile(), s"Properties file $file is not a normal file")
2064
2065    val inReader = new InputStreamReader(new FileInputStream(file), StandardCharsets.UTF_8)
2066    try {
2067      val properties = new Properties()
2068      properties.load(inReader)
2069      properties.stringPropertyNames().asScala.map(
2070        k => (k, properties.getProperty(k).trim)).toMap
2071    } catch {
2072      case e: IOException =>
2073        throw new SparkException(s"Failed when loading Spark properties from $filename", e)
2074    } finally {
2075      inReader.close()
2076    }
2077  }
2078
2079  /** Return the path of the default Spark properties file. */
2080  def getDefaultPropertiesFile(env: Map[String, String] = sys.env): String = {
2081    env.get("SPARK_CONF_DIR")
2082      .orElse(env.get("SPARK_HOME").map { t => s"$t${File.separator}conf" })
2083      .map { t => new File(s"$t${File.separator}spark-defaults.conf")}
2084      .filter(_.isFile)
2085      .map(_.getAbsolutePath)
2086      .orNull
2087  }
2088
2089  /**
2090   * Return a nice string representation of the exception. It will call "printStackTrace" to
2091   * recursively generate the stack trace including the exception and its causes.
2092   */
2093  def exceptionString(e: Throwable): String = {
2094    if (e == null) {
2095      ""
2096    } else {
2097      // Use e.printStackTrace here because e.getStackTrace doesn't include the cause
2098      val stringWriter = new StringWriter()
2099      e.printStackTrace(new PrintWriter(stringWriter))
2100      stringWriter.toString
2101    }
2102  }
2103
2104  private implicit class Lock(lock: LockInfo) {
2105    def lockString: String = {
2106      lock match {
2107        case monitor: MonitorInfo =>
2108          s"Monitor(${lock.getClassName}@${lock.getIdentityHashCode}})"
2109        case _ =>
2110          s"Lock(${lock.getClassName}@${lock.getIdentityHashCode}})"
2111      }
2112    }
2113  }
2114
2115  /** Return a thread dump of all threads' stacktraces.  Used to capture dumps for the web UI */
2116  def getThreadDump(): Array[ThreadStackTrace] = {
2117    // We need to filter out null values here because dumpAllThreads() may return null array
2118    // elements for threads that are dead / don't exist.
2119    val threadInfos = ManagementFactory.getThreadMXBean.dumpAllThreads(true, true).filter(_ != null)
2120    threadInfos.sortBy(_.getThreadId).map(threadInfoToThreadStackTrace)
2121  }
2122
2123  def getThreadDumpForThread(threadId: Long): Option[ThreadStackTrace] = {
2124    if (threadId <= 0) {
2125      None
2126    } else {
2127      // The Int.MaxValue here requests the entire untruncated stack trace of the thread:
2128      val threadInfo =
2129        Option(ManagementFactory.getThreadMXBean.getThreadInfo(threadId, Int.MaxValue))
2130      threadInfo.map(threadInfoToThreadStackTrace)
2131    }
2132  }
2133
2134  private def threadInfoToThreadStackTrace(threadInfo: ThreadInfo): ThreadStackTrace = {
2135    val monitors = threadInfo.getLockedMonitors.map(m => m.getLockedStackFrame -> m).toMap
2136    val stackTrace = threadInfo.getStackTrace.map { frame =>
2137      monitors.get(frame) match {
2138        case Some(monitor) =>
2139          monitor.getLockedStackFrame.toString + s" => holding ${monitor.lockString}"
2140        case None =>
2141          frame.toString
2142      }
2143    }.mkString("\n")
2144
2145    // use a set to dedup re-entrant locks that are held at multiple places
2146    val heldLocks =
2147      (threadInfo.getLockedSynchronizers ++ threadInfo.getLockedMonitors).map(_.lockString).toSet
2148
2149    ThreadStackTrace(
2150      threadId = threadInfo.getThreadId,
2151      threadName = threadInfo.getThreadName,
2152      threadState = threadInfo.getThreadState,
2153      stackTrace = stackTrace,
2154      blockedByThreadId =
2155        if (threadInfo.getLockOwnerId < 0) None else Some(threadInfo.getLockOwnerId),
2156      blockedByLock = Option(threadInfo.getLockInfo).map(_.lockString).getOrElse(""),
2157      holdingLocks = heldLocks.toSeq)
2158  }
2159
2160  /**
2161   * Convert all spark properties set in the given SparkConf to a sequence of java options.
2162   */
2163  def sparkJavaOpts(conf: SparkConf, filterKey: (String => Boolean) = _ => true): Seq[String] = {
2164    conf.getAll
2165      .filter { case (k, _) => filterKey(k) }
2166      .map { case (k, v) => s"-D$k=$v" }
2167  }
2168
2169  /**
2170   * Maximum number of retries when binding to a port before giving up.
2171   */
2172  def portMaxRetries(conf: SparkConf): Int = {
2173    val maxRetries = conf.getOption("spark.port.maxRetries").map(_.toInt)
2174    if (conf.contains("spark.testing")) {
2175      // Set a higher number of retries for tests...
2176      maxRetries.getOrElse(100)
2177    } else {
2178      maxRetries.getOrElse(16)
2179    }
2180  }
2181
2182  /**
2183   * Attempt to start a service on the given port, or fail after a number of attempts.
2184   * Each subsequent attempt uses 1 + the port used in the previous attempt (unless the port is 0).
2185   *
2186   * @param startPort The initial port to start the service on.
2187   * @param startService Function to start service on a given port.
2188   *                     This is expected to throw java.net.BindException on port collision.
2189   * @param conf A SparkConf used to get the maximum number of retries when binding to a port.
2190   * @param serviceName Name of the service.
2191   * @return (service: T, port: Int)
2192   */
2193  def startServiceOnPort[T](
2194      startPort: Int,
2195      startService: Int => (T, Int),
2196      conf: SparkConf,
2197      serviceName: String = ""): (T, Int) = {
2198
2199    require(startPort == 0 || (1024 <= startPort && startPort < 65536),
2200      "startPort should be between 1024 and 65535 (inclusive), or 0 for a random free port.")
2201
2202    val serviceString = if (serviceName.isEmpty) "" else s" '$serviceName'"
2203    val maxRetries = portMaxRetries(conf)
2204    for (offset <- 0 to maxRetries) {
2205      // Do not increment port if startPort is 0, which is treated as a special port
2206      val tryPort = if (startPort == 0) {
2207        startPort
2208      } else {
2209        // If the new port wraps around, do not try a privilege port
2210        ((startPort + offset - 1024) % (65536 - 1024)) + 1024
2211      }
2212      try {
2213        val (service, port) = startService(tryPort)
2214        logInfo(s"Successfully started service$serviceString on port $port.")
2215        return (service, port)
2216      } catch {
2217        case e: Exception if isBindCollision(e) =>
2218          if (offset >= maxRetries) {
2219            val exceptionMessage = s"${e.getMessage}: Service$serviceString failed after " +
2220              s"$maxRetries retries (starting from $startPort)! Consider explicitly setting " +
2221              s"the appropriate port for the service$serviceString (for example spark.ui.port " +
2222              s"for SparkUI) to an available port or increasing spark.port.maxRetries."
2223            val exception = new BindException(exceptionMessage)
2224            // restore original stack trace
2225            exception.setStackTrace(e.getStackTrace)
2226            throw exception
2227          }
2228          logWarning(s"Service$serviceString could not bind on port $tryPort. " +
2229            s"Attempting port ${tryPort + 1}.")
2230      }
2231    }
2232    // Should never happen
2233    throw new SparkException(s"Failed to start service$serviceString on port $startPort")
2234  }
2235
2236  /**
2237   * Return whether the exception is caused by an address-port collision when binding.
2238   */
2239  def isBindCollision(exception: Throwable): Boolean = {
2240    exception match {
2241      case e: BindException =>
2242        if (e.getMessage != null) {
2243          return true
2244        }
2245        isBindCollision(e.getCause)
2246      case e: MultiException =>
2247        e.getThrowables.asScala.exists(isBindCollision)
2248      case e: NativeIoException =>
2249        (e.getMessage != null && e.getMessage.startsWith("bind() failed: ")) ||
2250          isBindCollision(e.getCause)
2251      case e: Exception => isBindCollision(e.getCause)
2252      case _ => false
2253    }
2254  }
2255
2256  /**
2257   * configure a new log4j level
2258   */
2259  def setLogLevel(l: org.apache.log4j.Level) {
2260    org.apache.log4j.Logger.getRootLogger().setLevel(l)
2261  }
2262
2263  /**
2264   * config a log4j properties used for testsuite
2265   */
2266  def configTestLog4j(level: String): Unit = {
2267    val pro = new Properties()
2268    pro.put("log4j.rootLogger", s"$level, console")
2269    pro.put("log4j.appender.console", "org.apache.log4j.ConsoleAppender")
2270    pro.put("log4j.appender.console.target", "System.err")
2271    pro.put("log4j.appender.console.layout", "org.apache.log4j.PatternLayout")
2272    pro.put("log4j.appender.console.layout.ConversionPattern",
2273      "%d{yy/MM/dd HH:mm:ss} %p %c{1}: %m%n")
2274    PropertyConfigurator.configure(pro)
2275  }
2276
2277  /**
2278   * If the given URL connection is HttpsURLConnection, it sets the SSL socket factory and
2279   * the host verifier from the given security manager.
2280   */
2281  def setupSecureURLConnection(urlConnection: URLConnection, sm: SecurityManager): URLConnection = {
2282    urlConnection match {
2283      case https: HttpsURLConnection =>
2284        sm.sslSocketFactory.foreach(https.setSSLSocketFactory)
2285        sm.hostnameVerifier.foreach(https.setHostnameVerifier)
2286        https
2287      case connection => connection
2288    }
2289  }
2290
2291  def invoke(
2292      clazz: Class[_],
2293      obj: AnyRef,
2294      methodName: String,
2295      args: (Class[_], AnyRef)*): AnyRef = {
2296    val (types, values) = args.unzip
2297    val method = clazz.getDeclaredMethod(methodName, types: _*)
2298    method.setAccessible(true)
2299    method.invoke(obj, values.toSeq: _*)
2300  }
2301
2302  // Limit of bytes for total size of results (default is 1GB)
2303  def getMaxResultSize(conf: SparkConf): Long = {
2304    memoryStringToMb(conf.get("spark.driver.maxResultSize", "1g")).toLong << 20
2305  }
2306
2307  /**
2308   * Return the current system LD_LIBRARY_PATH name
2309   */
2310  def libraryPathEnvName: String = {
2311    if (isWindows) {
2312      "PATH"
2313    } else if (isMac) {
2314      "DYLD_LIBRARY_PATH"
2315    } else {
2316      "LD_LIBRARY_PATH"
2317    }
2318  }
2319
2320  /**
2321   * Return the prefix of a command that appends the given library paths to the
2322   * system-specific library path environment variable. On Unix, for instance,
2323   * this returns the string LD_LIBRARY_PATH="path1:path2:$LD_LIBRARY_PATH".
2324   */
2325  def libraryPathEnvPrefix(libraryPaths: Seq[String]): String = {
2326    val libraryPathScriptVar = if (isWindows) {
2327      s"%${libraryPathEnvName}%"
2328    } else {
2329      "$" + libraryPathEnvName
2330    }
2331    val libraryPath = (libraryPaths :+ libraryPathScriptVar).mkString("\"",
2332      File.pathSeparator, "\"")
2333    val ampersand = if (Utils.isWindows) {
2334      " &"
2335    } else {
2336      ""
2337    }
2338    s"$libraryPathEnvName=$libraryPath$ampersand"
2339  }
2340
2341  /**
2342   * Return the value of a config either through the SparkConf or the Hadoop configuration
2343   * if this is Yarn mode. In the latter case, this defaults to the value set through SparkConf
2344   * if the key is not set in the Hadoop configuration.
2345   */
2346  def getSparkOrYarnConfig(conf: SparkConf, key: String, default: String): String = {
2347    val sparkValue = conf.get(key, default)
2348    if (SparkHadoopUtil.get.isYarnMode) {
2349      SparkHadoopUtil.get.newConfiguration(conf).get(key, sparkValue)
2350    } else {
2351      sparkValue
2352    }
2353  }
2354
2355  /**
2356   * Return a pair of host and port extracted from the `sparkUrl`.
2357   *
2358   * A spark url (`spark://host:port`) is a special URI that its scheme is `spark` and only contains
2359   * host and port.
2360   *
2361   * @throws org.apache.spark.SparkException if sparkUrl is invalid.
2362   */
2363  @throws(classOf[SparkException])
2364  def extractHostPortFromSparkUrl(sparkUrl: String): (String, Int) = {
2365    try {
2366      val uri = new java.net.URI(sparkUrl)
2367      val host = uri.getHost
2368      val port = uri.getPort
2369      if (uri.getScheme != "spark" ||
2370        host == null ||
2371        port < 0 ||
2372        (uri.getPath != null && !uri.getPath.isEmpty) || // uri.getPath returns "" instead of null
2373        uri.getFragment != null ||
2374        uri.getQuery != null ||
2375        uri.getUserInfo != null) {
2376        throw new SparkException("Invalid master URL: " + sparkUrl)
2377      }
2378      (host, port)
2379    } catch {
2380      case e: java.net.URISyntaxException =>
2381        throw new SparkException("Invalid master URL: " + sparkUrl, e)
2382    }
2383  }
2384
2385  /**
2386   * Returns the current user name. This is the currently logged in user, unless that's been
2387   * overridden by the `SPARK_USER` environment variable.
2388   */
2389  def getCurrentUserName(): String = {
2390    Option(System.getenv("SPARK_USER"))
2391      .getOrElse(UserGroupInformation.getCurrentUser().getShortUserName())
2392  }
2393
2394  val EMPTY_USER_GROUPS = Set[String]()
2395
2396  // Returns the groups to which the current user belongs.
2397  def getCurrentUserGroups(sparkConf: SparkConf, username: String): Set[String] = {
2398    val groupProviderClassName = sparkConf.get("spark.user.groups.mapping",
2399      "org.apache.spark.security.ShellBasedGroupsMappingProvider")
2400    if (groupProviderClassName != "") {
2401      try {
2402        val groupMappingServiceProvider = classForName(groupProviderClassName).newInstance.
2403          asInstanceOf[org.apache.spark.security.GroupMappingServiceProvider]
2404        val currentUserGroups = groupMappingServiceProvider.getGroups(username)
2405        return currentUserGroups
2406      } catch {
2407        case e: Exception => logError(s"Error getting groups for user=$username", e)
2408      }
2409    }
2410    EMPTY_USER_GROUPS
2411  }
2412
2413  /**
2414   * Split the comma delimited string of master URLs into a list.
2415   * For instance, "spark://abc,def" becomes [spark://abc, spark://def].
2416   */
2417  def parseStandaloneMasterUrls(masterUrls: String): Array[String] = {
2418    masterUrls.stripPrefix("spark://").split(",").map("spark://" + _)
2419  }
2420
2421  /** An identifier that backup masters use in their responses. */
2422  val BACKUP_STANDALONE_MASTER_PREFIX = "Current state is not alive"
2423
2424  /** Return true if the response message is sent from a backup Master on standby. */
2425  def responseFromBackup(msg: String): Boolean = {
2426    msg.startsWith(BACKUP_STANDALONE_MASTER_PREFIX)
2427  }
2428
2429  /**
2430   * To avoid calling `Utils.getCallSite` for every single RDD we create in the body,
2431   * set a dummy call site that RDDs use instead. This is for performance optimization.
2432   */
2433  def withDummyCallSite[T](sc: SparkContext)(body: => T): T = {
2434    val oldShortCallSite = sc.getLocalProperty(CallSite.SHORT_FORM)
2435    val oldLongCallSite = sc.getLocalProperty(CallSite.LONG_FORM)
2436    try {
2437      sc.setLocalProperty(CallSite.SHORT_FORM, "")
2438      sc.setLocalProperty(CallSite.LONG_FORM, "")
2439      body
2440    } finally {
2441      // Restore the old ones here
2442      sc.setLocalProperty(CallSite.SHORT_FORM, oldShortCallSite)
2443      sc.setLocalProperty(CallSite.LONG_FORM, oldLongCallSite)
2444    }
2445  }
2446
2447  /**
2448   * Return whether the specified file is a parent directory of the child file.
2449   */
2450  @tailrec
2451  def isInDirectory(parent: File, child: File): Boolean = {
2452    if (child == null || parent == null) {
2453      return false
2454    }
2455    if (!child.exists() || !parent.exists() || !parent.isDirectory()) {
2456      return false
2457    }
2458    if (parent.equals(child)) {
2459      return true
2460    }
2461    isInDirectory(parent, child.getParentFile)
2462  }
2463
2464
2465  /**
2466   *
2467   * @return whether it is local mode
2468   */
2469  def isLocalMaster(conf: SparkConf): Boolean = {
2470    val master = conf.get("spark.master", "")
2471    master == "local" || master.startsWith("local[")
2472  }
2473
2474  /**
2475   * Return whether dynamic allocation is enabled in the given conf.
2476   */
2477  def isDynamicAllocationEnabled(conf: SparkConf): Boolean = {
2478    val dynamicAllocationEnabled = conf.getBoolean("spark.dynamicAllocation.enabled", false)
2479    dynamicAllocationEnabled &&
2480      (!isLocalMaster(conf) || conf.getBoolean("spark.dynamicAllocation.testing", false))
2481  }
2482
2483  /**
2484   * Return the initial number of executors for dynamic allocation.
2485   */
2486  def getDynamicAllocationInitialExecutors(conf: SparkConf): Int = {
2487    if (conf.get(DYN_ALLOCATION_INITIAL_EXECUTORS) < conf.get(DYN_ALLOCATION_MIN_EXECUTORS)) {
2488      logWarning(s"${DYN_ALLOCATION_INITIAL_EXECUTORS.key} less than " +
2489        s"${DYN_ALLOCATION_MIN_EXECUTORS.key} is invalid, ignoring its setting, " +
2490          "please update your configs.")
2491    }
2492
2493    if (conf.get(EXECUTOR_INSTANCES).getOrElse(0) < conf.get(DYN_ALLOCATION_MIN_EXECUTORS)) {
2494      logWarning(s"${EXECUTOR_INSTANCES.key} less than " +
2495        s"${DYN_ALLOCATION_MIN_EXECUTORS.key} is invalid, ignoring its setting, " +
2496          "please update your configs.")
2497    }
2498
2499    val initialExecutors = Seq(
2500      conf.get(DYN_ALLOCATION_MIN_EXECUTORS),
2501      conf.get(DYN_ALLOCATION_INITIAL_EXECUTORS),
2502      conf.get(EXECUTOR_INSTANCES).getOrElse(0)).max
2503
2504    logInfo(s"Using initial executors = $initialExecutors, max of " +
2505      s"${DYN_ALLOCATION_INITIAL_EXECUTORS.key}, ${DYN_ALLOCATION_MIN_EXECUTORS.key} and " +
2506        s"${EXECUTOR_INSTANCES.key}")
2507    initialExecutors
2508  }
2509
2510  def tryWithResource[R <: Closeable, T](createResource: => R)(f: R => T): T = {
2511    val resource = createResource
2512    try f.apply(resource) finally resource.close()
2513  }
2514
2515  /**
2516   * Returns a path of temporary file which is in the same directory with `path`.
2517   */
2518  def tempFileWith(path: File): File = {
2519    new File(path.getAbsolutePath + "." + UUID.randomUUID())
2520  }
2521
2522  /**
2523   * Returns the name of this JVM process. This is OS dependent but typically (OSX, Linux, Windows),
2524   * this is formatted as PID@hostname.
2525   */
2526  def getProcessName(): String = {
2527    ManagementFactory.getRuntimeMXBean().getName()
2528  }
2529
2530  /**
2531   * Utility function that should be called early in `main()` for daemons to set up some common
2532   * diagnostic state.
2533   */
2534  def initDaemon(log: Logger): Unit = {
2535    log.info(s"Started daemon with process name: ${Utils.getProcessName()}")
2536    SignalUtils.registerLogger(log)
2537  }
2538
2539  /**
2540   * Unions two comma-separated lists of files and filters out empty strings.
2541   */
2542  def unionFileLists(leftList: Option[String], rightList: Option[String]): Set[String] = {
2543    var allFiles = Set[String]()
2544    leftList.foreach { value => allFiles ++= value.split(",") }
2545    rightList.foreach { value => allFiles ++= value.split(",") }
2546    allFiles.filter { _.nonEmpty }
2547  }
2548
2549  /**
2550   * In YARN mode this method returns a union of the jar files pointed by "spark.jars" and the
2551   * "spark.yarn.dist.jars" properties, while in other modes it returns the jar files pointed by
2552   * only the "spark.jars" property.
2553   */
2554  def getUserJars(conf: SparkConf, isShell: Boolean = false): Seq[String] = {
2555    val sparkJars = conf.getOption("spark.jars")
2556    if (conf.get("spark.master") == "yarn" && isShell) {
2557      val yarnJars = conf.getOption("spark.yarn.dist.jars")
2558      unionFileLists(sparkJars, yarnJars).toSeq
2559    } else {
2560      sparkJars.map(_.split(",")).map(_.filter(_.nonEmpty)).toSeq.flatten
2561    }
2562  }
2563}
2564
2565private[util] object CallerContext extends Logging {
2566  val callerContextSupported: Boolean = {
2567    SparkHadoopUtil.get.conf.getBoolean("hadoop.caller.context.enabled", false) && {
2568      try {
2569        Utils.classForName("org.apache.hadoop.ipc.CallerContext")
2570        Utils.classForName("org.apache.hadoop.ipc.CallerContext$Builder")
2571        true
2572      } catch {
2573        case _: ClassNotFoundException =>
2574          false
2575        case NonFatal(e) =>
2576          logWarning("Fail to load the CallerContext class", e)
2577          false
2578      }
2579    }
2580  }
2581}
2582
2583/**
2584 * An utility class used to set up Spark caller contexts to HDFS and Yarn. The `context` will be
2585 * constructed by parameters passed in.
2586 * When Spark applications run on Yarn and HDFS, its caller contexts will be written into Yarn RM
2587 * audit log and hdfs-audit.log. That can help users to better diagnose and understand how
2588 * specific applications impacting parts of the Hadoop system and potential problems they may be
2589 * creating (e.g. overloading NN). As HDFS mentioned in HDFS-9184, for a given HDFS operation, it's
2590 * very helpful to track which upper level job issues it.
2591 *
2592 * @param from who sets up the caller context (TASK, CLIENT, APPMASTER)
2593 *
2594 * The parameters below are optional:
2595 * @param appId id of the app this task belongs to
2596 * @param appAttemptId attempt id of the app this task belongs to
2597 * @param jobId id of the job this task belongs to
2598 * @param stageId id of the stage this task belongs to
2599 * @param stageAttemptId attempt id of the stage this task belongs to
2600 * @param taskId task id
2601 * @param taskAttemptNumber task attempt id
2602 */
2603private[spark] class CallerContext(
2604   from: String,
2605   appId: Option[String] = None,
2606   appAttemptId: Option[String] = None,
2607   jobId: Option[Int] = None,
2608   stageId: Option[Int] = None,
2609   stageAttemptId: Option[Int] = None,
2610   taskId: Option[Long] = None,
2611   taskAttemptNumber: Option[Int] = None) extends Logging {
2612
2613   val appIdStr = if (appId.isDefined) s"_${appId.get}" else ""
2614   val appAttemptIdStr = if (appAttemptId.isDefined) s"_${appAttemptId.get}" else ""
2615   val jobIdStr = if (jobId.isDefined) s"_JId_${jobId.get}" else ""
2616   val stageIdStr = if (stageId.isDefined) s"_SId_${stageId.get}" else ""
2617   val stageAttemptIdStr = if (stageAttemptId.isDefined) s"_${stageAttemptId.get}" else ""
2618   val taskIdStr = if (taskId.isDefined) s"_TId_${taskId.get}" else ""
2619   val taskAttemptNumberStr =
2620     if (taskAttemptNumber.isDefined) s"_${taskAttemptNumber.get}" else ""
2621
2622   val context = "SPARK_" + from + appIdStr + appAttemptIdStr +
2623     jobIdStr + stageIdStr + stageAttemptIdStr + taskIdStr + taskAttemptNumberStr
2624
2625  /**
2626   * Set up the caller context [[context]] by invoking Hadoop CallerContext API of
2627   * [[org.apache.hadoop.ipc.CallerContext]], which was added in hadoop 2.8.
2628   */
2629  def setCurrentContext(): Unit = {
2630    if (CallerContext.callerContextSupported) {
2631      try {
2632        val callerContext = Utils.classForName("org.apache.hadoop.ipc.CallerContext")
2633        val builder = Utils.classForName("org.apache.hadoop.ipc.CallerContext$Builder")
2634        val builderInst = builder.getConstructor(classOf[String]).newInstance(context)
2635        val hdfsContext = builder.getMethod("build").invoke(builderInst)
2636        callerContext.getMethod("setCurrent", callerContext).invoke(null, hdfsContext)
2637      } catch {
2638        case NonFatal(e) =>
2639          logWarning("Fail to set Spark caller context", e)
2640      }
2641    }
2642  }
2643}
2644
2645/**
2646 * A utility class to redirect the child process's stdout or stderr.
2647 */
2648private[spark] class RedirectThread(
2649    in: InputStream,
2650    out: OutputStream,
2651    name: String,
2652    propagateEof: Boolean = false)
2653  extends Thread(name) {
2654
2655  setDaemon(true)
2656  override def run() {
2657    scala.util.control.Exception.ignoring(classOf[IOException]) {
2658      // FIXME: We copy the stream on the level of bytes to avoid encoding problems.
2659      Utils.tryWithSafeFinally {
2660        val buf = new Array[Byte](1024)
2661        var len = in.read(buf)
2662        while (len != -1) {
2663          out.write(buf, 0, len)
2664          out.flush()
2665          len = in.read(buf)
2666        }
2667      } {
2668        if (propagateEof) {
2669          out.close()
2670        }
2671      }
2672    }
2673  }
2674}
2675
2676/**
2677 * An [[OutputStream]] that will store the last 10 kilobytes (by default) written to it
2678 * in a circular buffer. The current contents of the buffer can be accessed using
2679 * the toString method.
2680 */
2681private[spark] class CircularBuffer(sizeInBytes: Int = 10240) extends java.io.OutputStream {
2682  private var pos: Int = 0
2683  private var isBufferFull = false
2684  private val buffer = new Array[Byte](sizeInBytes)
2685
2686  def write(input: Int): Unit = {
2687    buffer(pos) = input.toByte
2688    pos = (pos + 1) % buffer.length
2689    isBufferFull = isBufferFull || (pos == 0)
2690  }
2691
2692  override def toString: String = {
2693    if (!isBufferFull) {
2694      return new String(buffer, 0, pos, StandardCharsets.UTF_8)
2695    }
2696
2697    val nonCircularBuffer = new Array[Byte](sizeInBytes)
2698    System.arraycopy(buffer, pos, nonCircularBuffer, 0, buffer.length - pos)
2699    System.arraycopy(buffer, 0, nonCircularBuffer, buffer.length - pos, pos)
2700    new String(nonCircularBuffer, StandardCharsets.UTF_8)
2701  }
2702}
2703