1/* 2 * Licensed to the Apache Software Foundation (ASF) under one or more 3 * contributor license agreements. See the NOTICE file distributed with 4 * this work for additional information regarding copyright ownership. 5 * The ASF licenses this file to You under the Apache License, Version 2.0 6 * (the "License"); you may not use this file except in compliance with 7 * the License. You may obtain a copy of the License at 8 * 9 * http://www.apache.org/licenses/LICENSE-2.0 10 * 11 * Unless required by applicable law or agreed to in writing, software 12 * distributed under the License is distributed on an "AS IS" BASIS, 13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 * See the License for the specific language governing permissions and 15 * limitations under the License. 16 */ 17 18package org.apache.spark.sql 19 20import org.apache.spark.sql.functions._ 21import org.apache.spark.sql.test.SharedSQLContext 22import org.apache.spark.storage.StorageLevel 23 24 25class DatasetCacheSuite extends QueryTest with SharedSQLContext { 26 import testImplicits._ 27 28 test("get storage level") { 29 val ds1 = Seq("1", "2").toDS().as("a") 30 val ds2 = Seq(2, 3).toDS().as("b") 31 32 // default storage level 33 ds1.persist() 34 ds2.cache() 35 assert(ds1.storageLevel == StorageLevel.MEMORY_AND_DISK) 36 assert(ds2.storageLevel == StorageLevel.MEMORY_AND_DISK) 37 // unpersist 38 ds1.unpersist() 39 assert(ds1.storageLevel == StorageLevel.NONE) 40 // non-default storage level 41 ds1.persist(StorageLevel.MEMORY_ONLY_2) 42 assert(ds1.storageLevel == StorageLevel.MEMORY_ONLY_2) 43 // joined Dataset should not be persisted 44 val joined = ds1.joinWith(ds2, $"a.value" === $"b.value") 45 assert(joined.storageLevel == StorageLevel.NONE) 46 } 47 48 test("persist and unpersist") { 49 val ds = Seq(("a", 1), ("b", 2), ("c", 3)).toDS().select(expr("_2 + 1").as[Int]) 50 val cached = ds.cache() 51 // count triggers the caching action. It should not throw. 52 cached.count() 53 // Make sure, the Dataset is indeed cached. 54 assertCached(cached) 55 // Check result. 56 checkDataset( 57 cached, 58 2, 3, 4) 59 // Drop the cache. 60 cached.unpersist() 61 assert(cached.storageLevel == StorageLevel.NONE, "The Dataset should not be cached.") 62 } 63 64 test("persist and then rebind right encoder when join 2 datasets") { 65 val ds1 = Seq("1", "2").toDS().as("a") 66 val ds2 = Seq(2, 3).toDS().as("b") 67 68 ds1.persist() 69 assertCached(ds1) 70 ds2.persist() 71 assertCached(ds2) 72 73 val joined = ds1.joinWith(ds2, $"a.value" === $"b.value") 74 checkDataset(joined, ("2", 2)) 75 assertCached(joined, 2) 76 77 ds1.unpersist() 78 assert(ds1.storageLevel == StorageLevel.NONE, "The Dataset ds1 should not be cached.") 79 ds2.unpersist() 80 assert(ds2.storageLevel == StorageLevel.NONE, "The Dataset ds2 should not be cached.") 81 } 82 83 test("persist and then groupBy columns asKey, map") { 84 val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() 85 val grouped = ds.groupByKey(_._1) 86 val agged = grouped.mapGroups { case (g, iter) => (g, iter.map(_._2).sum) } 87 agged.persist() 88 89 checkDataset( 90 agged.filter(_._1 == "b"), 91 ("b", 3)) 92 assertCached(agged.filter(_._1 == "b")) 93 94 ds.unpersist() 95 assert(ds.storageLevel == StorageLevel.NONE, "The Dataset ds should not be cached.") 96 agged.unpersist() 97 assert(agged.storageLevel == StorageLevel.NONE, "The Dataset agged should not be cached.") 98 } 99} 100