1/* 2 * Licensed to the Apache Software Foundation (ASF) under one or more 3 * contributor license agreements. See the NOTICE file distributed with 4 * this work for additional information regarding copyright ownership. 5 * The ASF licenses this file to You under the Apache License, Version 2.0 6 * (the "License"); you may not use this file except in compliance with 7 * the License. You may obtain a copy of the License at 8 * 9 * http://www.apache.org/licenses/LICENSE-2.0 10 * 11 * Unless required by applicable law or agreed to in writing, software 12 * distributed under the License is distributed on an "AS IS" BASIS, 13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 * See the License for the specific language governing permissions and 15 * limitations under the License. 16 */ 17 18package org.apache.spark.sql.catalyst.expressions 19 20import java.sql.Timestamp 21 22import org.apache.spark.SparkFunSuite 23import org.apache.spark.metrics.source.CodegenMetrics 24import org.apache.spark.sql.Row 25import org.apache.spark.sql.catalyst.InternalRow 26import org.apache.spark.sql.catalyst.dsl.expressions._ 27import org.apache.spark.sql.catalyst.expressions.codegen._ 28import org.apache.spark.sql.catalyst.expressions.objects.{CreateExternalRow, GetExternalRowField, ValidateExternalType} 29import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, DateTimeUtils, GenericArrayData} 30import org.apache.spark.sql.types._ 31import org.apache.spark.unsafe.types.UTF8String 32import org.apache.spark.util.ThreadUtils 33 34/** 35 * Additional tests for code generation. 36 */ 37class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper { 38 39 test("multithreaded eval") { 40 import scala.concurrent._ 41 import ExecutionContext.Implicits.global 42 import scala.concurrent.duration._ 43 44 val futures = (1 to 20).map { _ => 45 Future { 46 GeneratePredicate.generate(EqualTo(Literal(1), Literal(1))) 47 GenerateMutableProjection.generate(EqualTo(Literal(1), Literal(1)) :: Nil) 48 GenerateOrdering.generate(Add(Literal(1), Literal(1)).asc :: Nil) 49 } 50 } 51 52 futures.foreach(ThreadUtils.awaitResult(_, 10.seconds)) 53 } 54 55 test("metrics are recorded on compile") { 56 val startCount1 = CodegenMetrics.METRIC_COMPILATION_TIME.getCount() 57 val startCount2 = CodegenMetrics.METRIC_SOURCE_CODE_SIZE.getCount() 58 val startCount3 = CodegenMetrics.METRIC_GENERATED_CLASS_BYTECODE_SIZE.getCount() 59 val startCount4 = CodegenMetrics.METRIC_GENERATED_METHOD_BYTECODE_SIZE.getCount() 60 GenerateOrdering.generate(Add(Literal(123), Literal(1)).asc :: Nil) 61 assert(CodegenMetrics.METRIC_COMPILATION_TIME.getCount() == startCount1 + 1) 62 assert(CodegenMetrics.METRIC_SOURCE_CODE_SIZE.getCount() == startCount2 + 1) 63 assert(CodegenMetrics.METRIC_GENERATED_CLASS_BYTECODE_SIZE.getCount() > startCount3) 64 assert(CodegenMetrics.METRIC_GENERATED_METHOD_BYTECODE_SIZE.getCount() > startCount4) 65 } 66 67 test("SPARK-8443: split wide projections into blocks due to JVM code size limit") { 68 val length = 5000 69 val expressions = List.fill(length)(EqualTo(Literal(1), Literal(1))) 70 val plan = GenerateMutableProjection.generate(expressions) 71 val actual = plan(new GenericInternalRow(length)).toSeq(expressions.map(_.dataType)) 72 val expected = Seq.fill(length)(true) 73 74 if (!checkResult(actual, expected)) { 75 fail(s"Incorrect Evaluation: expressions: $expressions, actual: $actual, expected: $expected") 76 } 77 } 78 79 test("SPARK-13242: case-when expression with large number of branches (or cases)") { 80 val cases = 50 81 val clauses = 20 82 83 // Generate an individual case 84 def generateCase(n: Int): (Expression, Expression) = { 85 val condition = (1 to clauses) 86 .map(c => EqualTo(BoundReference(0, StringType, false), Literal(s"$c:$n"))) 87 .reduceLeft[Expression]((l, r) => Or(l, r)) 88 (condition, Literal(n)) 89 } 90 91 val expression = CaseWhen((1 to cases).map(generateCase(_))) 92 93 val plan = GenerateMutableProjection.generate(Seq(expression)) 94 val input = new GenericInternalRow(Array[Any](UTF8String.fromString(s"${clauses}:${cases}"))) 95 val actual = plan(input).toSeq(Seq(expression.dataType)) 96 97 assert(actual(0) == cases) 98 } 99 100 test("SPARK-18091: split large if expressions into blocks due to JVM code size limit") { 101 var strExpr: Expression = Literal("abc") 102 for (_ <- 1 to 150) { 103 strExpr = Decode(Encode(strExpr, "utf-8"), "utf-8") 104 } 105 106 val expressions = Seq(If(EqualTo(strExpr, strExpr), strExpr, strExpr)) 107 val plan = GenerateMutableProjection.generate(expressions) 108 val actual = plan(null).toSeq(expressions.map(_.dataType)) 109 val expected = Seq(UTF8String.fromString("abc")) 110 111 if (!checkResult(actual, expected)) { 112 fail(s"Incorrect Evaluation: expressions: $expressions, actual: $actual, expected: $expected") 113 } 114 } 115 116 test("SPARK-14793: split wide array creation into blocks due to JVM code size limit") { 117 val length = 5000 118 val expressions = Seq(CreateArray(List.fill(length)(EqualTo(Literal(1), Literal(1))))) 119 val plan = GenerateMutableProjection.generate(expressions) 120 val actual = plan(new GenericInternalRow(length)).toSeq(expressions.map(_.dataType)) 121 val expected = Seq(new GenericArrayData(Seq.fill(length)(true))) 122 123 if (!checkResult(actual, expected)) { 124 fail(s"Incorrect Evaluation: expressions: $expressions, actual: $actual, expected: $expected") 125 } 126 } 127 128 test("SPARK-14793: split wide map creation into blocks due to JVM code size limit") { 129 val length = 5000 130 val expressions = Seq(CreateMap( 131 List.fill(length)(EqualTo(Literal(1), Literal(1))).zipWithIndex.flatMap { 132 case (expr, i) => Seq(Literal(i), expr) 133 })) 134 val plan = GenerateMutableProjection.generate(expressions) 135 val actual = plan(new GenericInternalRow(length)).toSeq(expressions.map(_.dataType)).map { 136 case m: ArrayBasedMapData => ArrayBasedMapData.toScalaMap(m) 137 } 138 val expected = (0 until length).map((_, true)).toMap :: Nil 139 140 if (!checkResult(actual, expected)) { 141 fail(s"Incorrect Evaluation: expressions: $expressions, actual: $actual, expected: $expected") 142 } 143 } 144 145 test("SPARK-14793: split wide struct creation into blocks due to JVM code size limit") { 146 val length = 5000 147 val expressions = Seq(CreateStruct(List.fill(length)(EqualTo(Literal(1), Literal(1))))) 148 val plan = GenerateMutableProjection.generate(expressions) 149 val actual = plan(new GenericInternalRow(length)).toSeq(expressions.map(_.dataType)) 150 val expected = Seq(InternalRow(Seq.fill(length)(true): _*)) 151 152 if (!checkResult(actual, expected)) { 153 fail(s"Incorrect Evaluation: expressions: $expressions, actual: $actual, expected: $expected") 154 } 155 } 156 157 test("SPARK-14793: split wide named struct creation into blocks due to JVM code size limit") { 158 val length = 5000 159 val expressions = Seq(CreateNamedStruct( 160 List.fill(length)(EqualTo(Literal(1), Literal(1))).flatMap { 161 expr => Seq(Literal(expr.toString), expr) 162 })) 163 val plan = GenerateMutableProjection.generate(expressions) 164 val actual = plan(new GenericInternalRow(length)).toSeq(expressions.map(_.dataType)) 165 val expected = Seq(InternalRow(Seq.fill(length)(true): _*)) 166 167 if (!checkResult(actual, expected)) { 168 fail(s"Incorrect Evaluation: expressions: $expressions, actual: $actual, expected: $expected") 169 } 170 } 171 172 test("SPARK-14224: split wide external row creation into blocks due to JVM code size limit") { 173 val length = 5000 174 val schema = StructType(Seq.fill(length)(StructField("int", IntegerType))) 175 val expressions = Seq(CreateExternalRow(Seq.fill(length)(Literal(1)), schema)) 176 val plan = GenerateMutableProjection.generate(expressions) 177 val actual = plan(new GenericInternalRow(length)).toSeq(expressions.map(_.dataType)) 178 val expected = Seq(Row.fromSeq(Seq.fill(length)(1))) 179 180 if (!checkResult(actual, expected)) { 181 fail(s"Incorrect Evaluation: expressions: $expressions, actual: $actual, expected: $expected") 182 } 183 } 184 185 test("SPARK-17702: split wide constructor into blocks due to JVM code size limit") { 186 val length = 5000 187 val expressions = Seq.fill(length) { 188 ToUTCTimestamp( 189 Literal.create(Timestamp.valueOf("2015-07-24 00:00:00"), TimestampType), 190 Literal.create("PST", StringType)) 191 } 192 val plan = GenerateMutableProjection.generate(expressions) 193 val actual = plan(new GenericInternalRow(length)).toSeq(expressions.map(_.dataType)) 194 val expected = Seq.fill(length)( 195 DateTimeUtils.fromJavaTimestamp(Timestamp.valueOf("2015-07-24 07:00:00"))) 196 197 if (!checkResult(actual, expected)) { 198 fail(s"Incorrect Evaluation: expressions: $expressions, actual: $actual, expected: $expected") 199 } 200 } 201 202 test("test generated safe and unsafe projection") { 203 val schema = new StructType(Array( 204 StructField("a", StringType, true), 205 StructField("b", IntegerType, true), 206 StructField("c", new StructType(Array( 207 StructField("aa", StringType, true), 208 StructField("bb", IntegerType, true) 209 )), true), 210 StructField("d", new StructType(Array( 211 StructField("a", new StructType(Array( 212 StructField("b", StringType, true), 213 StructField("", IntegerType, true) 214 )), true) 215 )), true) 216 )) 217 val row = Row("a", 1, Row("b", 2), Row(Row("c", 3))) 218 val lit = Literal.create(row, schema) 219 val internalRow = lit.value.asInstanceOf[InternalRow] 220 221 val unsafeProj = UnsafeProjection.create(schema) 222 val unsafeRow: UnsafeRow = unsafeProj(internalRow) 223 assert(unsafeRow.getUTF8String(0) === UTF8String.fromString("a")) 224 assert(unsafeRow.getInt(1) === 1) 225 assert(unsafeRow.getStruct(2, 2).getUTF8String(0) === UTF8String.fromString("b")) 226 assert(unsafeRow.getStruct(2, 2).getInt(1) === 2) 227 assert(unsafeRow.getStruct(3, 1).getStruct(0, 2).getUTF8String(0) === 228 UTF8String.fromString("c")) 229 assert(unsafeRow.getStruct(3, 1).getStruct(0, 2).getInt(1) === 3) 230 231 val fromUnsafe = FromUnsafeProjection(schema) 232 val internalRow2 = fromUnsafe(unsafeRow) 233 assert(internalRow === internalRow2) 234 235 // update unsafeRow should not affect internalRow2 236 unsafeRow.setInt(1, 10) 237 unsafeRow.getStruct(2, 2).setInt(1, 10) 238 unsafeRow.getStruct(3, 1).getStruct(0, 2).setInt(1, 4) 239 assert(internalRow === internalRow2) 240 } 241 242 test("*/ in the data") { 243 // When */ appears in a comment block (i.e. in /**/), code gen will break. 244 // So, in Expression and CodegenFallback, we escape */ to \*\/. 245 checkEvaluation( 246 EqualTo(BoundReference(0, StringType, false), Literal.create("*/", StringType)), 247 true, 248 InternalRow(UTF8String.fromString("*/"))) 249 } 250 251 test("\\u in the data") { 252 // When \ u appears in a comment block (i.e. in /**/), code gen will break. 253 // So, in Expression and CodegenFallback, we escape \ u to \\u. 254 checkEvaluation( 255 EqualTo(BoundReference(0, StringType, false), Literal.create("\\u", StringType)), 256 true, 257 InternalRow(UTF8String.fromString("\\u"))) 258 } 259 260 test("check compilation error doesn't occur caused by specific literal") { 261 // The end of comment (*/) should be escaped. 262 GenerateUnsafeProjection.generate( 263 Literal.create("*/Compilation error occurs/*", StringType) :: Nil) 264 265 // `\u002A` is `*` and `\u002F` is `/` 266 // so if the end of comment consists of those characters in queries, we need to escape them. 267 GenerateUnsafeProjection.generate( 268 Literal.create("\\u002A/Compilation error occurs/*", StringType) :: Nil) 269 GenerateUnsafeProjection.generate( 270 Literal.create("\\\\u002A/Compilation error occurs/*", StringType) :: Nil) 271 GenerateUnsafeProjection.generate( 272 Literal.create("\\u002a/Compilation error occurs/*", StringType) :: Nil) 273 GenerateUnsafeProjection.generate( 274 Literal.create("\\\\u002a/Compilation error occurs/*", StringType) :: Nil) 275 GenerateUnsafeProjection.generate( 276 Literal.create("*\\u002FCompilation error occurs/*", StringType) :: Nil) 277 GenerateUnsafeProjection.generate( 278 Literal.create("*\\\\u002FCompilation error occurs/*", StringType) :: Nil) 279 GenerateUnsafeProjection.generate( 280 Literal.create("*\\002fCompilation error occurs/*", StringType) :: Nil) 281 GenerateUnsafeProjection.generate( 282 Literal.create("*\\\\002fCompilation error occurs/*", StringType) :: Nil) 283 GenerateUnsafeProjection.generate( 284 Literal.create("\\002A\\002FCompilation error occurs/*", StringType) :: Nil) 285 GenerateUnsafeProjection.generate( 286 Literal.create("\\\\002A\\002FCompilation error occurs/*", StringType) :: Nil) 287 GenerateUnsafeProjection.generate( 288 Literal.create("\\002A\\\\002FCompilation error occurs/*", StringType) :: Nil) 289 290 // \ u002X is an invalid unicode literal so it should be escaped. 291 GenerateUnsafeProjection.generate( 292 Literal.create("\\u002X/Compilation error occurs", StringType) :: Nil) 293 GenerateUnsafeProjection.generate( 294 Literal.create("\\\\u002X/Compilation error occurs", StringType) :: Nil) 295 296 // \ u001 is an invalid unicode literal so it should be escaped. 297 GenerateUnsafeProjection.generate( 298 Literal.create("\\u001/Compilation error occurs", StringType) :: Nil) 299 GenerateUnsafeProjection.generate( 300 Literal.create("\\\\u001/Compilation error occurs", StringType) :: Nil) 301 302 } 303 304 test("SPARK-17160: field names are properly escaped by GetExternalRowField") { 305 val inputObject = BoundReference(0, ObjectType(classOf[Row]), nullable = true) 306 GenerateUnsafeProjection.generate( 307 ValidateExternalType( 308 GetExternalRowField(inputObject, index = 0, fieldName = "\"quote"), IntegerType) :: Nil) 309 } 310 311 test("SPARK-17160: field names are properly escaped by AssertTrue") { 312 GenerateUnsafeProjection.generate(AssertTrue(Cast(Literal("\""), BooleanType)) :: Nil) 313 } 314} 315