1/* 2 * Licensed to the Apache Software Foundation (ASF) under one or more 3 * contributor license agreements. See the NOTICE file distributed with 4 * this work for additional information regarding copyright ownership. 5 * The ASF licenses this file to You under the Apache License, Version 2.0 6 * (the "License"); you may not use this file except in compliance with 7 * the License. You may obtain a copy of the License at 8 * 9 * http://www.apache.org/licenses/LICENSE-2.0 10 * 11 * Unless required by applicable law or agreed to in writing, software 12 * distributed under the License is distributed on an "AS IS" BASIS, 13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 * See the License for the specific language governing permissions and 15 * limitations under the License. 16 */ 17package org.apache.spark.security 18 19import java.io.{ByteArrayInputStream, ByteArrayOutputStream} 20import java.nio.charset.StandardCharsets.UTF_8 21import java.util.UUID 22 23import com.google.common.io.ByteStreams 24 25import org.apache.spark._ 26import org.apache.spark.internal.config._ 27import org.apache.spark.security.CryptoStreamUtils._ 28import org.apache.spark.serializer.{JavaSerializer, SerializerManager} 29import org.apache.spark.storage.TempShuffleBlockId 30 31class CryptoStreamUtilsSuite extends SparkFunSuite { 32 33 test("crypto configuration conversion") { 34 val sparkKey1 = s"${SPARK_IO_ENCRYPTION_COMMONS_CONFIG_PREFIX}a.b.c" 35 val sparkVal1 = "val1" 36 val cryptoKey1 = s"${COMMONS_CRYPTO_CONF_PREFIX}a.b.c" 37 38 val sparkKey2 = SPARK_IO_ENCRYPTION_COMMONS_CONFIG_PREFIX.stripSuffix(".") + "A.b.c" 39 val sparkVal2 = "val2" 40 val cryptoKey2 = s"${COMMONS_CRYPTO_CONF_PREFIX}A.b.c" 41 val conf = new SparkConf() 42 conf.set(sparkKey1, sparkVal1) 43 conf.set(sparkKey2, sparkVal2) 44 val props = CryptoStreamUtils.toCryptoConf(conf) 45 assert(props.getProperty(cryptoKey1) === sparkVal1) 46 assert(!props.containsKey(cryptoKey2)) 47 } 48 49 test("shuffle encryption key length should be 128 by default") { 50 val conf = createConf() 51 var key = CryptoStreamUtils.createKey(conf) 52 val actual = key.length * (java.lang.Byte.SIZE) 53 assert(actual === 128) 54 } 55 56 test("create 256-bit key") { 57 val conf = createConf(IO_ENCRYPTION_KEY_SIZE_BITS.key -> "256") 58 var key = CryptoStreamUtils.createKey(conf) 59 val actual = key.length * (java.lang.Byte.SIZE) 60 assert(actual === 256) 61 } 62 63 test("create key with invalid length") { 64 intercept[IllegalArgumentException] { 65 val conf = createConf(IO_ENCRYPTION_KEY_SIZE_BITS.key -> "328") 66 CryptoStreamUtils.createKey(conf) 67 } 68 } 69 70 test("serializer manager integration") { 71 val conf = createConf() 72 .set("spark.shuffle.compress", "true") 73 .set("spark.shuffle.spill.compress", "true") 74 75 val plainStr = "hello world" 76 val blockId = new TempShuffleBlockId(UUID.randomUUID()) 77 val key = Some(CryptoStreamUtils.createKey(conf)) 78 val serializerManager = new SerializerManager(new JavaSerializer(conf), conf, 79 encryptionKey = key) 80 81 val outputStream = new ByteArrayOutputStream() 82 val wrappedOutputStream = serializerManager.wrapStream(blockId, outputStream) 83 wrappedOutputStream.write(plainStr.getBytes(UTF_8)) 84 wrappedOutputStream.close() 85 86 val encryptedBytes = outputStream.toByteArray 87 val encryptedStr = new String(encryptedBytes, UTF_8) 88 assert(plainStr !== encryptedStr) 89 90 val inputStream = new ByteArrayInputStream(encryptedBytes) 91 val wrappedInputStream = serializerManager.wrapStream(blockId, inputStream) 92 val decryptedBytes = ByteStreams.toByteArray(wrappedInputStream) 93 val decryptedStr = new String(decryptedBytes, UTF_8) 94 assert(decryptedStr === plainStr) 95 } 96 97 test("encryption key propagation to executors") { 98 val conf = createConf().setAppName("Crypto Test").setMaster("local-cluster[1,1,1024]") 99 val sc = new SparkContext(conf) 100 try { 101 val content = "This is the content to be encrypted." 102 val encrypted = sc.parallelize(Seq(1)) 103 .map { str => 104 val bytes = new ByteArrayOutputStream() 105 val out = CryptoStreamUtils.createCryptoOutputStream(bytes, SparkEnv.get.conf, 106 SparkEnv.get.securityManager.getIOEncryptionKey().get) 107 out.write(content.getBytes(UTF_8)) 108 out.close() 109 bytes.toByteArray() 110 }.collect()(0) 111 112 assert(content != encrypted) 113 114 val in = CryptoStreamUtils.createCryptoInputStream(new ByteArrayInputStream(encrypted), 115 sc.conf, SparkEnv.get.securityManager.getIOEncryptionKey().get) 116 val decrypted = new String(ByteStreams.toByteArray(in), UTF_8) 117 assert(content === decrypted) 118 } finally { 119 sc.stop() 120 } 121 } 122 123 private def createConf(extra: (String, String)*): SparkConf = { 124 val conf = new SparkConf() 125 extra.foreach { case (k, v) => conf.set(k, v) } 126 conf.set(IO_ENCRYPTION_ENABLED, true) 127 conf 128 } 129 130} 131