1/*
2 * Licensed to the Apache Software Foundation (ASF) under one or more
3 * contributor license agreements.  See the NOTICE file distributed with
4 * this work for additional information regarding copyright ownership.
5 * The ASF licenses this file to You under the Apache License, Version 2.0
6 * (the "License"); you may not use this file except in compliance with
7 * the License.  You may obtain a copy of the License at
8 *
9 *    http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17package org.apache.spark.security
18
19import java.io.{ByteArrayInputStream, ByteArrayOutputStream}
20import java.nio.charset.StandardCharsets.UTF_8
21import java.util.UUID
22
23import com.google.common.io.ByteStreams
24
25import org.apache.spark._
26import org.apache.spark.internal.config._
27import org.apache.spark.security.CryptoStreamUtils._
28import org.apache.spark.serializer.{JavaSerializer, SerializerManager}
29import org.apache.spark.storage.TempShuffleBlockId
30
31class CryptoStreamUtilsSuite extends SparkFunSuite {
32
33  test("crypto configuration conversion") {
34    val sparkKey1 = s"${SPARK_IO_ENCRYPTION_COMMONS_CONFIG_PREFIX}a.b.c"
35    val sparkVal1 = "val1"
36    val cryptoKey1 = s"${COMMONS_CRYPTO_CONF_PREFIX}a.b.c"
37
38    val sparkKey2 = SPARK_IO_ENCRYPTION_COMMONS_CONFIG_PREFIX.stripSuffix(".") + "A.b.c"
39    val sparkVal2 = "val2"
40    val cryptoKey2 = s"${COMMONS_CRYPTO_CONF_PREFIX}A.b.c"
41    val conf = new SparkConf()
42    conf.set(sparkKey1, sparkVal1)
43    conf.set(sparkKey2, sparkVal2)
44    val props = CryptoStreamUtils.toCryptoConf(conf)
45    assert(props.getProperty(cryptoKey1) === sparkVal1)
46    assert(!props.containsKey(cryptoKey2))
47  }
48
49  test("shuffle encryption key length should be 128 by default") {
50    val conf = createConf()
51    var key = CryptoStreamUtils.createKey(conf)
52    val actual = key.length * (java.lang.Byte.SIZE)
53    assert(actual === 128)
54  }
55
56  test("create 256-bit key") {
57    val conf = createConf(IO_ENCRYPTION_KEY_SIZE_BITS.key -> "256")
58    var key = CryptoStreamUtils.createKey(conf)
59    val actual = key.length * (java.lang.Byte.SIZE)
60    assert(actual === 256)
61  }
62
63  test("create key with invalid length") {
64    intercept[IllegalArgumentException] {
65      val conf = createConf(IO_ENCRYPTION_KEY_SIZE_BITS.key -> "328")
66      CryptoStreamUtils.createKey(conf)
67    }
68  }
69
70  test("serializer manager integration") {
71    val conf = createConf()
72      .set("spark.shuffle.compress", "true")
73      .set("spark.shuffle.spill.compress", "true")
74
75    val plainStr = "hello world"
76    val blockId = new TempShuffleBlockId(UUID.randomUUID())
77    val key = Some(CryptoStreamUtils.createKey(conf))
78    val serializerManager = new SerializerManager(new JavaSerializer(conf), conf,
79      encryptionKey = key)
80
81    val outputStream = new ByteArrayOutputStream()
82    val wrappedOutputStream = serializerManager.wrapStream(blockId, outputStream)
83    wrappedOutputStream.write(plainStr.getBytes(UTF_8))
84    wrappedOutputStream.close()
85
86    val encryptedBytes = outputStream.toByteArray
87    val encryptedStr = new String(encryptedBytes, UTF_8)
88    assert(plainStr !== encryptedStr)
89
90    val inputStream = new ByteArrayInputStream(encryptedBytes)
91    val wrappedInputStream = serializerManager.wrapStream(blockId, inputStream)
92    val decryptedBytes = ByteStreams.toByteArray(wrappedInputStream)
93    val decryptedStr = new String(decryptedBytes, UTF_8)
94    assert(decryptedStr === plainStr)
95  }
96
97  test("encryption key propagation to executors") {
98    val conf = createConf().setAppName("Crypto Test").setMaster("local-cluster[1,1,1024]")
99    val sc = new SparkContext(conf)
100    try {
101      val content = "This is the content to be encrypted."
102      val encrypted = sc.parallelize(Seq(1))
103        .map { str =>
104          val bytes = new ByteArrayOutputStream()
105          val out = CryptoStreamUtils.createCryptoOutputStream(bytes, SparkEnv.get.conf,
106            SparkEnv.get.securityManager.getIOEncryptionKey().get)
107          out.write(content.getBytes(UTF_8))
108          out.close()
109          bytes.toByteArray()
110        }.collect()(0)
111
112      assert(content != encrypted)
113
114      val in = CryptoStreamUtils.createCryptoInputStream(new ByteArrayInputStream(encrypted),
115        sc.conf, SparkEnv.get.securityManager.getIOEncryptionKey().get)
116      val decrypted = new String(ByteStreams.toByteArray(in), UTF_8)
117      assert(content === decrypted)
118    } finally {
119      sc.stop()
120    }
121  }
122
123  private def createConf(extra: (String, String)*): SparkConf = {
124    val conf = new SparkConf()
125    extra.foreach { case (k, v) => conf.set(k, v) }
126    conf.set(IO_ENCRYPTION_ENABLED, true)
127    conf
128  }
129
130}
131