1/*
2 * Licensed to the Apache Software Foundation (ASF) under one or more
3 * contributor license agreements.  See the NOTICE file distributed with
4 * this work for additional information regarding copyright ownership.
5 * The ASF licenses this file to You under the Apache License, Version 2.0
6 * (the "License"); you may not use this file except in compliance with
7 * the License.  You may obtain a copy of the License at
8 *
9 *    http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17
18package org.apache.spark.deploy
19
20import java.io.IOException
21import java.lang.reflect.Method
22import java.security.PrivilegedExceptionAction
23import java.text.DateFormat
24import java.util.{Arrays, Comparator, Date, Locale}
25
26import scala.collection.JavaConverters._
27import scala.util.control.NonFatal
28
29import com.google.common.primitives.Longs
30import org.apache.hadoop.conf.Configuration
31import org.apache.hadoop.fs.{FileStatus, FileSystem, Path, PathFilter}
32import org.apache.hadoop.fs.FileSystem.Statistics
33import org.apache.hadoop.mapred.JobConf
34import org.apache.hadoop.security.{Credentials, UserGroupInformation}
35import org.apache.hadoop.security.token.{Token, TokenIdentifier}
36import org.apache.hadoop.security.token.delegation.AbstractDelegationTokenIdentifier
37
38import org.apache.spark.{SparkConf, SparkException}
39import org.apache.spark.annotation.DeveloperApi
40import org.apache.spark.internal.Logging
41import org.apache.spark.util.Utils
42
43/**
44 * :: DeveloperApi ::
45 * Contains util methods to interact with Hadoop from Spark.
46 */
47@DeveloperApi
48class SparkHadoopUtil extends Logging {
49  private val sparkConf = new SparkConf(false).loadFromSystemProperties(true)
50  val conf: Configuration = newConfiguration(sparkConf)
51  UserGroupInformation.setConfiguration(conf)
52
53  /**
54   * Runs the given function with a Hadoop UserGroupInformation as a thread local variable
55   * (distributed to child threads), used for authenticating HDFS and YARN calls.
56   *
57   * IMPORTANT NOTE: If this function is going to be called repeated in the same process
58   * you need to look https://issues.apache.org/jira/browse/HDFS-3545 and possibly
59   * do a FileSystem.closeAllForUGI in order to avoid leaking Filesystems
60   */
61  def runAsSparkUser(func: () => Unit) {
62    val user = Utils.getCurrentUserName()
63    logDebug("running as user: " + user)
64    val ugi = UserGroupInformation.createRemoteUser(user)
65    transferCredentials(UserGroupInformation.getCurrentUser(), ugi)
66    ugi.doAs(new PrivilegedExceptionAction[Unit] {
67      def run: Unit = func()
68    })
69  }
70
71  def transferCredentials(source: UserGroupInformation, dest: UserGroupInformation) {
72    for (token <- source.getTokens.asScala) {
73      dest.addToken(token)
74    }
75  }
76
77
78  /**
79   * Appends S3-specific, spark.hadoop.*, and spark.buffer.size configurations to a Hadoop
80   * configuration.
81   */
82  def appendS3AndSparkHadoopConfigurations(conf: SparkConf, hadoopConf: Configuration): Unit = {
83    // Note: this null check is around more than just access to the "conf" object to maintain
84    // the behavior of the old implementation of this code, for backwards compatibility.
85    if (conf != null) {
86      // Explicitly check for S3 environment variables
87      if (System.getenv("AWS_ACCESS_KEY_ID") != null &&
88          System.getenv("AWS_SECRET_ACCESS_KEY") != null) {
89        val keyId = System.getenv("AWS_ACCESS_KEY_ID")
90        val accessKey = System.getenv("AWS_SECRET_ACCESS_KEY")
91
92        hadoopConf.set("fs.s3.awsAccessKeyId", keyId)
93        hadoopConf.set("fs.s3n.awsAccessKeyId", keyId)
94        hadoopConf.set("fs.s3a.access.key", keyId)
95        hadoopConf.set("fs.s3.awsSecretAccessKey", accessKey)
96        hadoopConf.set("fs.s3n.awsSecretAccessKey", accessKey)
97        hadoopConf.set("fs.s3a.secret.key", accessKey)
98      }
99      // Copy any "spark.hadoop.foo=bar" system properties into conf as "foo=bar"
100      conf.getAll.foreach { case (key, value) =>
101        if (key.startsWith("spark.hadoop.")) {
102          hadoopConf.set(key.substring("spark.hadoop.".length), value)
103        }
104      }
105      val bufferSize = conf.get("spark.buffer.size", "65536")
106      hadoopConf.set("io.file.buffer.size", bufferSize)
107    }
108  }
109
110  /**
111   * Return an appropriate (subclass) of Configuration. Creating config can initializes some Hadoop
112   * subsystems.
113   */
114  def newConfiguration(conf: SparkConf): Configuration = {
115    val hadoopConf = new Configuration()
116    appendS3AndSparkHadoopConfigurations(conf, hadoopConf)
117    hadoopConf
118  }
119
120  /**
121   * Add any user credentials to the job conf which are necessary for running on a secure Hadoop
122   * cluster.
123   */
124  def addCredentials(conf: JobConf) {}
125
126  def isYarnMode(): Boolean = { false }
127
128  def getCurrentUserCredentials(): Credentials = { null }
129
130  def addCurrentUserCredentials(creds: Credentials) {}
131
132  def addSecretKeyToUserCredentials(key: String, secret: String) {}
133
134  def getSecretKeyFromUserCredentials(key: String): Array[Byte] = { null }
135
136  def loginUserFromKeytab(principalName: String, keytabFilename: String) {
137    UserGroupInformation.loginUserFromKeytab(principalName, keytabFilename)
138  }
139
140  /**
141   * Returns a function that can be called to find Hadoop FileSystem bytes read. If
142   * getFSBytesReadOnThreadCallback is called from thread r at time t, the returned callback will
143   * return the bytes read on r since t.  Reflection is required because thread-level FileSystem
144   * statistics are only available as of Hadoop 2.5 (see HADOOP-10688).
145   * Returns None if the required method can't be found.
146   */
147  private[spark] def getFSBytesReadOnThreadCallback(): Option[() => Long] = {
148    try {
149      val threadStats = getFileSystemThreadStatistics()
150      val getBytesReadMethod = getFileSystemThreadStatisticsMethod("getBytesRead")
151      val f = () => threadStats.map(getBytesReadMethod.invoke(_).asInstanceOf[Long]).sum
152      val baselineBytesRead = f()
153      Some(() => f() - baselineBytesRead)
154    } catch {
155      case e @ (_: NoSuchMethodException | _: ClassNotFoundException) =>
156        logDebug("Couldn't find method for retrieving thread-level FileSystem input data", e)
157        None
158    }
159  }
160
161  /**
162   * Returns a function that can be called to find Hadoop FileSystem bytes written. If
163   * getFSBytesWrittenOnThreadCallback is called from thread r at time t, the returned callback will
164   * return the bytes written on r since t.  Reflection is required because thread-level FileSystem
165   * statistics are only available as of Hadoop 2.5 (see HADOOP-10688).
166   * Returns None if the required method can't be found.
167   */
168  private[spark] def getFSBytesWrittenOnThreadCallback(): Option[() => Long] = {
169    try {
170      val threadStats = getFileSystemThreadStatistics()
171      val getBytesWrittenMethod = getFileSystemThreadStatisticsMethod("getBytesWritten")
172      val f = () => threadStats.map(getBytesWrittenMethod.invoke(_).asInstanceOf[Long]).sum
173      val baselineBytesWritten = f()
174      Some(() => f() - baselineBytesWritten)
175    } catch {
176      case e @ (_: NoSuchMethodException | _: ClassNotFoundException) =>
177        logDebug("Couldn't find method for retrieving thread-level FileSystem output data", e)
178        None
179    }
180  }
181
182  private def getFileSystemThreadStatistics(): Seq[AnyRef] = {
183    FileSystem.getAllStatistics.asScala.map(
184      Utils.invoke(classOf[Statistics], _, "getThreadStatistics"))
185  }
186
187  private def getFileSystemThreadStatisticsMethod(methodName: String): Method = {
188    val statisticsDataClass =
189      Utils.classForName("org.apache.hadoop.fs.FileSystem$Statistics$StatisticsData")
190    statisticsDataClass.getDeclaredMethod(methodName)
191  }
192
193  /**
194   * Get [[FileStatus]] objects for all leaf children (files) under the given base path. If the
195   * given path points to a file, return a single-element collection containing [[FileStatus]] of
196   * that file.
197   */
198  def listLeafStatuses(fs: FileSystem, basePath: Path): Seq[FileStatus] = {
199    listLeafStatuses(fs, fs.getFileStatus(basePath))
200  }
201
202  /**
203   * Get [[FileStatus]] objects for all leaf children (files) under the given base path. If the
204   * given path points to a file, return a single-element collection containing [[FileStatus]] of
205   * that file.
206   */
207  def listLeafStatuses(fs: FileSystem, baseStatus: FileStatus): Seq[FileStatus] = {
208    def recurse(status: FileStatus): Seq[FileStatus] = {
209      val (directories, leaves) = fs.listStatus(status.getPath).partition(_.isDirectory)
210      leaves ++ directories.flatMap(f => listLeafStatuses(fs, f))
211    }
212
213    if (baseStatus.isDirectory) recurse(baseStatus) else Seq(baseStatus)
214  }
215
216  def listLeafDirStatuses(fs: FileSystem, basePath: Path): Seq[FileStatus] = {
217    listLeafDirStatuses(fs, fs.getFileStatus(basePath))
218  }
219
220  def listLeafDirStatuses(fs: FileSystem, baseStatus: FileStatus): Seq[FileStatus] = {
221    def recurse(status: FileStatus): Seq[FileStatus] = {
222      val (directories, files) = fs.listStatus(status.getPath).partition(_.isDirectory)
223      val leaves = if (directories.isEmpty) Seq(status) else Seq.empty[FileStatus]
224      leaves ++ directories.flatMap(dir => listLeafDirStatuses(fs, dir))
225    }
226
227    assert(baseStatus.isDirectory)
228    recurse(baseStatus)
229  }
230
231  def isGlobPath(pattern: Path): Boolean = {
232    pattern.toString.exists("{}[]*?\\".toSet.contains)
233  }
234
235  def globPath(pattern: Path): Seq[Path] = {
236    val fs = pattern.getFileSystem(conf)
237    Option(fs.globStatus(pattern)).map { statuses =>
238      statuses.map(_.getPath.makeQualified(fs.getUri, fs.getWorkingDirectory)).toSeq
239    }.getOrElse(Seq.empty[Path])
240  }
241
242  def globPathIfNecessary(pattern: Path): Seq[Path] = {
243    if (isGlobPath(pattern)) globPath(pattern) else Seq(pattern)
244  }
245
246  /**
247   * Lists all the files in a directory with the specified prefix, and does not end with the
248   * given suffix. The returned {{FileStatus}} instances are sorted by the modification times of
249   * the respective files.
250   */
251  def listFilesSorted(
252      remoteFs: FileSystem,
253      dir: Path,
254      prefix: String,
255      exclusionSuffix: String): Array[FileStatus] = {
256    try {
257      val fileStatuses = remoteFs.listStatus(dir,
258        new PathFilter {
259          override def accept(path: Path): Boolean = {
260            val name = path.getName
261            name.startsWith(prefix) && !name.endsWith(exclusionSuffix)
262          }
263        })
264      Arrays.sort(fileStatuses, new Comparator[FileStatus] {
265        override def compare(o1: FileStatus, o2: FileStatus): Int = {
266          Longs.compare(o1.getModificationTime, o2.getModificationTime)
267        }
268      })
269      fileStatuses
270    } catch {
271      case NonFatal(e) =>
272        logWarning("Error while attempting to list files from application staging dir", e)
273        Array.empty
274    }
275  }
276
277  private[spark] def getSuffixForCredentialsPath(credentialsPath: Path): Int = {
278    val fileName = credentialsPath.getName
279    fileName.substring(
280      fileName.lastIndexOf(SparkHadoopUtil.SPARK_YARN_CREDS_COUNTER_DELIM) + 1).toInt
281  }
282
283
284  private val HADOOP_CONF_PATTERN = "(\\$\\{hadoopconf-[^\\}\\$\\s]+\\})".r.unanchored
285
286  /**
287   * Substitute variables by looking them up in Hadoop configs. Only variables that match the
288   * ${hadoopconf- .. } pattern are substituted.
289   */
290  def substituteHadoopVariables(text: String, hadoopConf: Configuration): String = {
291    text match {
292      case HADOOP_CONF_PATTERN(matched) =>
293        logDebug(text + " matched " + HADOOP_CONF_PATTERN)
294        val key = matched.substring(13, matched.length() - 1) // remove ${hadoopconf- .. }
295        val eval = Option[String](hadoopConf.get(key))
296          .map { value =>
297            logDebug("Substituted " + matched + " with " + value)
298            text.replace(matched, value)
299          }
300        if (eval.isEmpty) {
301          // The variable was not found in Hadoop configs, so return text as is.
302          text
303        } else {
304          // Continue to substitute more variables.
305          substituteHadoopVariables(eval.get, hadoopConf)
306        }
307      case _ =>
308        logDebug(text + " didn't match " + HADOOP_CONF_PATTERN)
309        text
310    }
311  }
312
313  /**
314   * Start a thread to periodically update the current user's credentials with new credentials so
315   * that access to secured service does not fail.
316   */
317  private[spark] def startCredentialUpdater(conf: SparkConf) {}
318
319  /**
320   * Stop the thread that does the credential updates.
321   */
322  private[spark] def stopCredentialUpdater() {}
323
324  /**
325   * Return a fresh Hadoop configuration, bypassing the HDFS cache mechanism.
326   * This is to prevent the DFSClient from using an old cached token to connect to the NameNode.
327   */
328  private[spark] def getConfBypassingFSCache(
329      hadoopConf: Configuration,
330      scheme: String): Configuration = {
331    val newConf = new Configuration(hadoopConf)
332    val confKey = s"fs.${scheme}.impl.disable.cache"
333    newConf.setBoolean(confKey, true)
334    newConf
335  }
336
337  /**
338   * Dump the credentials' tokens to string values.
339   *
340   * @param credentials credentials
341   * @return an iterator over the string values. If no credentials are passed in: an empty list
342   */
343  private[spark] def dumpTokens(credentials: Credentials): Iterable[String] = {
344    if (credentials != null) {
345      credentials.getAllTokens.asScala.map(tokenToString)
346    } else {
347      Seq()
348    }
349  }
350
351  /**
352   * Convert a token to a string for logging.
353   * If its an abstract delegation token, attempt to unmarshall it and then
354   * print more details, including timestamps in human-readable form.
355   *
356   * @param token token to convert to a string
357   * @return a printable string value.
358   */
359  private[spark] def tokenToString(token: Token[_ <: TokenIdentifier]): String = {
360    val df = DateFormat.getDateTimeInstance(DateFormat.SHORT, DateFormat.SHORT, Locale.US)
361    val buffer = new StringBuilder(128)
362    buffer.append(token.toString)
363    try {
364      val ti = token.decodeIdentifier
365      buffer.append("; ").append(ti)
366      ti match {
367        case dt: AbstractDelegationTokenIdentifier =>
368          // include human times and the renewer, which the HDFS tokens toString omits
369          buffer.append("; Renewer: ").append(dt.getRenewer)
370          buffer.append("; Issued: ").append(df.format(new Date(dt.getIssueDate)))
371          buffer.append("; Max Date: ").append(df.format(new Date(dt.getMaxDate)))
372        case _ =>
373      }
374    } catch {
375      case e: IOException =>
376        logDebug(s"Failed to decode $token: $e", e)
377    }
378    buffer.toString
379  }
380}
381
382object SparkHadoopUtil {
383
384  private lazy val hadoop = new SparkHadoopUtil
385  private lazy val yarn = try {
386    Utils.classForName("org.apache.spark.deploy.yarn.YarnSparkHadoopUtil")
387      .newInstance()
388      .asInstanceOf[SparkHadoopUtil]
389  } catch {
390    case e: Exception => throw new SparkException("Unable to load YARN support", e)
391  }
392
393  val SPARK_YARN_CREDS_TEMP_EXTENSION = ".tmp"
394
395  val SPARK_YARN_CREDS_COUNTER_DELIM = "-"
396
397  /**
398   * Number of records to update input metrics when reading from HadoopRDDs.
399   *
400   * Each update is potentially expensive because we need to use reflection to access the
401   * Hadoop FileSystem API of interest (only available in 2.5), so we should do this sparingly.
402   */
403  private[spark] val UPDATE_INPUT_METRICS_INTERVAL_RECORDS = 1000
404
405  def get: SparkHadoopUtil = {
406    // Check each time to support changing to/from YARN
407    val yarnMode = java.lang.Boolean.parseBoolean(
408        System.getProperty("SPARK_YARN_MODE", System.getenv("SPARK_YARN_MODE")))
409    if (yarnMode) {
410      yarn
411    } else {
412      hadoop
413    }
414  }
415}
416