1/* 2 * Licensed to the Apache Software Foundation (ASF) under one or more 3 * contributor license agreements. See the NOTICE file distributed with 4 * this work for additional information regarding copyright ownership. 5 * The ASF licenses this file to You under the Apache License, Version 2.0 6 * (the "License"); you may not use this file except in compliance with 7 * the License. You may obtain a copy of the License at 8 * 9 * http://www.apache.org/licenses/LICENSE-2.0 10 * 11 * Unless required by applicable law or agreed to in writing, software 12 * distributed under the License is distributed on an "AS IS" BASIS, 13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 * See the License for the specific language governing permissions and 15 * limitations under the License. 16 */ 17 18package org.apache.spark.sql.execution 19 20import org.apache.spark.sql.Row 21import org.apache.spark.sql.execution.aggregate.HashAggregateExec 22import org.apache.spark.sql.execution.joins.BroadcastHashJoinExec 23import org.apache.spark.sql.expressions.scalalang.typed 24import org.apache.spark.sql.functions.{avg, broadcast, col, max} 25import org.apache.spark.sql.test.SharedSQLContext 26import org.apache.spark.sql.types.{IntegerType, StringType, StructType} 27 28class WholeStageCodegenSuite extends SparkPlanTest with SharedSQLContext { 29 30 test("range/filter should be combined") { 31 val df = spark.range(10).filter("id = 1").selectExpr("id + 1") 32 val plan = df.queryExecution.executedPlan 33 assert(plan.find(_.isInstanceOf[WholeStageCodegenExec]).isDefined) 34 assert(df.collect() === Array(Row(2))) 35 } 36 37 test("Aggregate should be included in WholeStageCodegen") { 38 val df = spark.range(10).groupBy().agg(max(col("id")), avg(col("id"))) 39 val plan = df.queryExecution.executedPlan 40 assert(plan.find(p => 41 p.isInstanceOf[WholeStageCodegenExec] && 42 p.asInstanceOf[WholeStageCodegenExec].child.isInstanceOf[HashAggregateExec]).isDefined) 43 assert(df.collect() === Array(Row(9, 4.5))) 44 } 45 46 test("Aggregate with grouping keys should be included in WholeStageCodegen") { 47 val df = spark.range(3).groupBy("id").count().orderBy("id") 48 val plan = df.queryExecution.executedPlan 49 assert(plan.find(p => 50 p.isInstanceOf[WholeStageCodegenExec] && 51 p.asInstanceOf[WholeStageCodegenExec].child.isInstanceOf[HashAggregateExec]).isDefined) 52 assert(df.collect() === Array(Row(0, 1), Row(1, 1), Row(2, 1))) 53 } 54 55 test("BroadcastHashJoin should be included in WholeStageCodegen") { 56 val rdd = spark.sparkContext.makeRDD(Seq(Row(1, "1"), Row(1, "1"), Row(2, "2"))) 57 val schema = new StructType().add("k", IntegerType).add("v", StringType) 58 val smallDF = spark.createDataFrame(rdd, schema) 59 val df = spark.range(10).join(broadcast(smallDF), col("k") === col("id")) 60 assert(df.queryExecution.executedPlan.find(p => 61 p.isInstanceOf[WholeStageCodegenExec] && 62 p.asInstanceOf[WholeStageCodegenExec].child.isInstanceOf[BroadcastHashJoinExec]).isDefined) 63 assert(df.collect() === Array(Row(1, 1, "1"), Row(1, 1, "1"), Row(2, 2, "2"))) 64 } 65 66 test("Sort should be included in WholeStageCodegen") { 67 val df = spark.range(3, 0, -1).toDF().sort(col("id")) 68 val plan = df.queryExecution.executedPlan 69 assert(plan.find(p => 70 p.isInstanceOf[WholeStageCodegenExec] && 71 p.asInstanceOf[WholeStageCodegenExec].child.isInstanceOf[SortExec]).isDefined) 72 assert(df.collect() === Array(Row(1), Row(2), Row(3))) 73 } 74 75 test("MapElements should be included in WholeStageCodegen") { 76 import testImplicits._ 77 78 val ds = spark.range(10).map(_.toString) 79 val plan = ds.queryExecution.executedPlan 80 assert(plan.find(p => 81 p.isInstanceOf[WholeStageCodegenExec] && 82 p.asInstanceOf[WholeStageCodegenExec].child.isInstanceOf[SerializeFromObjectExec]).isDefined) 83 assert(ds.collect() === 0.until(10).map(_.toString).toArray) 84 } 85 86 test("typed filter should be included in WholeStageCodegen") { 87 val ds = spark.range(10).filter(_ % 2 == 0) 88 val plan = ds.queryExecution.executedPlan 89 assert(plan.find(p => 90 p.isInstanceOf[WholeStageCodegenExec] && 91 p.asInstanceOf[WholeStageCodegenExec].child.isInstanceOf[FilterExec]).isDefined) 92 assert(ds.collect() === Array(0, 2, 4, 6, 8)) 93 } 94 95 test("back-to-back typed filter should be included in WholeStageCodegen") { 96 val ds = spark.range(10).filter(_ % 2 == 0).filter(_ % 3 == 0) 97 val plan = ds.queryExecution.executedPlan 98 assert(plan.find(p => 99 p.isInstanceOf[WholeStageCodegenExec] && 100 p.asInstanceOf[WholeStageCodegenExec].child.isInstanceOf[FilterExec]).isDefined) 101 assert(ds.collect() === Array(0, 6)) 102 } 103 104 test("simple typed UDAF should be included in WholeStageCodegen") { 105 import testImplicits._ 106 107 val ds = Seq(("a", 10), ("b", 1), ("b", 2), ("c", 1)).toDS() 108 .groupByKey(_._1).agg(typed.sum(_._2)) 109 110 val plan = ds.queryExecution.executedPlan 111 assert(plan.find(p => 112 p.isInstanceOf[WholeStageCodegenExec] && 113 p.asInstanceOf[WholeStageCodegenExec].child.isInstanceOf[HashAggregateExec]).isDefined) 114 assert(ds.collect() === Array(("a", 10.0), ("b", 3.0), ("c", 1.0))) 115 } 116 117 test("SPARK-19512 codegen for comparing structs is incorrect") { 118 // this would raise CompileException before the fix 119 spark.range(10) 120 .selectExpr("named_struct('a', id) as col1", "named_struct('a', id+2) as col2") 121 .filter("col1 = col2").count() 122 // this would raise java.lang.IndexOutOfBoundsException before the fix 123 spark.range(10) 124 .selectExpr("named_struct('a', id, 'b', id) as col1", 125 "named_struct('a',id+2, 'b',id+2) as col2") 126 .filter("col1 = col2").count() 127 } 128} 129