1/*
2 * Licensed to the Apache Software Foundation (ASF) under one or more
3 * contributor license agreements.  See the NOTICE file distributed with
4 * this work for additional information regarding copyright ownership.
5 * The ASF licenses this file to You under the Apache License, Version 2.0
6 * (the "License"); you may not use this file except in compliance with
7 * the License.  You may obtain a copy of the License at
8 *
9 *    http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17
18package org.apache.spark.sql.execution
19
20import org.apache.spark.sql.Row
21import org.apache.spark.sql.execution.aggregate.HashAggregateExec
22import org.apache.spark.sql.execution.joins.BroadcastHashJoinExec
23import org.apache.spark.sql.expressions.scalalang.typed
24import org.apache.spark.sql.functions.{avg, broadcast, col, max}
25import org.apache.spark.sql.test.SharedSQLContext
26import org.apache.spark.sql.types.{IntegerType, StringType, StructType}
27
28class WholeStageCodegenSuite extends SparkPlanTest with SharedSQLContext {
29
30  test("range/filter should be combined") {
31    val df = spark.range(10).filter("id = 1").selectExpr("id + 1")
32    val plan = df.queryExecution.executedPlan
33    assert(plan.find(_.isInstanceOf[WholeStageCodegenExec]).isDefined)
34    assert(df.collect() === Array(Row(2)))
35  }
36
37  test("Aggregate should be included in WholeStageCodegen") {
38    val df = spark.range(10).groupBy().agg(max(col("id")), avg(col("id")))
39    val plan = df.queryExecution.executedPlan
40    assert(plan.find(p =>
41      p.isInstanceOf[WholeStageCodegenExec] &&
42        p.asInstanceOf[WholeStageCodegenExec].child.isInstanceOf[HashAggregateExec]).isDefined)
43    assert(df.collect() === Array(Row(9, 4.5)))
44  }
45
46  test("Aggregate with grouping keys should be included in WholeStageCodegen") {
47    val df = spark.range(3).groupBy("id").count().orderBy("id")
48    val plan = df.queryExecution.executedPlan
49    assert(plan.find(p =>
50      p.isInstanceOf[WholeStageCodegenExec] &&
51        p.asInstanceOf[WholeStageCodegenExec].child.isInstanceOf[HashAggregateExec]).isDefined)
52    assert(df.collect() === Array(Row(0, 1), Row(1, 1), Row(2, 1)))
53  }
54
55  test("BroadcastHashJoin should be included in WholeStageCodegen") {
56    val rdd = spark.sparkContext.makeRDD(Seq(Row(1, "1"), Row(1, "1"), Row(2, "2")))
57    val schema = new StructType().add("k", IntegerType).add("v", StringType)
58    val smallDF = spark.createDataFrame(rdd, schema)
59    val df = spark.range(10).join(broadcast(smallDF), col("k") === col("id"))
60    assert(df.queryExecution.executedPlan.find(p =>
61      p.isInstanceOf[WholeStageCodegenExec] &&
62        p.asInstanceOf[WholeStageCodegenExec].child.isInstanceOf[BroadcastHashJoinExec]).isDefined)
63    assert(df.collect() === Array(Row(1, 1, "1"), Row(1, 1, "1"), Row(2, 2, "2")))
64  }
65
66  test("Sort should be included in WholeStageCodegen") {
67    val df = spark.range(3, 0, -1).toDF().sort(col("id"))
68    val plan = df.queryExecution.executedPlan
69    assert(plan.find(p =>
70      p.isInstanceOf[WholeStageCodegenExec] &&
71        p.asInstanceOf[WholeStageCodegenExec].child.isInstanceOf[SortExec]).isDefined)
72    assert(df.collect() === Array(Row(1), Row(2), Row(3)))
73  }
74
75  test("MapElements should be included in WholeStageCodegen") {
76    import testImplicits._
77
78    val ds = spark.range(10).map(_.toString)
79    val plan = ds.queryExecution.executedPlan
80    assert(plan.find(p =>
81      p.isInstanceOf[WholeStageCodegenExec] &&
82      p.asInstanceOf[WholeStageCodegenExec].child.isInstanceOf[SerializeFromObjectExec]).isDefined)
83    assert(ds.collect() === 0.until(10).map(_.toString).toArray)
84  }
85
86  test("typed filter should be included in WholeStageCodegen") {
87    val ds = spark.range(10).filter(_ % 2 == 0)
88    val plan = ds.queryExecution.executedPlan
89    assert(plan.find(p =>
90      p.isInstanceOf[WholeStageCodegenExec] &&
91        p.asInstanceOf[WholeStageCodegenExec].child.isInstanceOf[FilterExec]).isDefined)
92    assert(ds.collect() === Array(0, 2, 4, 6, 8))
93  }
94
95  test("back-to-back typed filter should be included in WholeStageCodegen") {
96    val ds = spark.range(10).filter(_ % 2 == 0).filter(_ % 3 == 0)
97    val plan = ds.queryExecution.executedPlan
98    assert(plan.find(p =>
99      p.isInstanceOf[WholeStageCodegenExec] &&
100      p.asInstanceOf[WholeStageCodegenExec].child.isInstanceOf[FilterExec]).isDefined)
101    assert(ds.collect() === Array(0, 6))
102  }
103
104  test("simple typed UDAF should be included in WholeStageCodegen") {
105    import testImplicits._
106
107    val ds = Seq(("a", 10), ("b", 1), ("b", 2), ("c", 1)).toDS()
108      .groupByKey(_._1).agg(typed.sum(_._2))
109
110    val plan = ds.queryExecution.executedPlan
111    assert(plan.find(p =>
112      p.isInstanceOf[WholeStageCodegenExec] &&
113        p.asInstanceOf[WholeStageCodegenExec].child.isInstanceOf[HashAggregateExec]).isDefined)
114    assert(ds.collect() === Array(("a", 10.0), ("b", 3.0), ("c", 1.0)))
115  }
116
117  test("SPARK-19512 codegen for comparing structs is incorrect") {
118    // this would raise CompileException before the fix
119    spark.range(10)
120      .selectExpr("named_struct('a', id) as col1", "named_struct('a', id+2) as col2")
121      .filter("col1 = col2").count()
122    // this would raise java.lang.IndexOutOfBoundsException before the fix
123    spark.range(10)
124      .selectExpr("named_struct('a', id, 'b', id) as col1",
125        "named_struct('a',id+2, 'b',id+2) as col2")
126      .filter("col1 = col2").count()
127  }
128}
129