1/*
2 * Licensed to the Apache Software Foundation (ASF) under one or more
3 * contributor license agreements.  See the NOTICE file distributed with
4 * this work for additional information regarding copyright ownership.
5 * The ASF licenses this file to You under the Apache License, Version 2.0
6 * (the "License"); you may not use this file except in compliance with
7 * the License.  You may obtain a copy of the License at
8 *
9 *    http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17
18package org.apache.spark.sql
19
20import org.apache.spark.sql.functions._
21import org.apache.spark.sql.test.SharedSQLContext
22import org.apache.spark.storage.StorageLevel
23
24
25class DatasetCacheSuite extends QueryTest with SharedSQLContext {
26  import testImplicits._
27
28  test("get storage level") {
29    val ds1 = Seq("1", "2").toDS().as("a")
30    val ds2 = Seq(2, 3).toDS().as("b")
31
32    // default storage level
33    ds1.persist()
34    ds2.cache()
35    assert(ds1.storageLevel == StorageLevel.MEMORY_AND_DISK)
36    assert(ds2.storageLevel == StorageLevel.MEMORY_AND_DISK)
37    // unpersist
38    ds1.unpersist()
39    assert(ds1.storageLevel == StorageLevel.NONE)
40    // non-default storage level
41    ds1.persist(StorageLevel.MEMORY_ONLY_2)
42    assert(ds1.storageLevel == StorageLevel.MEMORY_ONLY_2)
43    // joined Dataset should not be persisted
44    val joined = ds1.joinWith(ds2, $"a.value" === $"b.value")
45    assert(joined.storageLevel == StorageLevel.NONE)
46  }
47
48  test("persist and unpersist") {
49    val ds = Seq(("a", 1), ("b", 2), ("c", 3)).toDS().select(expr("_2 + 1").as[Int])
50    val cached = ds.cache()
51    // count triggers the caching action. It should not throw.
52    cached.count()
53    // Make sure, the Dataset is indeed cached.
54    assertCached(cached)
55    // Check result.
56    checkDataset(
57      cached,
58      2, 3, 4)
59    // Drop the cache.
60    cached.unpersist()
61    assert(cached.storageLevel == StorageLevel.NONE, "The Dataset should not be cached.")
62  }
63
64  test("persist and then rebind right encoder when join 2 datasets") {
65    val ds1 = Seq("1", "2").toDS().as("a")
66    val ds2 = Seq(2, 3).toDS().as("b")
67
68    ds1.persist()
69    assertCached(ds1)
70    ds2.persist()
71    assertCached(ds2)
72
73    val joined = ds1.joinWith(ds2, $"a.value" === $"b.value")
74    checkDataset(joined, ("2", 2))
75    assertCached(joined, 2)
76
77    ds1.unpersist()
78    assert(ds1.storageLevel == StorageLevel.NONE, "The Dataset ds1 should not be cached.")
79    ds2.unpersist()
80    assert(ds2.storageLevel == StorageLevel.NONE, "The Dataset ds2 should not be cached.")
81  }
82
83  test("persist and then groupBy columns asKey, map") {
84    val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS()
85    val grouped = ds.groupByKey(_._1)
86    val agged = grouped.mapGroups { case (g, iter) => (g, iter.map(_._2).sum) }
87    agged.persist()
88
89    checkDataset(
90      agged.filter(_._1 == "b"),
91      ("b", 3))
92    assertCached(agged.filter(_._1 == "b"))
93
94    ds.unpersist()
95    assert(ds.storageLevel == StorageLevel.NONE, "The Dataset ds should not be cached.")
96    agged.unpersist()
97    assert(agged.storageLevel == StorageLevel.NONE, "The Dataset agged should not be cached.")
98  }
99}
100