1/* 2 * Licensed to the Apache Software Foundation (ASF) under one or more 3 * contributor license agreements. See the NOTICE file distributed with 4 * this work for additional information regarding copyright ownership. 5 * The ASF licenses this file to You under the Apache License, Version 2.0 6 * (the "License"); you may not use this file except in compliance with 7 * the License. You may obtain a copy of the License at 8 * 9 * http://www.apache.org/licenses/LICENSE-2.0 10 * 11 * Unless required by applicable law or agreed to in writing, software 12 * distributed under the License is distributed on an "AS IS" BASIS, 13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 * See the License for the specific language governing permissions and 15 * limitations under the License. 16 */ 17 18package org.apache.spark 19 20import org.scalatest.Assertions 21 22import org.apache.spark.storage.StorageLevel 23 24class SparkContextInfoSuite extends SparkFunSuite with LocalSparkContext { 25 test("getPersistentRDDs only returns RDDs that are marked as cached") { 26 sc = new SparkContext("local", "test") 27 assert(sc.getPersistentRDDs.isEmpty === true) 28 29 val rdd = sc.makeRDD(Array(1, 2, 3, 4), 2) 30 assert(sc.getPersistentRDDs.isEmpty === true) 31 32 rdd.cache() 33 assert(sc.getPersistentRDDs.size === 1) 34 assert(sc.getPersistentRDDs.values.head === rdd) 35 } 36 37 test("getPersistentRDDs returns an immutable map") { 38 sc = new SparkContext("local", "test") 39 val rdd1 = sc.makeRDD(Array(1, 2, 3, 4), 2).cache() 40 val myRdds = sc.getPersistentRDDs 41 assert(myRdds.size === 1) 42 assert(myRdds(0) === rdd1) 43 assert(myRdds(0).getStorageLevel === StorageLevel.MEMORY_ONLY) 44 45 // myRdds2 should have 2 RDDs, but myRdds should not change 46 val rdd2 = sc.makeRDD(Array(5, 6, 7, 8), 1).cache() 47 val myRdds2 = sc.getPersistentRDDs 48 assert(myRdds2.size === 2) 49 assert(myRdds2(0) === rdd1) 50 assert(myRdds2(1) === rdd2) 51 assert(myRdds2(0).getStorageLevel === StorageLevel.MEMORY_ONLY) 52 assert(myRdds2(1).getStorageLevel === StorageLevel.MEMORY_ONLY) 53 assert(myRdds.size === 1) 54 assert(myRdds(0) === rdd1) 55 assert(myRdds(0).getStorageLevel === StorageLevel.MEMORY_ONLY) 56 } 57 58 test("getRDDStorageInfo only reports on RDDs that actually persist data") { 59 sc = new SparkContext("local", "test") 60 val rdd = sc.makeRDD(Array(1, 2, 3, 4), 2).cache() 61 assert(sc.getRDDStorageInfo.size === 0) 62 rdd.collect() 63 assert(sc.getRDDStorageInfo.size === 1) 64 assert(sc.getRDDStorageInfo.head.isCached) 65 assert(sc.getRDDStorageInfo.head.memSize > 0) 66 assert(sc.getRDDStorageInfo.head.storageLevel === StorageLevel.MEMORY_ONLY) 67 } 68 69 test("call sites report correct locations") { 70 sc = new SparkContext("local", "test") 71 testPackage.runCallSiteTest(sc) 72 } 73} 74 75/** Call site must be outside of usual org.apache.spark packages (see Utils#SPARK_CLASS_REGEX). */ 76package object testPackage extends Assertions { 77 private val CALL_SITE_REGEX = "(.+) at (.+):([0-9]+)".r 78 79 def runCallSiteTest(sc: SparkContext) { 80 val rdd = sc.makeRDD(Array(1, 2, 3, 4), 2) 81 val rddCreationSite = rdd.getCreationSite 82 val curCallSite = sc.getCallSite().shortForm // note: 2 lines after definition of "rdd" 83 84 val rddCreationLine = rddCreationSite match { 85 case CALL_SITE_REGEX(func, file, line) => 86 assert(func === "makeRDD") 87 assert(file === "SparkContextInfoSuite.scala") 88 line.toInt 89 case _ => fail("Did not match expected call site format") 90 } 91 92 curCallSite match { 93 case CALL_SITE_REGEX(func, file, line) => 94 assert(func === "getCallSite") // this is correct because we called it from outside of Spark 95 assert(file === "SparkContextInfoSuite.scala") 96 assert(line.toInt === rddCreationLine.toInt + 2) 97 case _ => fail("Did not match expected call site format") 98 } 99 } 100} 101