1/* 2 * Licensed to the Apache Software Foundation (ASF) under one or more 3 * contributor license agreements. See the NOTICE file distributed with 4 * this work for additional information regarding copyright ownership. 5 * The ASF licenses this file to You under the Apache License, Version 2.0 6 * (the "License"); you may not use this file except in compliance with 7 * the License. You may obtain a copy of the License at 8 * 9 * http://www.apache.org/licenses/LICENSE-2.0 10 * 11 * Unless required by applicable law or agreed to in writing, software 12 * distributed under the License is distributed on an "AS IS" BASIS, 13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 * See the License for the specific language governing permissions and 15 * limitations under the License. 16 */ 17 18package org.apache.spark.util 19 20import java.io._ 21import java.lang.management.{LockInfo, ManagementFactory, MonitorInfo, ThreadInfo} 22import java.net._ 23import java.nio.ByteBuffer 24import java.nio.channels.Channels 25import java.nio.charset.StandardCharsets 26import java.nio.file.{Files, Paths} 27import java.util.{Locale, Properties, Random, UUID} 28import java.util.concurrent._ 29import java.util.concurrent.atomic.AtomicBoolean 30import java.util.zip.GZIPInputStream 31import javax.net.ssl.HttpsURLConnection 32 33import scala.annotation.tailrec 34import scala.collection.JavaConverters._ 35import scala.collection.Map 36import scala.collection.mutable.ArrayBuffer 37import scala.io.Source 38import scala.reflect.ClassTag 39import scala.util.Try 40import scala.util.control.{ControlThrowable, NonFatal} 41 42import _root_.io.netty.channel.unix.Errors.NativeIoException 43import com.google.common.cache.{CacheBuilder, CacheLoader, LoadingCache} 44import com.google.common.io.{ByteStreams, Files => GFiles} 45import com.google.common.net.InetAddresses 46import org.apache.commons.lang3.SystemUtils 47import org.apache.hadoop.conf.Configuration 48import org.apache.hadoop.fs.{FileSystem, FileUtil, Path} 49import org.apache.hadoop.security.UserGroupInformation 50import org.apache.log4j.PropertyConfigurator 51import org.eclipse.jetty.util.MultiException 52import org.json4s._ 53import org.slf4j.Logger 54 55import org.apache.spark._ 56import org.apache.spark.deploy.SparkHadoopUtil 57import org.apache.spark.internal.Logging 58import org.apache.spark.internal.config.{DYN_ALLOCATION_INITIAL_EXECUTORS, DYN_ALLOCATION_MIN_EXECUTORS, EXECUTOR_INSTANCES} 59import org.apache.spark.network.util.JavaUtils 60import org.apache.spark.serializer.{DeserializationStream, SerializationStream, SerializerInstance} 61import org.apache.spark.util.logging.RollingFileAppender 62 63/** CallSite represents a place in user code. It can have a short and a long form. */ 64private[spark] case class CallSite(shortForm: String, longForm: String) 65 66private[spark] object CallSite { 67 val SHORT_FORM = "callSite.short" 68 val LONG_FORM = "callSite.long" 69 val empty = CallSite("", "") 70} 71 72/** 73 * Various utility methods used by Spark. 74 */ 75private[spark] object Utils extends Logging { 76 val random = new Random() 77 78 /** 79 * Define a default value for driver memory here since this value is referenced across the code 80 * base and nearly all files already use Utils.scala 81 */ 82 val DEFAULT_DRIVER_MEM_MB = JavaUtils.DEFAULT_DRIVER_MEM_MB.toInt 83 84 private val MAX_DIR_CREATION_ATTEMPTS: Int = 10 85 @volatile private var localRootDirs: Array[String] = null 86 87 /** 88 * The performance overhead of creating and logging strings for wide schemas can be large. To 89 * limit the impact, we bound the number of fields to include by default. This can be overridden 90 * by setting the 'spark.debug.maxToStringFields' conf in SparkEnv. 91 */ 92 val DEFAULT_MAX_TO_STRING_FIELDS = 25 93 94 private def maxNumToStringFields = { 95 if (SparkEnv.get != null) { 96 SparkEnv.get.conf.getInt("spark.debug.maxToStringFields", DEFAULT_MAX_TO_STRING_FIELDS) 97 } else { 98 DEFAULT_MAX_TO_STRING_FIELDS 99 } 100 } 101 102 /** Whether we have warned about plan string truncation yet. */ 103 private val truncationWarningPrinted = new AtomicBoolean(false) 104 105 /** 106 * Format a sequence with semantics similar to calling .mkString(). Any elements beyond 107 * maxNumToStringFields will be dropped and replaced by a "... N more fields" placeholder. 108 * 109 * @return the trimmed and formatted string. 110 */ 111 def truncatedString[T]( 112 seq: Seq[T], 113 start: String, 114 sep: String, 115 end: String, 116 maxNumFields: Int = maxNumToStringFields): String = { 117 if (seq.length > maxNumFields) { 118 if (truncationWarningPrinted.compareAndSet(false, true)) { 119 logWarning( 120 "Truncated the string representation of a plan since it was too large. This " + 121 "behavior can be adjusted by setting 'spark.debug.maxToStringFields' in SparkEnv.conf.") 122 } 123 val numFields = math.max(0, maxNumFields - 1) 124 seq.take(numFields).mkString( 125 start, sep, sep + "... " + (seq.length - numFields) + " more fields" + end) 126 } else { 127 seq.mkString(start, sep, end) 128 } 129 } 130 131 /** Shorthand for calling truncatedString() without start or end strings. */ 132 def truncatedString[T](seq: Seq[T], sep: String): String = truncatedString(seq, "", sep, "") 133 134 /** Serialize an object using Java serialization */ 135 def serialize[T](o: T): Array[Byte] = { 136 val bos = new ByteArrayOutputStream() 137 val oos = new ObjectOutputStream(bos) 138 oos.writeObject(o) 139 oos.close() 140 bos.toByteArray 141 } 142 143 /** Deserialize an object using Java serialization */ 144 def deserialize[T](bytes: Array[Byte]): T = { 145 val bis = new ByteArrayInputStream(bytes) 146 val ois = new ObjectInputStream(bis) 147 ois.readObject.asInstanceOf[T] 148 } 149 150 /** Deserialize an object using Java serialization and the given ClassLoader */ 151 def deserialize[T](bytes: Array[Byte], loader: ClassLoader): T = { 152 val bis = new ByteArrayInputStream(bytes) 153 val ois = new ObjectInputStream(bis) { 154 override def resolveClass(desc: ObjectStreamClass): Class[_] = { 155 // scalastyle:off classforname 156 Class.forName(desc.getName, false, loader) 157 // scalastyle:on classforname 158 } 159 } 160 ois.readObject.asInstanceOf[T] 161 } 162 163 /** Deserialize a Long value (used for [[org.apache.spark.api.python.PythonPartitioner]]) */ 164 def deserializeLongValue(bytes: Array[Byte]) : Long = { 165 // Note: we assume that we are given a Long value encoded in network (big-endian) byte order 166 var result = bytes(7) & 0xFFL 167 result = result + ((bytes(6) & 0xFFL) << 8) 168 result = result + ((bytes(5) & 0xFFL) << 16) 169 result = result + ((bytes(4) & 0xFFL) << 24) 170 result = result + ((bytes(3) & 0xFFL) << 32) 171 result = result + ((bytes(2) & 0xFFL) << 40) 172 result = result + ((bytes(1) & 0xFFL) << 48) 173 result + ((bytes(0) & 0xFFL) << 56) 174 } 175 176 /** Serialize via nested stream using specific serializer */ 177 def serializeViaNestedStream(os: OutputStream, ser: SerializerInstance)( 178 f: SerializationStream => Unit): Unit = { 179 val osWrapper = ser.serializeStream(new OutputStream { 180 override def write(b: Int): Unit = os.write(b) 181 override def write(b: Array[Byte], off: Int, len: Int): Unit = os.write(b, off, len) 182 }) 183 try { 184 f(osWrapper) 185 } finally { 186 osWrapper.close() 187 } 188 } 189 190 /** Deserialize via nested stream using specific serializer */ 191 def deserializeViaNestedStream(is: InputStream, ser: SerializerInstance)( 192 f: DeserializationStream => Unit): Unit = { 193 val isWrapper = ser.deserializeStream(new InputStream { 194 override def read(): Int = is.read() 195 override def read(b: Array[Byte], off: Int, len: Int): Int = is.read(b, off, len) 196 }) 197 try { 198 f(isWrapper) 199 } finally { 200 isWrapper.close() 201 } 202 } 203 204 /** 205 * Get the ClassLoader which loaded Spark. 206 */ 207 def getSparkClassLoader: ClassLoader = getClass.getClassLoader 208 209 /** 210 * Get the Context ClassLoader on this thread or, if not present, the ClassLoader that 211 * loaded Spark. 212 * 213 * This should be used whenever passing a ClassLoader to Class.ForName or finding the currently 214 * active loader when setting up ClassLoader delegation chains. 215 */ 216 def getContextOrSparkClassLoader: ClassLoader = 217 Option(Thread.currentThread().getContextClassLoader).getOrElse(getSparkClassLoader) 218 219 /** Determines whether the provided class is loadable in the current thread. */ 220 def classIsLoadable(clazz: String): Boolean = { 221 // scalastyle:off classforname 222 Try { Class.forName(clazz, false, getContextOrSparkClassLoader) }.isSuccess 223 // scalastyle:on classforname 224 } 225 226 // scalastyle:off classforname 227 /** Preferred alternative to Class.forName(className) */ 228 def classForName(className: String): Class[_] = { 229 Class.forName(className, true, getContextOrSparkClassLoader) 230 // scalastyle:on classforname 231 } 232 233 /** 234 * Primitive often used when writing [[java.nio.ByteBuffer]] to [[java.io.DataOutput]] 235 */ 236 def writeByteBuffer(bb: ByteBuffer, out: DataOutput): Unit = { 237 if (bb.hasArray) { 238 out.write(bb.array(), bb.arrayOffset() + bb.position(), bb.remaining()) 239 } else { 240 val bbval = new Array[Byte](bb.remaining()) 241 bb.get(bbval) 242 out.write(bbval) 243 } 244 } 245 246 /** 247 * Primitive often used when writing [[java.nio.ByteBuffer]] to [[java.io.OutputStream]] 248 */ 249 def writeByteBuffer(bb: ByteBuffer, out: OutputStream): Unit = { 250 if (bb.hasArray) { 251 out.write(bb.array(), bb.arrayOffset() + bb.position(), bb.remaining()) 252 } else { 253 val bbval = new Array[Byte](bb.remaining()) 254 bb.get(bbval) 255 out.write(bbval) 256 } 257 } 258 259 /** 260 * JDK equivalent of `chmod 700 file`. 261 * 262 * @param file the file whose permissions will be modified 263 * @return true if the permissions were successfully changed, false otherwise. 264 */ 265 def chmod700(file: File): Boolean = { 266 file.setReadable(false, false) && 267 file.setReadable(true, true) && 268 file.setWritable(false, false) && 269 file.setWritable(true, true) && 270 file.setExecutable(false, false) && 271 file.setExecutable(true, true) 272 } 273 274 /** 275 * Create a directory inside the given parent directory. The directory is guaranteed to be 276 * newly created, and is not marked for automatic deletion. 277 */ 278 def createDirectory(root: String, namePrefix: String = "spark"): File = { 279 var attempts = 0 280 val maxAttempts = MAX_DIR_CREATION_ATTEMPTS 281 var dir: File = null 282 while (dir == null) { 283 attempts += 1 284 if (attempts > maxAttempts) { 285 throw new IOException("Failed to create a temp directory (under " + root + ") after " + 286 maxAttempts + " attempts!") 287 } 288 try { 289 dir = new File(root, namePrefix + "-" + UUID.randomUUID.toString) 290 if (dir.exists() || !dir.mkdirs()) { 291 dir = null 292 } 293 } catch { case e: SecurityException => dir = null; } 294 } 295 296 dir.getCanonicalFile 297 } 298 299 /** 300 * Create a temporary directory inside the given parent directory. The directory will be 301 * automatically deleted when the VM shuts down. 302 */ 303 def createTempDir( 304 root: String = System.getProperty("java.io.tmpdir"), 305 namePrefix: String = "spark"): File = { 306 val dir = createDirectory(root, namePrefix) 307 ShutdownHookManager.registerShutdownDeleteDir(dir) 308 dir 309 } 310 311 /** 312 * Copy all data from an InputStream to an OutputStream. NIO way of file stream to file stream 313 * copying is disabled by default unless explicitly set transferToEnabled as true, 314 * the parameter transferToEnabled should be configured by spark.file.transferTo = [true|false]. 315 */ 316 def copyStream(in: InputStream, 317 out: OutputStream, 318 closeStreams: Boolean = false, 319 transferToEnabled: Boolean = false): Long = 320 { 321 var count = 0L 322 tryWithSafeFinally { 323 if (in.isInstanceOf[FileInputStream] && out.isInstanceOf[FileOutputStream] 324 && transferToEnabled) { 325 // When both streams are File stream, use transferTo to improve copy performance. 326 val inChannel = in.asInstanceOf[FileInputStream].getChannel() 327 val outChannel = out.asInstanceOf[FileOutputStream].getChannel() 328 val initialPos = outChannel.position() 329 val size = inChannel.size() 330 331 // In case transferTo method transferred less data than we have required. 332 while (count < size) { 333 count += inChannel.transferTo(count, size - count, outChannel) 334 } 335 336 // Check the position after transferTo loop to see if it is in the right position and 337 // give user information if not. 338 // Position will not be increased to the expected length after calling transferTo in 339 // kernel version 2.6.32, this issue can be seen in 340 // https://bugs.openjdk.java.net/browse/JDK-7052359 341 // This will lead to stream corruption issue when using sort-based shuffle (SPARK-3948). 342 val finalPos = outChannel.position() 343 assert(finalPos == initialPos + size, 344 s""" 345 |Current position $finalPos do not equal to expected position ${initialPos + size} 346 |after transferTo, please check your kernel version to see if it is 2.6.32, 347 |this is a kernel bug which will lead to unexpected behavior when using transferTo. 348 |You can set spark.file.transferTo = false to disable this NIO feature. 349 """.stripMargin) 350 } else { 351 val buf = new Array[Byte](8192) 352 var n = 0 353 while (n != -1) { 354 n = in.read(buf) 355 if (n != -1) { 356 out.write(buf, 0, n) 357 count += n 358 } 359 } 360 } 361 count 362 } { 363 if (closeStreams) { 364 try { 365 in.close() 366 } finally { 367 out.close() 368 } 369 } 370 } 371 } 372 373 /** 374 * Construct a URI container information used for authentication. 375 * This also sets the default authenticator to properly negotiation the 376 * user/password based on the URI. 377 * 378 * Note this relies on the Authenticator.setDefault being set properly to decode 379 * the user name and password. This is currently set in the SecurityManager. 380 */ 381 def constructURIForAuthentication(uri: URI, securityMgr: SecurityManager): URI = { 382 val userCred = securityMgr.getSecretKey() 383 if (userCred == null) throw new Exception("Secret key is null with authentication on") 384 val userInfo = securityMgr.getHttpUser() + ":" + userCred 385 new URI(uri.getScheme(), userInfo, uri.getHost(), uri.getPort(), uri.getPath(), 386 uri.getQuery(), uri.getFragment()) 387 } 388 389 /** 390 * A file name may contain some invalid URI characters, such as " ". This method will convert the 391 * file name to a raw path accepted by `java.net.URI(String)`. 392 * 393 * Note: the file name must not contain "/" or "\" 394 */ 395 def encodeFileNameToURIRawPath(fileName: String): String = { 396 require(!fileName.contains("/") && !fileName.contains("\\")) 397 // `file` and `localhost` are not used. Just to prevent URI from parsing `fileName` as 398 // scheme or host. The prefix "/" is required because URI doesn't accept a relative path. 399 // We should remove it after we get the raw path. 400 new URI("file", null, "localhost", -1, "/" + fileName, null, null).getRawPath.substring(1) 401 } 402 403 /** 404 * Get the file name from uri's raw path and decode it. If the raw path of uri ends with "/", 405 * return the name before the last "/". 406 */ 407 def decodeFileNameInURI(uri: URI): String = { 408 val rawPath = uri.getRawPath 409 val rawFileName = rawPath.split("/").last 410 new URI("file:///" + rawFileName).getPath.substring(1) 411 } 412 413 /** 414 * Download a file or directory to target directory. Supports fetching the file in a variety of 415 * ways, including HTTP, Hadoop-compatible filesystems, and files on a standard filesystem, based 416 * on the URL parameter. Fetching directories is only supported from Hadoop-compatible 417 * filesystems. 418 * 419 * If `useCache` is true, first attempts to fetch the file to a local cache that's shared 420 * across executors running the same application. `useCache` is used mainly for 421 * the executors, and not in local mode. 422 * 423 * Throws SparkException if the target file already exists and has different contents than 424 * the requested file. 425 */ 426 def fetchFile( 427 url: String, 428 targetDir: File, 429 conf: SparkConf, 430 securityMgr: SecurityManager, 431 hadoopConf: Configuration, 432 timestamp: Long, 433 useCache: Boolean) { 434 val fileName = decodeFileNameInURI(new URI(url)) 435 val targetFile = new File(targetDir, fileName) 436 val fetchCacheEnabled = conf.getBoolean("spark.files.useFetchCache", defaultValue = true) 437 if (useCache && fetchCacheEnabled) { 438 val cachedFileName = s"${url.hashCode}${timestamp}_cache" 439 val lockFileName = s"${url.hashCode}${timestamp}_lock" 440 val localDir = new File(getLocalDir(conf)) 441 val lockFile = new File(localDir, lockFileName) 442 val lockFileChannel = new RandomAccessFile(lockFile, "rw").getChannel() 443 // Only one executor entry. 444 // The FileLock is only used to control synchronization for executors download file, 445 // it's always safe regardless of lock type (mandatory or advisory). 446 val lock = lockFileChannel.lock() 447 val cachedFile = new File(localDir, cachedFileName) 448 try { 449 if (!cachedFile.exists()) { 450 doFetchFile(url, localDir, cachedFileName, conf, securityMgr, hadoopConf) 451 } 452 } finally { 453 lock.release() 454 lockFileChannel.close() 455 } 456 copyFile( 457 url, 458 cachedFile, 459 targetFile, 460 conf.getBoolean("spark.files.overwrite", false) 461 ) 462 } else { 463 doFetchFile(url, targetDir, fileName, conf, securityMgr, hadoopConf) 464 } 465 466 // Decompress the file if it's a .tar or .tar.gz 467 if (fileName.endsWith(".tar.gz") || fileName.endsWith(".tgz")) { 468 logInfo("Untarring " + fileName) 469 executeAndGetOutput(Seq("tar", "-xzf", fileName), targetDir) 470 } else if (fileName.endsWith(".tar")) { 471 logInfo("Untarring " + fileName) 472 executeAndGetOutput(Seq("tar", "-xf", fileName), targetDir) 473 } 474 // Make the file executable - That's necessary for scripts 475 FileUtil.chmod(targetFile.getAbsolutePath, "a+x") 476 477 // Windows does not grant read permission by default to non-admin users 478 // Add read permission to owner explicitly 479 if (isWindows) { 480 FileUtil.chmod(targetFile.getAbsolutePath, "u+r") 481 } 482 } 483 484 /** 485 * Download `in` to `tempFile`, then move it to `destFile`. 486 * 487 * If `destFile` already exists: 488 * - no-op if its contents equal those of `sourceFile`, 489 * - throw an exception if `fileOverwrite` is false, 490 * - attempt to overwrite it otherwise. 491 * 492 * @param url URL that `sourceFile` originated from, for logging purposes. 493 * @param in InputStream to download. 494 * @param destFile File path to move `tempFile` to. 495 * @param fileOverwrite Whether to delete/overwrite an existing `destFile` that does not match 496 * `sourceFile` 497 */ 498 private def downloadFile( 499 url: String, 500 in: InputStream, 501 destFile: File, 502 fileOverwrite: Boolean): Unit = { 503 val tempFile = File.createTempFile("fetchFileTemp", null, 504 new File(destFile.getParentFile.getAbsolutePath)) 505 logInfo(s"Fetching $url to $tempFile") 506 507 try { 508 val out = new FileOutputStream(tempFile) 509 Utils.copyStream(in, out, closeStreams = true) 510 copyFile(url, tempFile, destFile, fileOverwrite, removeSourceFile = true) 511 } finally { 512 // Catch-all for the couple of cases where for some reason we didn't move `tempFile` to 513 // `destFile`. 514 if (tempFile.exists()) { 515 tempFile.delete() 516 } 517 } 518 } 519 520 /** 521 * Copy `sourceFile` to `destFile`. 522 * 523 * If `destFile` already exists: 524 * - no-op if its contents equal those of `sourceFile`, 525 * - throw an exception if `fileOverwrite` is false, 526 * - attempt to overwrite it otherwise. 527 * 528 * @param url URL that `sourceFile` originated from, for logging purposes. 529 * @param sourceFile File path to copy/move from. 530 * @param destFile File path to copy/move to. 531 * @param fileOverwrite Whether to delete/overwrite an existing `destFile` that does not match 532 * `sourceFile` 533 * @param removeSourceFile Whether to remove `sourceFile` after / as part of moving/copying it to 534 * `destFile`. 535 */ 536 private def copyFile( 537 url: String, 538 sourceFile: File, 539 destFile: File, 540 fileOverwrite: Boolean, 541 removeSourceFile: Boolean = false): Unit = { 542 543 if (destFile.exists) { 544 if (!filesEqualRecursive(sourceFile, destFile)) { 545 if (fileOverwrite) { 546 logInfo( 547 s"File $destFile exists and does not match contents of $url, replacing it with $url" 548 ) 549 if (!destFile.delete()) { 550 throw new SparkException( 551 "Failed to delete %s while attempting to overwrite it with %s".format( 552 destFile.getAbsolutePath, 553 sourceFile.getAbsolutePath 554 ) 555 ) 556 } 557 } else { 558 throw new SparkException( 559 s"File $destFile exists and does not match contents of $url") 560 } 561 } else { 562 // Do nothing if the file contents are the same, i.e. this file has been copied 563 // previously. 564 logInfo( 565 "%s has been previously copied to %s".format( 566 sourceFile.getAbsolutePath, 567 destFile.getAbsolutePath 568 ) 569 ) 570 return 571 } 572 } 573 574 // The file does not exist in the target directory. Copy or move it there. 575 if (removeSourceFile) { 576 Files.move(sourceFile.toPath, destFile.toPath) 577 } else { 578 logInfo(s"Copying ${sourceFile.getAbsolutePath} to ${destFile.getAbsolutePath}") 579 copyRecursive(sourceFile, destFile) 580 } 581 } 582 583 private def filesEqualRecursive(file1: File, file2: File): Boolean = { 584 if (file1.isDirectory && file2.isDirectory) { 585 val subfiles1 = file1.listFiles() 586 val subfiles2 = file2.listFiles() 587 if (subfiles1.size != subfiles2.size) { 588 return false 589 } 590 subfiles1.sortBy(_.getName).zip(subfiles2.sortBy(_.getName)).forall { 591 case (f1, f2) => filesEqualRecursive(f1, f2) 592 } 593 } else if (file1.isFile && file2.isFile) { 594 GFiles.equal(file1, file2) 595 } else { 596 false 597 } 598 } 599 600 private def copyRecursive(source: File, dest: File): Unit = { 601 if (source.isDirectory) { 602 if (!dest.mkdir()) { 603 throw new IOException(s"Failed to create directory ${dest.getPath}") 604 } 605 val subfiles = source.listFiles() 606 subfiles.foreach(f => copyRecursive(f, new File(dest, f.getName))) 607 } else { 608 Files.copy(source.toPath, dest.toPath) 609 } 610 } 611 612 /** 613 * Download a file or directory to target directory. Supports fetching the file in a variety of 614 * ways, including HTTP, Hadoop-compatible filesystems, and files on a standard filesystem, based 615 * on the URL parameter. Fetching directories is only supported from Hadoop-compatible 616 * filesystems. 617 * 618 * Throws SparkException if the target file already exists and has different contents than 619 * the requested file. 620 */ 621 private def doFetchFile( 622 url: String, 623 targetDir: File, 624 filename: String, 625 conf: SparkConf, 626 securityMgr: SecurityManager, 627 hadoopConf: Configuration) { 628 val targetFile = new File(targetDir, filename) 629 val uri = new URI(url) 630 val fileOverwrite = conf.getBoolean("spark.files.overwrite", defaultValue = false) 631 Option(uri.getScheme).getOrElse("file") match { 632 case "spark" => 633 if (SparkEnv.get == null) { 634 throw new IllegalStateException( 635 "Cannot retrieve files with 'spark' scheme without an active SparkEnv.") 636 } 637 val source = SparkEnv.get.rpcEnv.openChannel(url) 638 val is = Channels.newInputStream(source) 639 downloadFile(url, is, targetFile, fileOverwrite) 640 case "http" | "https" | "ftp" => 641 var uc: URLConnection = null 642 if (securityMgr.isAuthenticationEnabled()) { 643 logDebug("fetchFile with security enabled") 644 val newuri = constructURIForAuthentication(uri, securityMgr) 645 uc = newuri.toURL().openConnection() 646 uc.setAllowUserInteraction(false) 647 } else { 648 logDebug("fetchFile not using security") 649 uc = new URL(url).openConnection() 650 } 651 Utils.setupSecureURLConnection(uc, securityMgr) 652 653 val timeoutMs = 654 conf.getTimeAsSeconds("spark.files.fetchTimeout", "60s").toInt * 1000 655 uc.setConnectTimeout(timeoutMs) 656 uc.setReadTimeout(timeoutMs) 657 uc.connect() 658 val in = uc.getInputStream() 659 downloadFile(url, in, targetFile, fileOverwrite) 660 case "file" => 661 // In the case of a local file, copy the local file to the target directory. 662 // Note the difference between uri vs url. 663 val sourceFile = if (uri.isAbsolute) new File(uri) else new File(url) 664 copyFile(url, sourceFile, targetFile, fileOverwrite) 665 case _ => 666 val fs = getHadoopFileSystem(uri, hadoopConf) 667 val path = new Path(uri) 668 fetchHcfsFile(path, targetDir, fs, conf, hadoopConf, fileOverwrite, 669 filename = Some(filename)) 670 } 671 } 672 673 /** 674 * Fetch a file or directory from a Hadoop-compatible filesystem. 675 * 676 * Visible for testing 677 */ 678 private[spark] def fetchHcfsFile( 679 path: Path, 680 targetDir: File, 681 fs: FileSystem, 682 conf: SparkConf, 683 hadoopConf: Configuration, 684 fileOverwrite: Boolean, 685 filename: Option[String] = None): Unit = { 686 if (!targetDir.exists() && !targetDir.mkdir()) { 687 throw new IOException(s"Failed to create directory ${targetDir.getPath}") 688 } 689 val dest = new File(targetDir, filename.getOrElse(path.getName)) 690 if (fs.isFile(path)) { 691 val in = fs.open(path) 692 try { 693 downloadFile(path.toString, in, dest, fileOverwrite) 694 } finally { 695 in.close() 696 } 697 } else { 698 fs.listStatus(path).foreach { fileStatus => 699 fetchHcfsFile(fileStatus.getPath(), dest, fs, conf, hadoopConf, fileOverwrite) 700 } 701 } 702 } 703 704 /** 705 * Validate that a given URI is actually a valid URL as well. 706 * @param uri The URI to validate 707 */ 708 @throws[MalformedURLException]("when the URI is an invalid URL") 709 def validateURL(uri: URI): Unit = { 710 Option(uri.getScheme).getOrElse("file") match { 711 case "http" | "https" | "ftp" => 712 try { 713 uri.toURL 714 } catch { 715 case e: MalformedURLException => 716 val ex = new MalformedURLException(s"URI (${uri.toString}) is not a valid URL.") 717 ex.initCause(e) 718 throw ex 719 } 720 case _ => // will not be turned into a URL anyway 721 } 722 } 723 724 /** 725 * Get the path of a temporary directory. Spark's local directories can be configured through 726 * multiple settings, which are used with the following precedence: 727 * 728 * - If called from inside of a YARN container, this will return a directory chosen by YARN. 729 * - If the SPARK_LOCAL_DIRS environment variable is set, this will return a directory from it. 730 * - Otherwise, if the spark.local.dir is set, this will return a directory from it. 731 * - Otherwise, this will return java.io.tmpdir. 732 * 733 * Some of these configuration options might be lists of multiple paths, but this method will 734 * always return a single directory. 735 */ 736 def getLocalDir(conf: SparkConf): String = { 737 getOrCreateLocalRootDirs(conf)(0) 738 } 739 740 private[spark] def isRunningInYarnContainer(conf: SparkConf): Boolean = { 741 // These environment variables are set by YARN. 742 conf.getenv("CONTAINER_ID") != null 743 } 744 745 /** 746 * Gets or creates the directories listed in spark.local.dir or SPARK_LOCAL_DIRS, 747 * and returns only the directories that exist / could be created. 748 * 749 * If no directories could be created, this will return an empty list. 750 * 751 * This method will cache the local directories for the application when it's first invoked. 752 * So calling it multiple times with a different configuration will always return the same 753 * set of directories. 754 */ 755 private[spark] def getOrCreateLocalRootDirs(conf: SparkConf): Array[String] = { 756 if (localRootDirs == null) { 757 this.synchronized { 758 if (localRootDirs == null) { 759 localRootDirs = getOrCreateLocalRootDirsImpl(conf) 760 } 761 } 762 } 763 localRootDirs 764 } 765 766 /** 767 * Return the configured local directories where Spark can write files. This 768 * method does not create any directories on its own, it only encapsulates the 769 * logic of locating the local directories according to deployment mode. 770 */ 771 def getConfiguredLocalDirs(conf: SparkConf): Array[String] = { 772 val shuffleServiceEnabled = conf.getBoolean("spark.shuffle.service.enabled", false) 773 if (isRunningInYarnContainer(conf)) { 774 // If we are in yarn mode, systems can have different disk layouts so we must set it 775 // to what Yarn on this system said was available. Note this assumes that Yarn has 776 // created the directories already, and that they are secured so that only the 777 // user has access to them. 778 getYarnLocalDirs(conf).split(",") 779 } else if (conf.getenv("SPARK_EXECUTOR_DIRS") != null) { 780 conf.getenv("SPARK_EXECUTOR_DIRS").split(File.pathSeparator) 781 } else if (conf.getenv("SPARK_LOCAL_DIRS") != null) { 782 conf.getenv("SPARK_LOCAL_DIRS").split(",") 783 } else if (conf.getenv("MESOS_DIRECTORY") != null && !shuffleServiceEnabled) { 784 // Mesos already creates a directory per Mesos task. Spark should use that directory 785 // instead so all temporary files are automatically cleaned up when the Mesos task ends. 786 // Note that we don't want this if the shuffle service is enabled because we want to 787 // continue to serve shuffle files after the executors that wrote them have already exited. 788 Array(conf.getenv("MESOS_DIRECTORY")) 789 } else { 790 if (conf.getenv("MESOS_DIRECTORY") != null && shuffleServiceEnabled) { 791 logInfo("MESOS_DIRECTORY available but not using provided Mesos sandbox because " + 792 "spark.shuffle.service.enabled is enabled.") 793 } 794 // In non-Yarn mode (or for the driver in yarn-client mode), we cannot trust the user 795 // configuration to point to a secure directory. So create a subdirectory with restricted 796 // permissions under each listed directory. 797 conf.get("spark.local.dir", System.getProperty("java.io.tmpdir")).split(",") 798 } 799 } 800 801 private def getOrCreateLocalRootDirsImpl(conf: SparkConf): Array[String] = { 802 getConfiguredLocalDirs(conf).flatMap { root => 803 try { 804 val rootDir = new File(root) 805 if (rootDir.exists || rootDir.mkdirs()) { 806 val dir = createTempDir(root) 807 chmod700(dir) 808 Some(dir.getAbsolutePath) 809 } else { 810 logError(s"Failed to create dir in $root. Ignoring this directory.") 811 None 812 } 813 } catch { 814 case e: IOException => 815 logError(s"Failed to create local root dir in $root. Ignoring this directory.") 816 None 817 } 818 } 819 } 820 821 /** Get the Yarn approved local directories. */ 822 private def getYarnLocalDirs(conf: SparkConf): String = { 823 val localDirs = Option(conf.getenv("LOCAL_DIRS")).getOrElse("") 824 825 if (localDirs.isEmpty) { 826 throw new Exception("Yarn Local dirs can't be empty") 827 } 828 localDirs 829 } 830 831 /** Used by unit tests. Do not call from other places. */ 832 private[spark] def clearLocalRootDirs(): Unit = { 833 localRootDirs = null 834 } 835 836 /** 837 * Shuffle the elements of a collection into a random order, returning the 838 * result in a new collection. Unlike scala.util.Random.shuffle, this method 839 * uses a local random number generator, avoiding inter-thread contention. 840 */ 841 def randomize[T: ClassTag](seq: TraversableOnce[T]): Seq[T] = { 842 randomizeInPlace(seq.toArray) 843 } 844 845 /** 846 * Shuffle the elements of an array into a random order, modifying the 847 * original array. Returns the original array. 848 */ 849 def randomizeInPlace[T](arr: Array[T], rand: Random = new Random): Array[T] = { 850 for (i <- (arr.length - 1) to 1 by -1) { 851 val j = rand.nextInt(i + 1) 852 val tmp = arr(j) 853 arr(j) = arr(i) 854 arr(i) = tmp 855 } 856 arr 857 } 858 859 /** 860 * Get the local host's IP address in dotted-quad format (e.g. 1.2.3.4). 861 * Note, this is typically not used from within core spark. 862 */ 863 private lazy val localIpAddress: InetAddress = findLocalInetAddress() 864 865 private def findLocalInetAddress(): InetAddress = { 866 val defaultIpOverride = System.getenv("SPARK_LOCAL_IP") 867 if (defaultIpOverride != null) { 868 InetAddress.getByName(defaultIpOverride) 869 } else { 870 val address = InetAddress.getLocalHost 871 if (address.isLoopbackAddress) { 872 // Address resolves to something like 127.0.1.1, which happens on Debian; try to find 873 // a better address using the local network interfaces 874 // getNetworkInterfaces returns ifs in reverse order compared to ifconfig output order 875 // on unix-like system. On windows, it returns in index order. 876 // It's more proper to pick ip address following system output order. 877 val activeNetworkIFs = NetworkInterface.getNetworkInterfaces.asScala.toSeq 878 val reOrderedNetworkIFs = if (isWindows) activeNetworkIFs else activeNetworkIFs.reverse 879 880 for (ni <- reOrderedNetworkIFs) { 881 val addresses = ni.getInetAddresses.asScala 882 .filterNot(addr => addr.isLinkLocalAddress || addr.isLoopbackAddress).toSeq 883 if (addresses.nonEmpty) { 884 val addr = addresses.find(_.isInstanceOf[Inet4Address]).getOrElse(addresses.head) 885 // because of Inet6Address.toHostName may add interface at the end if it knows about it 886 val strippedAddress = InetAddress.getByAddress(addr.getAddress) 887 // We've found an address that looks reasonable! 888 logWarning("Your hostname, " + InetAddress.getLocalHost.getHostName + " resolves to" + 889 " a loopback address: " + address.getHostAddress + "; using " + 890 strippedAddress.getHostAddress + " instead (on interface " + ni.getName + ")") 891 logWarning("Set SPARK_LOCAL_IP if you need to bind to another address") 892 return strippedAddress 893 } 894 } 895 logWarning("Your hostname, " + InetAddress.getLocalHost.getHostName + " resolves to" + 896 " a loopback address: " + address.getHostAddress + ", but we couldn't find any" + 897 " external IP address!") 898 logWarning("Set SPARK_LOCAL_IP if you need to bind to another address") 899 } 900 address 901 } 902 } 903 904 private var customHostname: Option[String] = sys.env.get("SPARK_LOCAL_HOSTNAME") 905 906 /** 907 * Allow setting a custom host name because when we run on Mesos we need to use the same 908 * hostname it reports to the master. 909 */ 910 def setCustomHostname(hostname: String) { 911 // DEBUG code 912 Utils.checkHost(hostname) 913 customHostname = Some(hostname) 914 } 915 916 /** 917 * Get the local machine's hostname. 918 */ 919 def localHostName(): String = { 920 customHostname.getOrElse(localIpAddress.getHostAddress) 921 } 922 923 /** 924 * Get the local machine's URI. 925 */ 926 def localHostNameForURI(): String = { 927 customHostname.getOrElse(InetAddresses.toUriString(localIpAddress)) 928 } 929 930 def checkHost(host: String, message: String = "") { 931 assert(host.indexOf(':') == -1, message) 932 } 933 934 def checkHostPort(hostPort: String, message: String = "") { 935 assert(hostPort.indexOf(':') != -1, message) 936 } 937 938 // Typically, this will be of order of number of nodes in cluster 939 // If not, we should change it to LRUCache or something. 940 private val hostPortParseResults = new ConcurrentHashMap[String, (String, Int)]() 941 942 def parseHostPort(hostPort: String): (String, Int) = { 943 // Check cache first. 944 val cached = hostPortParseResults.get(hostPort) 945 if (cached != null) { 946 return cached 947 } 948 949 val indx: Int = hostPort.lastIndexOf(':') 950 // This is potentially broken - when dealing with ipv6 addresses for example, sigh ... 951 // but then hadoop does not support ipv6 right now. 952 // For now, we assume that if port exists, then it is valid - not check if it is an int > 0 953 if (-1 == indx) { 954 val retval = (hostPort, 0) 955 hostPortParseResults.put(hostPort, retval) 956 return retval 957 } 958 959 val retval = (hostPort.substring(0, indx).trim(), hostPort.substring(indx + 1).trim().toInt) 960 hostPortParseResults.putIfAbsent(hostPort, retval) 961 hostPortParseResults.get(hostPort) 962 } 963 964 /** 965 * Return the string to tell how long has passed in milliseconds. 966 */ 967 def getUsedTimeMs(startTimeMs: Long): String = { 968 " " + (System.currentTimeMillis - startTimeMs) + " ms" 969 } 970 971 private def listFilesSafely(file: File): Seq[File] = { 972 if (file.exists()) { 973 val files = file.listFiles() 974 if (files == null) { 975 throw new IOException("Failed to list files for dir: " + file) 976 } 977 files 978 } else { 979 List() 980 } 981 } 982 983 /** 984 * Delete a file or directory and its contents recursively. 985 * Don't follow directories if they are symlinks. 986 * Throws an exception if deletion is unsuccessful. 987 */ 988 def deleteRecursively(file: File) { 989 if (file != null) { 990 try { 991 if (file.isDirectory && !isSymlink(file)) { 992 var savedIOException: IOException = null 993 for (child <- listFilesSafely(file)) { 994 try { 995 deleteRecursively(child) 996 } catch { 997 // In case of multiple exceptions, only last one will be thrown 998 case ioe: IOException => savedIOException = ioe 999 } 1000 } 1001 if (savedIOException != null) { 1002 throw savedIOException 1003 } 1004 ShutdownHookManager.removeShutdownDeleteDir(file) 1005 } 1006 } finally { 1007 if (!file.delete()) { 1008 // Delete can also fail if the file simply did not exist 1009 if (file.exists()) { 1010 throw new IOException("Failed to delete: " + file.getAbsolutePath) 1011 } 1012 } 1013 } 1014 } 1015 } 1016 1017 /** 1018 * Check to see if file is a symbolic link. 1019 */ 1020 def isSymlink(file: File): Boolean = { 1021 return Files.isSymbolicLink(Paths.get(file.toURI)) 1022 } 1023 1024 /** 1025 * Determines if a directory contains any files newer than cutoff seconds. 1026 * 1027 * @param dir must be the path to a directory, or IllegalArgumentException is thrown 1028 * @param cutoff measured in seconds. Returns true if there are any files or directories in the 1029 * given directory whose last modified time is later than this many seconds ago 1030 */ 1031 def doesDirectoryContainAnyNewFiles(dir: File, cutoff: Long): Boolean = { 1032 if (!dir.isDirectory) { 1033 throw new IllegalArgumentException(s"$dir is not a directory!") 1034 } 1035 val filesAndDirs = dir.listFiles() 1036 val cutoffTimeInMillis = System.currentTimeMillis - (cutoff * 1000) 1037 1038 filesAndDirs.exists(_.lastModified() > cutoffTimeInMillis) || 1039 filesAndDirs.filter(_.isDirectory).exists( 1040 subdir => doesDirectoryContainAnyNewFiles(subdir, cutoff) 1041 ) 1042 } 1043 1044 /** 1045 * Convert a time parameter such as (50s, 100ms, or 250us) to microseconds for internal use. If 1046 * no suffix is provided, the passed number is assumed to be in ms. 1047 */ 1048 def timeStringAsMs(str: String): Long = { 1049 JavaUtils.timeStringAsMs(str) 1050 } 1051 1052 /** 1053 * Convert a time parameter such as (50s, 100ms, or 250us) to seconds for internal use. If 1054 * no suffix is provided, the passed number is assumed to be in seconds. 1055 */ 1056 def timeStringAsSeconds(str: String): Long = { 1057 JavaUtils.timeStringAsSec(str) 1058 } 1059 1060 /** 1061 * Convert a passed byte string (e.g. 50b, 100k, or 250m) to bytes for internal use. 1062 * 1063 * If no suffix is provided, the passed number is assumed to be in bytes. 1064 */ 1065 def byteStringAsBytes(str: String): Long = { 1066 JavaUtils.byteStringAsBytes(str) 1067 } 1068 1069 /** 1070 * Convert a passed byte string (e.g. 50b, 100k, or 250m) to kibibytes for internal use. 1071 * 1072 * If no suffix is provided, the passed number is assumed to be in kibibytes. 1073 */ 1074 def byteStringAsKb(str: String): Long = { 1075 JavaUtils.byteStringAsKb(str) 1076 } 1077 1078 /** 1079 * Convert a passed byte string (e.g. 50b, 100k, or 250m) to mebibytes for internal use. 1080 * 1081 * If no suffix is provided, the passed number is assumed to be in mebibytes. 1082 */ 1083 def byteStringAsMb(str: String): Long = { 1084 JavaUtils.byteStringAsMb(str) 1085 } 1086 1087 /** 1088 * Convert a passed byte string (e.g. 50b, 100k, or 250m, 500g) to gibibytes for internal use. 1089 * 1090 * If no suffix is provided, the passed number is assumed to be in gibibytes. 1091 */ 1092 def byteStringAsGb(str: String): Long = { 1093 JavaUtils.byteStringAsGb(str) 1094 } 1095 1096 /** 1097 * Convert a Java memory parameter passed to -Xmx (such as 300m or 1g) to a number of mebibytes. 1098 */ 1099 def memoryStringToMb(str: String): Int = { 1100 // Convert to bytes, rather than directly to MB, because when no units are specified the unit 1101 // is assumed to be bytes 1102 (JavaUtils.byteStringAsBytes(str) / 1024 / 1024).toInt 1103 } 1104 1105 /** 1106 * Convert a quantity in bytes to a human-readable string such as "4.0 MB". 1107 */ 1108 def bytesToString(size: Long): String = { 1109 val TB = 1L << 40 1110 val GB = 1L << 30 1111 val MB = 1L << 20 1112 val KB = 1L << 10 1113 1114 val (value, unit) = { 1115 if (size >= 2*TB) { 1116 (size.asInstanceOf[Double] / TB, "TB") 1117 } else if (size >= 2*GB) { 1118 (size.asInstanceOf[Double] / GB, "GB") 1119 } else if (size >= 2*MB) { 1120 (size.asInstanceOf[Double] / MB, "MB") 1121 } else if (size >= 2*KB) { 1122 (size.asInstanceOf[Double] / KB, "KB") 1123 } else { 1124 (size.asInstanceOf[Double], "B") 1125 } 1126 } 1127 "%.1f %s".formatLocal(Locale.US, value, unit) 1128 } 1129 1130 /** 1131 * Returns a human-readable string representing a duration such as "35ms" 1132 */ 1133 def msDurationToString(ms: Long): String = { 1134 val second = 1000 1135 val minute = 60 * second 1136 val hour = 60 * minute 1137 1138 ms match { 1139 case t if t < second => 1140 "%d ms".format(t) 1141 case t if t < minute => 1142 "%.1f s".format(t.toFloat / second) 1143 case t if t < hour => 1144 "%.1f m".format(t.toFloat / minute) 1145 case t => 1146 "%.2f h".format(t.toFloat / hour) 1147 } 1148 } 1149 1150 /** 1151 * Convert a quantity in megabytes to a human-readable string such as "4.0 MB". 1152 */ 1153 def megabytesToString(megabytes: Long): String = { 1154 bytesToString(megabytes * 1024L * 1024L) 1155 } 1156 1157 /** 1158 * Execute a command and return the process running the command. 1159 */ 1160 def executeCommand( 1161 command: Seq[String], 1162 workingDir: File = new File("."), 1163 extraEnvironment: Map[String, String] = Map.empty, 1164 redirectStderr: Boolean = true): Process = { 1165 val builder = new ProcessBuilder(command: _*).directory(workingDir) 1166 val environment = builder.environment() 1167 for ((key, value) <- extraEnvironment) { 1168 environment.put(key, value) 1169 } 1170 val process = builder.start() 1171 if (redirectStderr) { 1172 val threadName = "redirect stderr for command " + command(0) 1173 def log(s: String): Unit = logInfo(s) 1174 processStreamByLine(threadName, process.getErrorStream, log) 1175 } 1176 process 1177 } 1178 1179 /** 1180 * Execute a command and get its output, throwing an exception if it yields a code other than 0. 1181 */ 1182 def executeAndGetOutput( 1183 command: Seq[String], 1184 workingDir: File = new File("."), 1185 extraEnvironment: Map[String, String] = Map.empty, 1186 redirectStderr: Boolean = true): String = { 1187 val process = executeCommand(command, workingDir, extraEnvironment, redirectStderr) 1188 val output = new StringBuilder 1189 val threadName = "read stdout for " + command(0) 1190 def appendToOutput(s: String): Unit = output.append(s).append("\n") 1191 val stdoutThread = processStreamByLine(threadName, process.getInputStream, appendToOutput) 1192 val exitCode = process.waitFor() 1193 stdoutThread.join() // Wait for it to finish reading output 1194 if (exitCode != 0) { 1195 logError(s"Process $command exited with code $exitCode: $output") 1196 throw new SparkException(s"Process $command exited with code $exitCode") 1197 } 1198 output.toString 1199 } 1200 1201 /** 1202 * Return and start a daemon thread that processes the content of the input stream line by line. 1203 */ 1204 def processStreamByLine( 1205 threadName: String, 1206 inputStream: InputStream, 1207 processLine: String => Unit): Thread = { 1208 val t = new Thread(threadName) { 1209 override def run() { 1210 for (line <- Source.fromInputStream(inputStream).getLines()) { 1211 processLine(line) 1212 } 1213 } 1214 } 1215 t.setDaemon(true) 1216 t.start() 1217 t 1218 } 1219 1220 /** 1221 * Execute a block of code that evaluates to Unit, forwarding any uncaught exceptions to the 1222 * default UncaughtExceptionHandler 1223 * 1224 * NOTE: This method is to be called by the spark-started JVM process. 1225 */ 1226 def tryOrExit(block: => Unit) { 1227 try { 1228 block 1229 } catch { 1230 case e: ControlThrowable => throw e 1231 case t: Throwable => SparkUncaughtExceptionHandler.uncaughtException(t) 1232 } 1233 } 1234 1235 /** 1236 * Execute a block of code that evaluates to Unit, stop SparkContext if there is any uncaught 1237 * exception 1238 * 1239 * NOTE: This method is to be called by the driver-side components to avoid stopping the 1240 * user-started JVM process completely; in contrast, tryOrExit is to be called in the 1241 * spark-started JVM process . 1242 */ 1243 def tryOrStopSparkContext(sc: SparkContext)(block: => Unit) { 1244 try { 1245 block 1246 } catch { 1247 case e: ControlThrowable => throw e 1248 case t: Throwable => 1249 val currentThreadName = Thread.currentThread().getName 1250 if (sc != null) { 1251 logError(s"uncaught error in thread $currentThreadName, stopping SparkContext", t) 1252 sc.stopInNewThread() 1253 } 1254 if (!NonFatal(t)) { 1255 logError(s"throw uncaught fatal error in thread $currentThreadName", t) 1256 throw t 1257 } 1258 } 1259 } 1260 1261 /** 1262 * Execute a block of code that returns a value, re-throwing any non-fatal uncaught 1263 * exceptions as IOException. This is used when implementing Externalizable and Serializable's 1264 * read and write methods, since Java's serializer will not report non-IOExceptions properly; 1265 * see SPARK-4080 for more context. 1266 */ 1267 def tryOrIOException[T](block: => T): T = { 1268 try { 1269 block 1270 } catch { 1271 case e: IOException => 1272 logError("Exception encountered", e) 1273 throw e 1274 case NonFatal(e) => 1275 logError("Exception encountered", e) 1276 throw new IOException(e) 1277 } 1278 } 1279 1280 /** Executes the given block. Log non-fatal errors if any, and only throw fatal errors */ 1281 def tryLogNonFatalError(block: => Unit) { 1282 try { 1283 block 1284 } catch { 1285 case NonFatal(t) => 1286 logError(s"Uncaught exception in thread ${Thread.currentThread().getName}", t) 1287 } 1288 } 1289 1290 /** 1291 * Execute a block of code, then a finally block, but if exceptions happen in 1292 * the finally block, do not suppress the original exception. 1293 * 1294 * This is primarily an issue with `finally { out.close() }` blocks, where 1295 * close needs to be called to clean up `out`, but if an exception happened 1296 * in `out.write`, it's likely `out` may be corrupted and `out.close` will 1297 * fail as well. This would then suppress the original/likely more meaningful 1298 * exception from the original `out.write` call. 1299 */ 1300 def tryWithSafeFinally[T](block: => T)(finallyBlock: => Unit): T = { 1301 var originalThrowable: Throwable = null 1302 try { 1303 block 1304 } catch { 1305 case t: Throwable => 1306 // Purposefully not using NonFatal, because even fatal exceptions 1307 // we don't want to have our finallyBlock suppress 1308 originalThrowable = t 1309 throw originalThrowable 1310 } finally { 1311 try { 1312 finallyBlock 1313 } catch { 1314 case t: Throwable => 1315 if (originalThrowable != null) { 1316 originalThrowable.addSuppressed(t) 1317 logWarning(s"Suppressing exception in finally: " + t.getMessage, t) 1318 throw originalThrowable 1319 } else { 1320 throw t 1321 } 1322 } 1323 } 1324 } 1325 1326 /** 1327 * Execute a block of code and call the failure callbacks in the catch block. If exceptions occur 1328 * in either the catch or the finally block, they are appended to the list of suppressed 1329 * exceptions in original exception which is then rethrown. 1330 * 1331 * This is primarily an issue with `catch { abort() }` or `finally { out.close() }` blocks, 1332 * where the abort/close needs to be called to clean up `out`, but if an exception happened 1333 * in `out.write`, it's likely `out` may be corrupted and `abort` or `out.close` will 1334 * fail as well. This would then suppress the original/likely more meaningful 1335 * exception from the original `out.write` call. 1336 */ 1337 def tryWithSafeFinallyAndFailureCallbacks[T](block: => T) 1338 (catchBlock: => Unit = (), finallyBlock: => Unit = ()): T = { 1339 var originalThrowable: Throwable = null 1340 try { 1341 block 1342 } catch { 1343 case cause: Throwable => 1344 // Purposefully not using NonFatal, because even fatal exceptions 1345 // we don't want to have our finallyBlock suppress 1346 originalThrowable = cause 1347 try { 1348 logError("Aborting task", originalThrowable) 1349 TaskContext.get().asInstanceOf[TaskContextImpl].markTaskFailed(originalThrowable) 1350 catchBlock 1351 } catch { 1352 case t: Throwable => 1353 originalThrowable.addSuppressed(t) 1354 logWarning(s"Suppressing exception in catch: " + t.getMessage, t) 1355 } 1356 throw originalThrowable 1357 } finally { 1358 try { 1359 finallyBlock 1360 } catch { 1361 case t: Throwable => 1362 if (originalThrowable != null) { 1363 originalThrowable.addSuppressed(t) 1364 logWarning(s"Suppressing exception in finally: " + t.getMessage, t) 1365 throw originalThrowable 1366 } else { 1367 throw t 1368 } 1369 } 1370 } 1371 } 1372 1373 /** Default filtering function for finding call sites using `getCallSite`. */ 1374 private def sparkInternalExclusionFunction(className: String): Boolean = { 1375 // A regular expression to match classes of the internal Spark API's 1376 // that we want to skip when finding the call site of a method. 1377 val SPARK_CORE_CLASS_REGEX = 1378 """^org\.apache\.spark(\.api\.java)?(\.util)?(\.rdd)?(\.broadcast)?\.[A-Z]""".r 1379 val SPARK_SQL_CLASS_REGEX = """^org\.apache\.spark\.sql.*""".r 1380 val SCALA_CORE_CLASS_PREFIX = "scala" 1381 val isSparkClass = SPARK_CORE_CLASS_REGEX.findFirstIn(className).isDefined || 1382 SPARK_SQL_CLASS_REGEX.findFirstIn(className).isDefined 1383 val isScalaClass = className.startsWith(SCALA_CORE_CLASS_PREFIX) 1384 // If the class is a Spark internal class or a Scala class, then exclude. 1385 isSparkClass || isScalaClass 1386 } 1387 1388 /** 1389 * When called inside a class in the spark package, returns the name of the user code class 1390 * (outside the spark package) that called into Spark, as well as which Spark method they called. 1391 * This is used, for example, to tell users where in their code each RDD got created. 1392 * 1393 * @param skipClass Function that is used to exclude non-user-code classes. 1394 */ 1395 def getCallSite(skipClass: String => Boolean = sparkInternalExclusionFunction): CallSite = { 1396 // Keep crawling up the stack trace until we find the first function not inside of the spark 1397 // package. We track the last (shallowest) contiguous Spark method. This might be an RDD 1398 // transformation, a SparkContext function (such as parallelize), or anything else that leads 1399 // to instantiation of an RDD. We also track the first (deepest) user method, file, and line. 1400 var lastSparkMethod = "<unknown>" 1401 var firstUserFile = "<unknown>" 1402 var firstUserLine = 0 1403 var insideSpark = true 1404 var callStack = new ArrayBuffer[String]() :+ "<unknown>" 1405 1406 Thread.currentThread.getStackTrace().foreach { ste: StackTraceElement => 1407 // When running under some profilers, the current stack trace might contain some bogus 1408 // frames. This is intended to ensure that we don't crash in these situations by 1409 // ignoring any frames that we can't examine. 1410 if (ste != null && ste.getMethodName != null 1411 && !ste.getMethodName.contains("getStackTrace")) { 1412 if (insideSpark) { 1413 if (skipClass(ste.getClassName)) { 1414 lastSparkMethod = if (ste.getMethodName == "<init>") { 1415 // Spark method is a constructor; get its class name 1416 ste.getClassName.substring(ste.getClassName.lastIndexOf('.') + 1) 1417 } else { 1418 ste.getMethodName 1419 } 1420 callStack(0) = ste.toString // Put last Spark method on top of the stack trace. 1421 } else { 1422 if (ste.getFileName != null) { 1423 firstUserFile = ste.getFileName 1424 if (ste.getLineNumber >= 0) { 1425 firstUserLine = ste.getLineNumber 1426 } 1427 } 1428 callStack += ste.toString 1429 insideSpark = false 1430 } 1431 } else { 1432 callStack += ste.toString 1433 } 1434 } 1435 } 1436 1437 val callStackDepth = System.getProperty("spark.callstack.depth", "20").toInt 1438 val shortForm = 1439 if (firstUserFile == "HiveSessionImpl.java") { 1440 // To be more user friendly, show a nicer string for queries submitted from the JDBC 1441 // server. 1442 "Spark JDBC Server Query" 1443 } else { 1444 s"$lastSparkMethod at $firstUserFile:$firstUserLine" 1445 } 1446 val longForm = callStack.take(callStackDepth).mkString("\n") 1447 1448 CallSite(shortForm, longForm) 1449 } 1450 1451 private val UNCOMPRESSED_LOG_FILE_LENGTH_CACHE_SIZE_CONF = 1452 "spark.worker.ui.compressedLogFileLengthCacheSize" 1453 private val DEFAULT_UNCOMPRESSED_LOG_FILE_LENGTH_CACHE_SIZE = 100 1454 private var compressedLogFileLengthCache: LoadingCache[String, java.lang.Long] = null 1455 private def getCompressedLogFileLengthCache( 1456 sparkConf: SparkConf): LoadingCache[String, java.lang.Long] = this.synchronized { 1457 if (compressedLogFileLengthCache == null) { 1458 val compressedLogFileLengthCacheSize = sparkConf.getInt( 1459 UNCOMPRESSED_LOG_FILE_LENGTH_CACHE_SIZE_CONF, 1460 DEFAULT_UNCOMPRESSED_LOG_FILE_LENGTH_CACHE_SIZE) 1461 compressedLogFileLengthCache = CacheBuilder.newBuilder() 1462 .maximumSize(compressedLogFileLengthCacheSize) 1463 .build[String, java.lang.Long](new CacheLoader[String, java.lang.Long]() { 1464 override def load(path: String): java.lang.Long = { 1465 Utils.getCompressedFileLength(new File(path)) 1466 } 1467 }) 1468 } 1469 compressedLogFileLengthCache 1470 } 1471 1472 /** 1473 * Return the file length, if the file is compressed it returns the uncompressed file length. 1474 * It also caches the uncompressed file size to avoid repeated decompression. The cache size is 1475 * read from workerConf. 1476 */ 1477 def getFileLength(file: File, workConf: SparkConf): Long = { 1478 if (file.getName.endsWith(".gz")) { 1479 getCompressedLogFileLengthCache(workConf).get(file.getAbsolutePath) 1480 } else { 1481 file.length 1482 } 1483 } 1484 1485 /** Return uncompressed file length of a compressed file. */ 1486 private def getCompressedFileLength(file: File): Long = { 1487 try { 1488 // Uncompress .gz file to determine file size. 1489 var fileSize = 0L 1490 val gzInputStream = new GZIPInputStream(new FileInputStream(file)) 1491 val bufSize = 1024 1492 val buf = new Array[Byte](bufSize) 1493 var numBytes = ByteStreams.read(gzInputStream, buf, 0, bufSize) 1494 while (numBytes > 0) { 1495 fileSize += numBytes 1496 numBytes = ByteStreams.read(gzInputStream, buf, 0, bufSize) 1497 } 1498 fileSize 1499 } catch { 1500 case e: Throwable => 1501 logError(s"Cannot get file length of ${file}", e) 1502 throw e 1503 } 1504 } 1505 1506 /** Return a string containing part of a file from byte 'start' to 'end'. */ 1507 def offsetBytes(path: String, length: Long, start: Long, end: Long): String = { 1508 val file = new File(path) 1509 val effectiveEnd = math.min(length, end) 1510 val effectiveStart = math.max(0, start) 1511 val buff = new Array[Byte]((effectiveEnd-effectiveStart).toInt) 1512 val stream = if (path.endsWith(".gz")) { 1513 new GZIPInputStream(new FileInputStream(file)) 1514 } else { 1515 new FileInputStream(file) 1516 } 1517 1518 try { 1519 ByteStreams.skipFully(stream, effectiveStart) 1520 ByteStreams.readFully(stream, buff) 1521 } finally { 1522 stream.close() 1523 } 1524 Source.fromBytes(buff).mkString 1525 } 1526 1527 /** 1528 * Return a string containing data across a set of files. The `startIndex` 1529 * and `endIndex` is based on the cumulative size of all the files take in 1530 * the given order. See figure below for more details. 1531 */ 1532 def offsetBytes(files: Seq[File], fileLengths: Seq[Long], start: Long, end: Long): String = { 1533 assert(files.length == fileLengths.length) 1534 val startIndex = math.max(start, 0) 1535 val endIndex = math.min(end, fileLengths.sum) 1536 val fileToLength = files.zip(fileLengths).toMap 1537 logDebug("Log files: \n" + fileToLength.mkString("\n")) 1538 1539 val stringBuffer = new StringBuffer((endIndex - startIndex).toInt) 1540 var sum = 0L 1541 files.zip(fileLengths).foreach { case (file, fileLength) => 1542 val startIndexOfFile = sum 1543 val endIndexOfFile = sum + fileToLength(file) 1544 logDebug(s"Processing file $file, " + 1545 s"with start index = $startIndexOfFile, end index = $endIndex") 1546 1547 /* 1548 ____________ 1549 range 1: | | 1550 | case A | 1551 1552 files: |==== file 1 ====|====== file 2 ======|===== file 3 =====| 1553 1554 | case B . case C . case D | 1555 range 2: |___________.____________________.______________| 1556 */ 1557 1558 if (startIndex <= startIndexOfFile && endIndex >= endIndexOfFile) { 1559 // Case C: read the whole file 1560 stringBuffer.append(offsetBytes(file.getAbsolutePath, fileLength, 0, fileToLength(file))) 1561 } else if (startIndex > startIndexOfFile && startIndex < endIndexOfFile) { 1562 // Case A and B: read from [start of required range] to [end of file / end of range] 1563 val effectiveStartIndex = startIndex - startIndexOfFile 1564 val effectiveEndIndex = math.min(endIndex - startIndexOfFile, fileToLength(file)) 1565 stringBuffer.append(Utils.offsetBytes( 1566 file.getAbsolutePath, fileLength, effectiveStartIndex, effectiveEndIndex)) 1567 } else if (endIndex > startIndexOfFile && endIndex < endIndexOfFile) { 1568 // Case D: read from [start of file] to [end of require range] 1569 val effectiveStartIndex = math.max(startIndex - startIndexOfFile, 0) 1570 val effectiveEndIndex = endIndex - startIndexOfFile 1571 stringBuffer.append(Utils.offsetBytes( 1572 file.getAbsolutePath, fileLength, effectiveStartIndex, effectiveEndIndex)) 1573 } 1574 sum += fileToLength(file) 1575 logDebug(s"After processing file $file, string built is ${stringBuffer.toString}") 1576 } 1577 stringBuffer.toString 1578 } 1579 1580 /** 1581 * Clone an object using a Spark serializer. 1582 */ 1583 def clone[T: ClassTag](value: T, serializer: SerializerInstance): T = { 1584 serializer.deserialize[T](serializer.serialize(value)) 1585 } 1586 1587 private def isSpace(c: Char): Boolean = { 1588 " \t\r\n".indexOf(c) != -1 1589 } 1590 1591 /** 1592 * Split a string of potentially quoted arguments from the command line the way that a shell 1593 * would do it to determine arguments to a command. For example, if the string is 'a "b c" d', 1594 * then it would be parsed as three arguments: 'a', 'b c' and 'd'. 1595 */ 1596 def splitCommandString(s: String): Seq[String] = { 1597 val buf = new ArrayBuffer[String] 1598 var inWord = false 1599 var inSingleQuote = false 1600 var inDoubleQuote = false 1601 val curWord = new StringBuilder 1602 def endWord() { 1603 buf += curWord.toString 1604 curWord.clear() 1605 } 1606 var i = 0 1607 while (i < s.length) { 1608 val nextChar = s.charAt(i) 1609 if (inDoubleQuote) { 1610 if (nextChar == '"') { 1611 inDoubleQuote = false 1612 } else if (nextChar == '\\') { 1613 if (i < s.length - 1) { 1614 // Append the next character directly, because only " and \ may be escaped in 1615 // double quotes after the shell's own expansion 1616 curWord.append(s.charAt(i + 1)) 1617 i += 1 1618 } 1619 } else { 1620 curWord.append(nextChar) 1621 } 1622 } else if (inSingleQuote) { 1623 if (nextChar == '\'') { 1624 inSingleQuote = false 1625 } else { 1626 curWord.append(nextChar) 1627 } 1628 // Backslashes are not treated specially in single quotes 1629 } else if (nextChar == '"') { 1630 inWord = true 1631 inDoubleQuote = true 1632 } else if (nextChar == '\'') { 1633 inWord = true 1634 inSingleQuote = true 1635 } else if (!isSpace(nextChar)) { 1636 curWord.append(nextChar) 1637 inWord = true 1638 } else if (inWord && isSpace(nextChar)) { 1639 endWord() 1640 inWord = false 1641 } 1642 i += 1 1643 } 1644 if (inWord || inDoubleQuote || inSingleQuote) { 1645 endWord() 1646 } 1647 buf 1648 } 1649 1650 /* Calculates 'x' modulo 'mod', takes to consideration sign of x, 1651 * i.e. if 'x' is negative, than 'x' % 'mod' is negative too 1652 * so function return (x % mod) + mod in that case. 1653 */ 1654 def nonNegativeMod(x: Int, mod: Int): Int = { 1655 val rawMod = x % mod 1656 rawMod + (if (rawMod < 0) mod else 0) 1657 } 1658 1659 // Handles idiosyncrasies with hash (add more as required) 1660 // This method should be kept in sync with 1661 // org.apache.spark.network.util.JavaUtils#nonNegativeHash(). 1662 def nonNegativeHash(obj: AnyRef): Int = { 1663 1664 // Required ? 1665 if (obj eq null) return 0 1666 1667 val hash = obj.hashCode 1668 // math.abs fails for Int.MinValue 1669 val hashAbs = if (Int.MinValue != hash) math.abs(hash) else 0 1670 1671 // Nothing else to guard against ? 1672 hashAbs 1673 } 1674 1675 /** 1676 * NaN-safe version of `java.lang.Double.compare()` which allows NaN values to be compared 1677 * according to semantics where NaN == NaN and NaN is greater than any non-NaN double. 1678 */ 1679 def nanSafeCompareDoubles(x: Double, y: Double): Int = { 1680 val xIsNan: Boolean = java.lang.Double.isNaN(x) 1681 val yIsNan: Boolean = java.lang.Double.isNaN(y) 1682 if ((xIsNan && yIsNan) || (x == y)) 0 1683 else if (xIsNan) 1 1684 else if (yIsNan) -1 1685 else if (x > y) 1 1686 else -1 1687 } 1688 1689 /** 1690 * NaN-safe version of `java.lang.Float.compare()` which allows NaN values to be compared 1691 * according to semantics where NaN == NaN and NaN is greater than any non-NaN float. 1692 */ 1693 def nanSafeCompareFloats(x: Float, y: Float): Int = { 1694 val xIsNan: Boolean = java.lang.Float.isNaN(x) 1695 val yIsNan: Boolean = java.lang.Float.isNaN(y) 1696 if ((xIsNan && yIsNan) || (x == y)) 0 1697 else if (xIsNan) 1 1698 else if (yIsNan) -1 1699 else if (x > y) 1 1700 else -1 1701 } 1702 1703 /** 1704 * Returns the system properties map that is thread-safe to iterator over. It gets the 1705 * properties which have been set explicitly, as well as those for which only a default value 1706 * has been defined. 1707 */ 1708 def getSystemProperties: Map[String, String] = { 1709 System.getProperties.stringPropertyNames().asScala 1710 .map(key => (key, System.getProperty(key))).toMap 1711 } 1712 1713 /** 1714 * Method executed for repeating a task for side effects. 1715 * Unlike a for comprehension, it permits JVM JIT optimization 1716 */ 1717 def times(numIters: Int)(f: => Unit): Unit = { 1718 var i = 0 1719 while (i < numIters) { 1720 f 1721 i += 1 1722 } 1723 } 1724 1725 /** 1726 * Timing method based on iterations that permit JVM JIT optimization. 1727 * 1728 * @param numIters number of iterations 1729 * @param f function to be executed. If prepare is not None, the running time of each call to f 1730 * must be an order of magnitude longer than one millisecond for accurate timing. 1731 * @param prepare function to be executed before each call to f. Its running time doesn't count. 1732 * @return the total time across all iterations (not counting preparation time) 1733 */ 1734 def timeIt(numIters: Int)(f: => Unit, prepare: Option[() => Unit] = None): Long = { 1735 if (prepare.isEmpty) { 1736 val start = System.currentTimeMillis 1737 times(numIters)(f) 1738 System.currentTimeMillis - start 1739 } else { 1740 var i = 0 1741 var sum = 0L 1742 while (i < numIters) { 1743 prepare.get.apply() 1744 val start = System.currentTimeMillis 1745 f 1746 sum += System.currentTimeMillis - start 1747 i += 1 1748 } 1749 sum 1750 } 1751 } 1752 1753 /** 1754 * Counts the number of elements of an iterator using a while loop rather than calling 1755 * [[scala.collection.Iterator#size]] because it uses a for loop, which is slightly slower 1756 * in the current version of Scala. 1757 */ 1758 def getIteratorSize[T](iterator: Iterator[T]): Long = { 1759 var count = 0L 1760 while (iterator.hasNext) { 1761 count += 1L 1762 iterator.next() 1763 } 1764 count 1765 } 1766 1767 /** 1768 * Generate a zipWithIndex iterator, avoid index value overflowing problem 1769 * in scala's zipWithIndex 1770 */ 1771 def getIteratorZipWithIndex[T](iterator: Iterator[T], startIndex: Long): Iterator[(T, Long)] = { 1772 new Iterator[(T, Long)] { 1773 require(startIndex >= 0, "startIndex should be >= 0.") 1774 var index: Long = startIndex - 1L 1775 def hasNext: Boolean = iterator.hasNext 1776 def next(): (T, Long) = { 1777 index += 1L 1778 (iterator.next(), index) 1779 } 1780 } 1781 } 1782 1783 /** 1784 * Creates a symlink. 1785 * 1786 * @param src absolute path to the source 1787 * @param dst relative path for the destination 1788 */ 1789 def symlink(src: File, dst: File): Unit = { 1790 if (!src.isAbsolute()) { 1791 throw new IOException("Source must be absolute") 1792 } 1793 if (dst.isAbsolute()) { 1794 throw new IOException("Destination must be relative") 1795 } 1796 Files.createSymbolicLink(dst.toPath, src.toPath) 1797 } 1798 1799 1800 /** Return the class name of the given object, removing all dollar signs */ 1801 def getFormattedClassName(obj: AnyRef): String = { 1802 obj.getClass.getSimpleName.replace("$", "") 1803 } 1804 1805 /** Return an option that translates JNothing to None */ 1806 def jsonOption(json: JValue): Option[JValue] = { 1807 json match { 1808 case JNothing => None 1809 case value: JValue => Some(value) 1810 } 1811 } 1812 1813 /** Return an empty JSON object */ 1814 def emptyJson: JsonAST.JObject = JObject(List[JField]()) 1815 1816 /** 1817 * Return a Hadoop FileSystem with the scheme encoded in the given path. 1818 */ 1819 def getHadoopFileSystem(path: URI, conf: Configuration): FileSystem = { 1820 FileSystem.get(path, conf) 1821 } 1822 1823 /** 1824 * Return a Hadoop FileSystem with the scheme encoded in the given path. 1825 */ 1826 def getHadoopFileSystem(path: String, conf: Configuration): FileSystem = { 1827 getHadoopFileSystem(new URI(path), conf) 1828 } 1829 1830 /** 1831 * Return the absolute path of a file in the given directory. 1832 */ 1833 def getFilePath(dir: File, fileName: String): Path = { 1834 assert(dir.isDirectory) 1835 val path = new File(dir, fileName).getAbsolutePath 1836 new Path(path) 1837 } 1838 1839 /** 1840 * Whether the underlying operating system is Windows. 1841 */ 1842 val isWindows = SystemUtils.IS_OS_WINDOWS 1843 1844 /** 1845 * Whether the underlying operating system is Mac OS X. 1846 */ 1847 val isMac = SystemUtils.IS_OS_MAC_OSX 1848 1849 /** 1850 * Pattern for matching a Windows drive, which contains only a single alphabet character. 1851 */ 1852 val windowsDrive = "([a-zA-Z])".r 1853 1854 /** 1855 * Indicates whether Spark is currently running unit tests. 1856 */ 1857 def isTesting: Boolean = { 1858 sys.env.contains("SPARK_TESTING") || sys.props.contains("spark.testing") 1859 } 1860 1861 /** 1862 * Strip the directory from a path name 1863 */ 1864 def stripDirectory(path: String): String = { 1865 new File(path).getName 1866 } 1867 1868 /** 1869 * Terminates a process waiting for at most the specified duration. 1870 * 1871 * @return the process exit value if it was successfully terminated, else None 1872 */ 1873 def terminateProcess(process: Process, timeoutMs: Long): Option[Int] = { 1874 // Politely destroy first 1875 process.destroy() 1876 1877 if (waitForProcess(process, timeoutMs)) { 1878 // Successful exit 1879 Option(process.exitValue()) 1880 } else { 1881 // Java 8 added a new API which will more forcibly kill the process. Use that if available. 1882 try { 1883 classOf[Process].getMethod("destroyForcibly").invoke(process) 1884 } catch { 1885 case _: NoSuchMethodException => return None // Not available; give up 1886 case NonFatal(e) => logWarning("Exception when attempting to kill process", e) 1887 } 1888 // Wait, again, although this really should return almost immediately 1889 if (waitForProcess(process, timeoutMs)) { 1890 Option(process.exitValue()) 1891 } else { 1892 logWarning("Timed out waiting to forcibly kill process") 1893 None 1894 } 1895 } 1896 } 1897 1898 /** 1899 * Wait for a process to terminate for at most the specified duration. 1900 * 1901 * @return whether the process actually terminated before the given timeout. 1902 */ 1903 def waitForProcess(process: Process, timeoutMs: Long): Boolean = { 1904 try { 1905 // Use Java 8 method if available 1906 classOf[Process].getMethod("waitFor", java.lang.Long.TYPE, classOf[TimeUnit]) 1907 .invoke(process, timeoutMs.asInstanceOf[java.lang.Long], TimeUnit.MILLISECONDS) 1908 .asInstanceOf[Boolean] 1909 } catch { 1910 case _: NoSuchMethodException => 1911 // Otherwise implement it manually 1912 var terminated = false 1913 val startTime = System.currentTimeMillis 1914 while (!terminated) { 1915 try { 1916 process.exitValue() 1917 terminated = true 1918 } catch { 1919 case e: IllegalThreadStateException => 1920 // Process not terminated yet 1921 if (System.currentTimeMillis - startTime > timeoutMs) { 1922 return false 1923 } 1924 Thread.sleep(100) 1925 } 1926 } 1927 true 1928 } 1929 } 1930 1931 /** 1932 * Return the stderr of a process after waiting for the process to terminate. 1933 * If the process does not terminate within the specified timeout, return None. 1934 */ 1935 def getStderr(process: Process, timeoutMs: Long): Option[String] = { 1936 val terminated = Utils.waitForProcess(process, timeoutMs) 1937 if (terminated) { 1938 Some(Source.fromInputStream(process.getErrorStream).getLines().mkString("\n")) 1939 } else { 1940 None 1941 } 1942 } 1943 1944 /** 1945 * Execute the given block, logging and re-throwing any uncaught exception. 1946 * This is particularly useful for wrapping code that runs in a thread, to ensure 1947 * that exceptions are printed, and to avoid having to catch Throwable. 1948 */ 1949 def logUncaughtExceptions[T](f: => T): T = { 1950 try { 1951 f 1952 } catch { 1953 case ct: ControlThrowable => 1954 throw ct 1955 case t: Throwable => 1956 logError(s"Uncaught exception in thread ${Thread.currentThread().getName}", t) 1957 throw t 1958 } 1959 } 1960 1961 /** Executes the given block in a Try, logging any uncaught exceptions. */ 1962 def tryLog[T](f: => T): Try[T] = { 1963 try { 1964 val res = f 1965 scala.util.Success(res) 1966 } catch { 1967 case ct: ControlThrowable => 1968 throw ct 1969 case t: Throwable => 1970 logError(s"Uncaught exception in thread ${Thread.currentThread().getName}", t) 1971 scala.util.Failure(t) 1972 } 1973 } 1974 1975 /** Returns true if the given exception was fatal. See docs for scala.util.control.NonFatal. */ 1976 def isFatalError(e: Throwable): Boolean = { 1977 e match { 1978 case NonFatal(_) | 1979 _: InterruptedException | 1980 _: NotImplementedError | 1981 _: ControlThrowable | 1982 _: LinkageError => 1983 false 1984 case _ => 1985 true 1986 } 1987 } 1988 1989 /** 1990 * Return a well-formed URI for the file described by a user input string. 1991 * 1992 * If the supplied path does not contain a scheme, or is a relative path, it will be 1993 * converted into an absolute path with a file:// scheme. 1994 */ 1995 def resolveURI(path: String): URI = { 1996 try { 1997 val uri = new URI(path) 1998 if (uri.getScheme() != null) { 1999 return uri 2000 } 2001 // make sure to handle if the path has a fragment (applies to yarn 2002 // distributed cache) 2003 if (uri.getFragment() != null) { 2004 val absoluteURI = new File(uri.getPath()).getAbsoluteFile().toURI() 2005 return new URI(absoluteURI.getScheme(), absoluteURI.getHost(), absoluteURI.getPath(), 2006 uri.getFragment()) 2007 } 2008 } catch { 2009 case e: URISyntaxException => 2010 } 2011 new File(path).getAbsoluteFile().toURI() 2012 } 2013 2014 /** Resolve a comma-separated list of paths. */ 2015 def resolveURIs(paths: String): String = { 2016 if (paths == null || paths.trim.isEmpty) { 2017 "" 2018 } else { 2019 paths.split(",").filter(_.trim.nonEmpty).map { p => Utils.resolveURI(p) }.mkString(",") 2020 } 2021 } 2022 2023 /** Return all non-local paths from a comma-separated list of paths. */ 2024 def nonLocalPaths(paths: String, testWindows: Boolean = false): Array[String] = { 2025 val windows = isWindows || testWindows 2026 if (paths == null || paths.trim.isEmpty) { 2027 Array.empty 2028 } else { 2029 paths.split(",").filter { p => 2030 val uri = resolveURI(p) 2031 Option(uri.getScheme).getOrElse("file") match { 2032 case windowsDrive(d) if windows => false 2033 case "local" | "file" => false 2034 case _ => true 2035 } 2036 } 2037 } 2038 } 2039 2040 /** 2041 * Load default Spark properties from the given file. If no file is provided, 2042 * use the common defaults file. This mutates state in the given SparkConf and 2043 * in this JVM's system properties if the config specified in the file is not 2044 * already set. Return the path of the properties file used. 2045 */ 2046 def loadDefaultSparkProperties(conf: SparkConf, filePath: String = null): String = { 2047 val path = Option(filePath).getOrElse(getDefaultPropertiesFile()) 2048 Option(path).foreach { confFile => 2049 getPropertiesFromFile(confFile).filter { case (k, v) => 2050 k.startsWith("spark.") 2051 }.foreach { case (k, v) => 2052 conf.setIfMissing(k, v) 2053 sys.props.getOrElseUpdate(k, v) 2054 } 2055 } 2056 path 2057 } 2058 2059 /** Load properties present in the given file. */ 2060 def getPropertiesFromFile(filename: String): Map[String, String] = { 2061 val file = new File(filename) 2062 require(file.exists(), s"Properties file $file does not exist") 2063 require(file.isFile(), s"Properties file $file is not a normal file") 2064 2065 val inReader = new InputStreamReader(new FileInputStream(file), StandardCharsets.UTF_8) 2066 try { 2067 val properties = new Properties() 2068 properties.load(inReader) 2069 properties.stringPropertyNames().asScala.map( 2070 k => (k, properties.getProperty(k).trim)).toMap 2071 } catch { 2072 case e: IOException => 2073 throw new SparkException(s"Failed when loading Spark properties from $filename", e) 2074 } finally { 2075 inReader.close() 2076 } 2077 } 2078 2079 /** Return the path of the default Spark properties file. */ 2080 def getDefaultPropertiesFile(env: Map[String, String] = sys.env): String = { 2081 env.get("SPARK_CONF_DIR") 2082 .orElse(env.get("SPARK_HOME").map { t => s"$t${File.separator}conf" }) 2083 .map { t => new File(s"$t${File.separator}spark-defaults.conf")} 2084 .filter(_.isFile) 2085 .map(_.getAbsolutePath) 2086 .orNull 2087 } 2088 2089 /** 2090 * Return a nice string representation of the exception. It will call "printStackTrace" to 2091 * recursively generate the stack trace including the exception and its causes. 2092 */ 2093 def exceptionString(e: Throwable): String = { 2094 if (e == null) { 2095 "" 2096 } else { 2097 // Use e.printStackTrace here because e.getStackTrace doesn't include the cause 2098 val stringWriter = new StringWriter() 2099 e.printStackTrace(new PrintWriter(stringWriter)) 2100 stringWriter.toString 2101 } 2102 } 2103 2104 private implicit class Lock(lock: LockInfo) { 2105 def lockString: String = { 2106 lock match { 2107 case monitor: MonitorInfo => 2108 s"Monitor(${lock.getClassName}@${lock.getIdentityHashCode}})" 2109 case _ => 2110 s"Lock(${lock.getClassName}@${lock.getIdentityHashCode}})" 2111 } 2112 } 2113 } 2114 2115 /** Return a thread dump of all threads' stacktraces. Used to capture dumps for the web UI */ 2116 def getThreadDump(): Array[ThreadStackTrace] = { 2117 // We need to filter out null values here because dumpAllThreads() may return null array 2118 // elements for threads that are dead / don't exist. 2119 val threadInfos = ManagementFactory.getThreadMXBean.dumpAllThreads(true, true).filter(_ != null) 2120 threadInfos.sortBy(_.getThreadId).map(threadInfoToThreadStackTrace) 2121 } 2122 2123 def getThreadDumpForThread(threadId: Long): Option[ThreadStackTrace] = { 2124 if (threadId <= 0) { 2125 None 2126 } else { 2127 // The Int.MaxValue here requests the entire untruncated stack trace of the thread: 2128 val threadInfo = 2129 Option(ManagementFactory.getThreadMXBean.getThreadInfo(threadId, Int.MaxValue)) 2130 threadInfo.map(threadInfoToThreadStackTrace) 2131 } 2132 } 2133 2134 private def threadInfoToThreadStackTrace(threadInfo: ThreadInfo): ThreadStackTrace = { 2135 val monitors = threadInfo.getLockedMonitors.map(m => m.getLockedStackFrame -> m).toMap 2136 val stackTrace = threadInfo.getStackTrace.map { frame => 2137 monitors.get(frame) match { 2138 case Some(monitor) => 2139 monitor.getLockedStackFrame.toString + s" => holding ${monitor.lockString}" 2140 case None => 2141 frame.toString 2142 } 2143 }.mkString("\n") 2144 2145 // use a set to dedup re-entrant locks that are held at multiple places 2146 val heldLocks = 2147 (threadInfo.getLockedSynchronizers ++ threadInfo.getLockedMonitors).map(_.lockString).toSet 2148 2149 ThreadStackTrace( 2150 threadId = threadInfo.getThreadId, 2151 threadName = threadInfo.getThreadName, 2152 threadState = threadInfo.getThreadState, 2153 stackTrace = stackTrace, 2154 blockedByThreadId = 2155 if (threadInfo.getLockOwnerId < 0) None else Some(threadInfo.getLockOwnerId), 2156 blockedByLock = Option(threadInfo.getLockInfo).map(_.lockString).getOrElse(""), 2157 holdingLocks = heldLocks.toSeq) 2158 } 2159 2160 /** 2161 * Convert all spark properties set in the given SparkConf to a sequence of java options. 2162 */ 2163 def sparkJavaOpts(conf: SparkConf, filterKey: (String => Boolean) = _ => true): Seq[String] = { 2164 conf.getAll 2165 .filter { case (k, _) => filterKey(k) } 2166 .map { case (k, v) => s"-D$k=$v" } 2167 } 2168 2169 /** 2170 * Maximum number of retries when binding to a port before giving up. 2171 */ 2172 def portMaxRetries(conf: SparkConf): Int = { 2173 val maxRetries = conf.getOption("spark.port.maxRetries").map(_.toInt) 2174 if (conf.contains("spark.testing")) { 2175 // Set a higher number of retries for tests... 2176 maxRetries.getOrElse(100) 2177 } else { 2178 maxRetries.getOrElse(16) 2179 } 2180 } 2181 2182 /** 2183 * Attempt to start a service on the given port, or fail after a number of attempts. 2184 * Each subsequent attempt uses 1 + the port used in the previous attempt (unless the port is 0). 2185 * 2186 * @param startPort The initial port to start the service on. 2187 * @param startService Function to start service on a given port. 2188 * This is expected to throw java.net.BindException on port collision. 2189 * @param conf A SparkConf used to get the maximum number of retries when binding to a port. 2190 * @param serviceName Name of the service. 2191 * @return (service: T, port: Int) 2192 */ 2193 def startServiceOnPort[T]( 2194 startPort: Int, 2195 startService: Int => (T, Int), 2196 conf: SparkConf, 2197 serviceName: String = ""): (T, Int) = { 2198 2199 require(startPort == 0 || (1024 <= startPort && startPort < 65536), 2200 "startPort should be between 1024 and 65535 (inclusive), or 0 for a random free port.") 2201 2202 val serviceString = if (serviceName.isEmpty) "" else s" '$serviceName'" 2203 val maxRetries = portMaxRetries(conf) 2204 for (offset <- 0 to maxRetries) { 2205 // Do not increment port if startPort is 0, which is treated as a special port 2206 val tryPort = if (startPort == 0) { 2207 startPort 2208 } else { 2209 // If the new port wraps around, do not try a privilege port 2210 ((startPort + offset - 1024) % (65536 - 1024)) + 1024 2211 } 2212 try { 2213 val (service, port) = startService(tryPort) 2214 logInfo(s"Successfully started service$serviceString on port $port.") 2215 return (service, port) 2216 } catch { 2217 case e: Exception if isBindCollision(e) => 2218 if (offset >= maxRetries) { 2219 val exceptionMessage = s"${e.getMessage}: Service$serviceString failed after " + 2220 s"$maxRetries retries (starting from $startPort)! Consider explicitly setting " + 2221 s"the appropriate port for the service$serviceString (for example spark.ui.port " + 2222 s"for SparkUI) to an available port or increasing spark.port.maxRetries." 2223 val exception = new BindException(exceptionMessage) 2224 // restore original stack trace 2225 exception.setStackTrace(e.getStackTrace) 2226 throw exception 2227 } 2228 logWarning(s"Service$serviceString could not bind on port $tryPort. " + 2229 s"Attempting port ${tryPort + 1}.") 2230 } 2231 } 2232 // Should never happen 2233 throw new SparkException(s"Failed to start service$serviceString on port $startPort") 2234 } 2235 2236 /** 2237 * Return whether the exception is caused by an address-port collision when binding. 2238 */ 2239 def isBindCollision(exception: Throwable): Boolean = { 2240 exception match { 2241 case e: BindException => 2242 if (e.getMessage != null) { 2243 return true 2244 } 2245 isBindCollision(e.getCause) 2246 case e: MultiException => 2247 e.getThrowables.asScala.exists(isBindCollision) 2248 case e: NativeIoException => 2249 (e.getMessage != null && e.getMessage.startsWith("bind() failed: ")) || 2250 isBindCollision(e.getCause) 2251 case e: Exception => isBindCollision(e.getCause) 2252 case _ => false 2253 } 2254 } 2255 2256 /** 2257 * configure a new log4j level 2258 */ 2259 def setLogLevel(l: org.apache.log4j.Level) { 2260 org.apache.log4j.Logger.getRootLogger().setLevel(l) 2261 } 2262 2263 /** 2264 * config a log4j properties used for testsuite 2265 */ 2266 def configTestLog4j(level: String): Unit = { 2267 val pro = new Properties() 2268 pro.put("log4j.rootLogger", s"$level, console") 2269 pro.put("log4j.appender.console", "org.apache.log4j.ConsoleAppender") 2270 pro.put("log4j.appender.console.target", "System.err") 2271 pro.put("log4j.appender.console.layout", "org.apache.log4j.PatternLayout") 2272 pro.put("log4j.appender.console.layout.ConversionPattern", 2273 "%d{yy/MM/dd HH:mm:ss} %p %c{1}: %m%n") 2274 PropertyConfigurator.configure(pro) 2275 } 2276 2277 /** 2278 * If the given URL connection is HttpsURLConnection, it sets the SSL socket factory and 2279 * the host verifier from the given security manager. 2280 */ 2281 def setupSecureURLConnection(urlConnection: URLConnection, sm: SecurityManager): URLConnection = { 2282 urlConnection match { 2283 case https: HttpsURLConnection => 2284 sm.sslSocketFactory.foreach(https.setSSLSocketFactory) 2285 sm.hostnameVerifier.foreach(https.setHostnameVerifier) 2286 https 2287 case connection => connection 2288 } 2289 } 2290 2291 def invoke( 2292 clazz: Class[_], 2293 obj: AnyRef, 2294 methodName: String, 2295 args: (Class[_], AnyRef)*): AnyRef = { 2296 val (types, values) = args.unzip 2297 val method = clazz.getDeclaredMethod(methodName, types: _*) 2298 method.setAccessible(true) 2299 method.invoke(obj, values.toSeq: _*) 2300 } 2301 2302 // Limit of bytes for total size of results (default is 1GB) 2303 def getMaxResultSize(conf: SparkConf): Long = { 2304 memoryStringToMb(conf.get("spark.driver.maxResultSize", "1g")).toLong << 20 2305 } 2306 2307 /** 2308 * Return the current system LD_LIBRARY_PATH name 2309 */ 2310 def libraryPathEnvName: String = { 2311 if (isWindows) { 2312 "PATH" 2313 } else if (isMac) { 2314 "DYLD_LIBRARY_PATH" 2315 } else { 2316 "LD_LIBRARY_PATH" 2317 } 2318 } 2319 2320 /** 2321 * Return the prefix of a command that appends the given library paths to the 2322 * system-specific library path environment variable. On Unix, for instance, 2323 * this returns the string LD_LIBRARY_PATH="path1:path2:$LD_LIBRARY_PATH". 2324 */ 2325 def libraryPathEnvPrefix(libraryPaths: Seq[String]): String = { 2326 val libraryPathScriptVar = if (isWindows) { 2327 s"%${libraryPathEnvName}%" 2328 } else { 2329 "$" + libraryPathEnvName 2330 } 2331 val libraryPath = (libraryPaths :+ libraryPathScriptVar).mkString("\"", 2332 File.pathSeparator, "\"") 2333 val ampersand = if (Utils.isWindows) { 2334 " &" 2335 } else { 2336 "" 2337 } 2338 s"$libraryPathEnvName=$libraryPath$ampersand" 2339 } 2340 2341 /** 2342 * Return the value of a config either through the SparkConf or the Hadoop configuration 2343 * if this is Yarn mode. In the latter case, this defaults to the value set through SparkConf 2344 * if the key is not set in the Hadoop configuration. 2345 */ 2346 def getSparkOrYarnConfig(conf: SparkConf, key: String, default: String): String = { 2347 val sparkValue = conf.get(key, default) 2348 if (SparkHadoopUtil.get.isYarnMode) { 2349 SparkHadoopUtil.get.newConfiguration(conf).get(key, sparkValue) 2350 } else { 2351 sparkValue 2352 } 2353 } 2354 2355 /** 2356 * Return a pair of host and port extracted from the `sparkUrl`. 2357 * 2358 * A spark url (`spark://host:port`) is a special URI that its scheme is `spark` and only contains 2359 * host and port. 2360 * 2361 * @throws org.apache.spark.SparkException if sparkUrl is invalid. 2362 */ 2363 @throws(classOf[SparkException]) 2364 def extractHostPortFromSparkUrl(sparkUrl: String): (String, Int) = { 2365 try { 2366 val uri = new java.net.URI(sparkUrl) 2367 val host = uri.getHost 2368 val port = uri.getPort 2369 if (uri.getScheme != "spark" || 2370 host == null || 2371 port < 0 || 2372 (uri.getPath != null && !uri.getPath.isEmpty) || // uri.getPath returns "" instead of null 2373 uri.getFragment != null || 2374 uri.getQuery != null || 2375 uri.getUserInfo != null) { 2376 throw new SparkException("Invalid master URL: " + sparkUrl) 2377 } 2378 (host, port) 2379 } catch { 2380 case e: java.net.URISyntaxException => 2381 throw new SparkException("Invalid master URL: " + sparkUrl, e) 2382 } 2383 } 2384 2385 /** 2386 * Returns the current user name. This is the currently logged in user, unless that's been 2387 * overridden by the `SPARK_USER` environment variable. 2388 */ 2389 def getCurrentUserName(): String = { 2390 Option(System.getenv("SPARK_USER")) 2391 .getOrElse(UserGroupInformation.getCurrentUser().getShortUserName()) 2392 } 2393 2394 val EMPTY_USER_GROUPS = Set[String]() 2395 2396 // Returns the groups to which the current user belongs. 2397 def getCurrentUserGroups(sparkConf: SparkConf, username: String): Set[String] = { 2398 val groupProviderClassName = sparkConf.get("spark.user.groups.mapping", 2399 "org.apache.spark.security.ShellBasedGroupsMappingProvider") 2400 if (groupProviderClassName != "") { 2401 try { 2402 val groupMappingServiceProvider = classForName(groupProviderClassName).newInstance. 2403 asInstanceOf[org.apache.spark.security.GroupMappingServiceProvider] 2404 val currentUserGroups = groupMappingServiceProvider.getGroups(username) 2405 return currentUserGroups 2406 } catch { 2407 case e: Exception => logError(s"Error getting groups for user=$username", e) 2408 } 2409 } 2410 EMPTY_USER_GROUPS 2411 } 2412 2413 /** 2414 * Split the comma delimited string of master URLs into a list. 2415 * For instance, "spark://abc,def" becomes [spark://abc, spark://def]. 2416 */ 2417 def parseStandaloneMasterUrls(masterUrls: String): Array[String] = { 2418 masterUrls.stripPrefix("spark://").split(",").map("spark://" + _) 2419 } 2420 2421 /** An identifier that backup masters use in their responses. */ 2422 val BACKUP_STANDALONE_MASTER_PREFIX = "Current state is not alive" 2423 2424 /** Return true if the response message is sent from a backup Master on standby. */ 2425 def responseFromBackup(msg: String): Boolean = { 2426 msg.startsWith(BACKUP_STANDALONE_MASTER_PREFIX) 2427 } 2428 2429 /** 2430 * To avoid calling `Utils.getCallSite` for every single RDD we create in the body, 2431 * set a dummy call site that RDDs use instead. This is for performance optimization. 2432 */ 2433 def withDummyCallSite[T](sc: SparkContext)(body: => T): T = { 2434 val oldShortCallSite = sc.getLocalProperty(CallSite.SHORT_FORM) 2435 val oldLongCallSite = sc.getLocalProperty(CallSite.LONG_FORM) 2436 try { 2437 sc.setLocalProperty(CallSite.SHORT_FORM, "") 2438 sc.setLocalProperty(CallSite.LONG_FORM, "") 2439 body 2440 } finally { 2441 // Restore the old ones here 2442 sc.setLocalProperty(CallSite.SHORT_FORM, oldShortCallSite) 2443 sc.setLocalProperty(CallSite.LONG_FORM, oldLongCallSite) 2444 } 2445 } 2446 2447 /** 2448 * Return whether the specified file is a parent directory of the child file. 2449 */ 2450 @tailrec 2451 def isInDirectory(parent: File, child: File): Boolean = { 2452 if (child == null || parent == null) { 2453 return false 2454 } 2455 if (!child.exists() || !parent.exists() || !parent.isDirectory()) { 2456 return false 2457 } 2458 if (parent.equals(child)) { 2459 return true 2460 } 2461 isInDirectory(parent, child.getParentFile) 2462 } 2463 2464 2465 /** 2466 * 2467 * @return whether it is local mode 2468 */ 2469 def isLocalMaster(conf: SparkConf): Boolean = { 2470 val master = conf.get("spark.master", "") 2471 master == "local" || master.startsWith("local[") 2472 } 2473 2474 /** 2475 * Return whether dynamic allocation is enabled in the given conf. 2476 */ 2477 def isDynamicAllocationEnabled(conf: SparkConf): Boolean = { 2478 val dynamicAllocationEnabled = conf.getBoolean("spark.dynamicAllocation.enabled", false) 2479 dynamicAllocationEnabled && 2480 (!isLocalMaster(conf) || conf.getBoolean("spark.dynamicAllocation.testing", false)) 2481 } 2482 2483 /** 2484 * Return the initial number of executors for dynamic allocation. 2485 */ 2486 def getDynamicAllocationInitialExecutors(conf: SparkConf): Int = { 2487 if (conf.get(DYN_ALLOCATION_INITIAL_EXECUTORS) < conf.get(DYN_ALLOCATION_MIN_EXECUTORS)) { 2488 logWarning(s"${DYN_ALLOCATION_INITIAL_EXECUTORS.key} less than " + 2489 s"${DYN_ALLOCATION_MIN_EXECUTORS.key} is invalid, ignoring its setting, " + 2490 "please update your configs.") 2491 } 2492 2493 if (conf.get(EXECUTOR_INSTANCES).getOrElse(0) < conf.get(DYN_ALLOCATION_MIN_EXECUTORS)) { 2494 logWarning(s"${EXECUTOR_INSTANCES.key} less than " + 2495 s"${DYN_ALLOCATION_MIN_EXECUTORS.key} is invalid, ignoring its setting, " + 2496 "please update your configs.") 2497 } 2498 2499 val initialExecutors = Seq( 2500 conf.get(DYN_ALLOCATION_MIN_EXECUTORS), 2501 conf.get(DYN_ALLOCATION_INITIAL_EXECUTORS), 2502 conf.get(EXECUTOR_INSTANCES).getOrElse(0)).max 2503 2504 logInfo(s"Using initial executors = $initialExecutors, max of " + 2505 s"${DYN_ALLOCATION_INITIAL_EXECUTORS.key}, ${DYN_ALLOCATION_MIN_EXECUTORS.key} and " + 2506 s"${EXECUTOR_INSTANCES.key}") 2507 initialExecutors 2508 } 2509 2510 def tryWithResource[R <: Closeable, T](createResource: => R)(f: R => T): T = { 2511 val resource = createResource 2512 try f.apply(resource) finally resource.close() 2513 } 2514 2515 /** 2516 * Returns a path of temporary file which is in the same directory with `path`. 2517 */ 2518 def tempFileWith(path: File): File = { 2519 new File(path.getAbsolutePath + "." + UUID.randomUUID()) 2520 } 2521 2522 /** 2523 * Returns the name of this JVM process. This is OS dependent but typically (OSX, Linux, Windows), 2524 * this is formatted as PID@hostname. 2525 */ 2526 def getProcessName(): String = { 2527 ManagementFactory.getRuntimeMXBean().getName() 2528 } 2529 2530 /** 2531 * Utility function that should be called early in `main()` for daemons to set up some common 2532 * diagnostic state. 2533 */ 2534 def initDaemon(log: Logger): Unit = { 2535 log.info(s"Started daemon with process name: ${Utils.getProcessName()}") 2536 SignalUtils.registerLogger(log) 2537 } 2538 2539 /** 2540 * Unions two comma-separated lists of files and filters out empty strings. 2541 */ 2542 def unionFileLists(leftList: Option[String], rightList: Option[String]): Set[String] = { 2543 var allFiles = Set[String]() 2544 leftList.foreach { value => allFiles ++= value.split(",") } 2545 rightList.foreach { value => allFiles ++= value.split(",") } 2546 allFiles.filter { _.nonEmpty } 2547 } 2548 2549 /** 2550 * In YARN mode this method returns a union of the jar files pointed by "spark.jars" and the 2551 * "spark.yarn.dist.jars" properties, while in other modes it returns the jar files pointed by 2552 * only the "spark.jars" property. 2553 */ 2554 def getUserJars(conf: SparkConf, isShell: Boolean = false): Seq[String] = { 2555 val sparkJars = conf.getOption("spark.jars") 2556 if (conf.get("spark.master") == "yarn" && isShell) { 2557 val yarnJars = conf.getOption("spark.yarn.dist.jars") 2558 unionFileLists(sparkJars, yarnJars).toSeq 2559 } else { 2560 sparkJars.map(_.split(",")).map(_.filter(_.nonEmpty)).toSeq.flatten 2561 } 2562 } 2563} 2564 2565private[util] object CallerContext extends Logging { 2566 val callerContextSupported: Boolean = { 2567 SparkHadoopUtil.get.conf.getBoolean("hadoop.caller.context.enabled", false) && { 2568 try { 2569 Utils.classForName("org.apache.hadoop.ipc.CallerContext") 2570 Utils.classForName("org.apache.hadoop.ipc.CallerContext$Builder") 2571 true 2572 } catch { 2573 case _: ClassNotFoundException => 2574 false 2575 case NonFatal(e) => 2576 logWarning("Fail to load the CallerContext class", e) 2577 false 2578 } 2579 } 2580 } 2581} 2582 2583/** 2584 * An utility class used to set up Spark caller contexts to HDFS and Yarn. The `context` will be 2585 * constructed by parameters passed in. 2586 * When Spark applications run on Yarn and HDFS, its caller contexts will be written into Yarn RM 2587 * audit log and hdfs-audit.log. That can help users to better diagnose and understand how 2588 * specific applications impacting parts of the Hadoop system and potential problems they may be 2589 * creating (e.g. overloading NN). As HDFS mentioned in HDFS-9184, for a given HDFS operation, it's 2590 * very helpful to track which upper level job issues it. 2591 * 2592 * @param from who sets up the caller context (TASK, CLIENT, APPMASTER) 2593 * 2594 * The parameters below are optional: 2595 * @param appId id of the app this task belongs to 2596 * @param appAttemptId attempt id of the app this task belongs to 2597 * @param jobId id of the job this task belongs to 2598 * @param stageId id of the stage this task belongs to 2599 * @param stageAttemptId attempt id of the stage this task belongs to 2600 * @param taskId task id 2601 * @param taskAttemptNumber task attempt id 2602 */ 2603private[spark] class CallerContext( 2604 from: String, 2605 appId: Option[String] = None, 2606 appAttemptId: Option[String] = None, 2607 jobId: Option[Int] = None, 2608 stageId: Option[Int] = None, 2609 stageAttemptId: Option[Int] = None, 2610 taskId: Option[Long] = None, 2611 taskAttemptNumber: Option[Int] = None) extends Logging { 2612 2613 val appIdStr = if (appId.isDefined) s"_${appId.get}" else "" 2614 val appAttemptIdStr = if (appAttemptId.isDefined) s"_${appAttemptId.get}" else "" 2615 val jobIdStr = if (jobId.isDefined) s"_JId_${jobId.get}" else "" 2616 val stageIdStr = if (stageId.isDefined) s"_SId_${stageId.get}" else "" 2617 val stageAttemptIdStr = if (stageAttemptId.isDefined) s"_${stageAttemptId.get}" else "" 2618 val taskIdStr = if (taskId.isDefined) s"_TId_${taskId.get}" else "" 2619 val taskAttemptNumberStr = 2620 if (taskAttemptNumber.isDefined) s"_${taskAttemptNumber.get}" else "" 2621 2622 val context = "SPARK_" + from + appIdStr + appAttemptIdStr + 2623 jobIdStr + stageIdStr + stageAttemptIdStr + taskIdStr + taskAttemptNumberStr 2624 2625 /** 2626 * Set up the caller context [[context]] by invoking Hadoop CallerContext API of 2627 * [[org.apache.hadoop.ipc.CallerContext]], which was added in hadoop 2.8. 2628 */ 2629 def setCurrentContext(): Unit = { 2630 if (CallerContext.callerContextSupported) { 2631 try { 2632 val callerContext = Utils.classForName("org.apache.hadoop.ipc.CallerContext") 2633 val builder = Utils.classForName("org.apache.hadoop.ipc.CallerContext$Builder") 2634 val builderInst = builder.getConstructor(classOf[String]).newInstance(context) 2635 val hdfsContext = builder.getMethod("build").invoke(builderInst) 2636 callerContext.getMethod("setCurrent", callerContext).invoke(null, hdfsContext) 2637 } catch { 2638 case NonFatal(e) => 2639 logWarning("Fail to set Spark caller context", e) 2640 } 2641 } 2642 } 2643} 2644 2645/** 2646 * A utility class to redirect the child process's stdout or stderr. 2647 */ 2648private[spark] class RedirectThread( 2649 in: InputStream, 2650 out: OutputStream, 2651 name: String, 2652 propagateEof: Boolean = false) 2653 extends Thread(name) { 2654 2655 setDaemon(true) 2656 override def run() { 2657 scala.util.control.Exception.ignoring(classOf[IOException]) { 2658 // FIXME: We copy the stream on the level of bytes to avoid encoding problems. 2659 Utils.tryWithSafeFinally { 2660 val buf = new Array[Byte](1024) 2661 var len = in.read(buf) 2662 while (len != -1) { 2663 out.write(buf, 0, len) 2664 out.flush() 2665 len = in.read(buf) 2666 } 2667 } { 2668 if (propagateEof) { 2669 out.close() 2670 } 2671 } 2672 } 2673 } 2674} 2675 2676/** 2677 * An [[OutputStream]] that will store the last 10 kilobytes (by default) written to it 2678 * in a circular buffer. The current contents of the buffer can be accessed using 2679 * the toString method. 2680 */ 2681private[spark] class CircularBuffer(sizeInBytes: Int = 10240) extends java.io.OutputStream { 2682 private var pos: Int = 0 2683 private var isBufferFull = false 2684 private val buffer = new Array[Byte](sizeInBytes) 2685 2686 def write(input: Int): Unit = { 2687 buffer(pos) = input.toByte 2688 pos = (pos + 1) % buffer.length 2689 isBufferFull = isBufferFull || (pos == 0) 2690 } 2691 2692 override def toString: String = { 2693 if (!isBufferFull) { 2694 return new String(buffer, 0, pos, StandardCharsets.UTF_8) 2695 } 2696 2697 val nonCircularBuffer = new Array[Byte](sizeInBytes) 2698 System.arraycopy(buffer, pos, nonCircularBuffer, 0, buffer.length - pos) 2699 System.arraycopy(buffer, 0, nonCircularBuffer, buffer.length - pos, pos) 2700 new String(nonCircularBuffer, StandardCharsets.UTF_8) 2701 } 2702} 2703