1/*
2 * Licensed to the Apache Software Foundation (ASF) under one or more
3 * contributor license agreements.  See the NOTICE file distributed with
4 * this work for additional information regarding copyright ownership.
5 * The ASF licenses this file to You under the Apache License, Version 2.0
6 * (the "License"); you may not use this file except in compliance with
7 * the License.  You may obtain a copy of the License at
8 *
9 *    http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17
18package org.apache.spark.deploy.yarn.security
19
20import org.apache.hadoop.conf.Configuration
21import org.apache.hadoop.io.Text
22import org.apache.hadoop.security.Credentials
23import org.apache.hadoop.security.token.Token
24import org.scalatest.{BeforeAndAfter, Matchers}
25
26import org.apache.spark.{SparkConf, SparkFunSuite}
27import org.apache.spark.deploy.yarn.config._
28
29class ConfigurableCredentialManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfter {
30  private var credentialManager: ConfigurableCredentialManager = null
31  private var sparkConf: SparkConf = null
32  private var hadoopConf: Configuration = null
33
34  override def beforeAll(): Unit = {
35    super.beforeAll()
36
37    sparkConf = new SparkConf()
38    hadoopConf = new Configuration()
39    System.setProperty("SPARK_YARN_MODE", "true")
40  }
41
42  override def afterAll(): Unit = {
43    System.clearProperty("SPARK_YARN_MODE")
44
45    super.afterAll()
46  }
47
48  test("Correctly load default credential providers") {
49    credentialManager = new ConfigurableCredentialManager(sparkConf, hadoopConf)
50
51    credentialManager.getServiceCredentialProvider("hdfs") should not be (None)
52    credentialManager.getServiceCredentialProvider("hbase") should not be (None)
53    credentialManager.getServiceCredentialProvider("hive") should not be (None)
54  }
55
56  test("disable hive credential provider") {
57    sparkConf.set("spark.yarn.security.credentials.hive.enabled", "false")
58    credentialManager = new ConfigurableCredentialManager(sparkConf, hadoopConf)
59
60    credentialManager.getServiceCredentialProvider("hdfs") should not be (None)
61    credentialManager.getServiceCredentialProvider("hbase") should not be (None)
62    credentialManager.getServiceCredentialProvider("hive") should be (None)
63  }
64
65  test("using deprecated configurations") {
66    sparkConf.set("spark.yarn.security.tokens.hdfs.enabled", "false")
67    sparkConf.set("spark.yarn.security.tokens.hive.enabled", "false")
68    credentialManager = new ConfigurableCredentialManager(sparkConf, hadoopConf)
69
70    credentialManager.getServiceCredentialProvider("hdfs") should be (None)
71    credentialManager.getServiceCredentialProvider("hive") should be (None)
72    credentialManager.getServiceCredentialProvider("test") should not be (None)
73    credentialManager.getServiceCredentialProvider("hbase") should not be (None)
74  }
75
76  test("verify obtaining credentials from provider") {
77    credentialManager = new ConfigurableCredentialManager(sparkConf, hadoopConf)
78    val creds = new Credentials()
79
80    // Tokens can only be obtained from TestTokenProvider, for hdfs, hbase and hive tokens cannot
81    // be obtained.
82    credentialManager.obtainCredentials(hadoopConf, creds)
83    val tokens = creds.getAllTokens
84    tokens.size() should be (1)
85    tokens.iterator().next().getService should be (new Text("test"))
86  }
87
88  test("verify getting credential renewal info") {
89    credentialManager = new ConfigurableCredentialManager(sparkConf, hadoopConf)
90    val creds = new Credentials()
91
92    val testCredentialProvider = credentialManager.getServiceCredentialProvider("test").get
93      .asInstanceOf[TestCredentialProvider]
94    // Only TestTokenProvider can get the time of next token renewal
95    val nextRenewal = credentialManager.obtainCredentials(hadoopConf, creds)
96    nextRenewal should be (testCredentialProvider.timeOfNextTokenRenewal)
97  }
98
99  test("obtain tokens For HiveMetastore") {
100    val hadoopConf = new Configuration()
101    hadoopConf.set("hive.metastore.kerberos.principal", "bob")
102    // thrift picks up on port 0 and bails out, without trying to talk to endpoint
103    hadoopConf.set("hive.metastore.uris", "http://localhost:0")
104
105    val hiveCredentialProvider = new HiveCredentialProvider()
106    val credentials = new Credentials()
107    hiveCredentialProvider.obtainCredentials(hadoopConf, sparkConf, credentials)
108
109    credentials.getAllTokens.size() should be (0)
110  }
111
112  test("Obtain tokens For HBase") {
113    val hadoopConf = new Configuration()
114    hadoopConf.set("hbase.security.authentication", "kerberos")
115
116    val hbaseTokenProvider = new HBaseCredentialProvider()
117    val creds = new Credentials()
118    hbaseTokenProvider.obtainCredentials(hadoopConf, sparkConf, creds)
119
120    creds.getAllTokens.size should be (0)
121  }
122}
123
124class TestCredentialProvider extends ServiceCredentialProvider {
125  val tokenRenewalInterval = 86400 * 1000L
126  var timeOfNextTokenRenewal = 0L
127
128  override def serviceName: String = "test"
129
130  override def credentialsRequired(conf: Configuration): Boolean = true
131
132  override def obtainCredentials(
133      hadoopConf: Configuration,
134      sparkConf: SparkConf,
135      creds: Credentials): Option[Long] = {
136    if (creds == null) {
137      // Guard out other unit test failures.
138      return None
139    }
140
141    val emptyToken = new Token()
142    emptyToken.setService(new Text("test"))
143    creds.addToken(emptyToken.getService, emptyToken)
144
145    val currTime = System.currentTimeMillis()
146    timeOfNextTokenRenewal = (currTime - currTime % tokenRenewalInterval) + tokenRenewalInterval
147
148    Some(timeOfNextTokenRenewal)
149  }
150}
151