1/*
2 * Licensed to the Apache Software Foundation (ASF) under one or more
3 * contributor license agreements.  See the NOTICE file distributed with
4 * this work for additional information regarding copyright ownership.
5 * The ASF licenses this file to You under the Apache License, Version 2.0
6 * (the "License"); you may not use this file except in compliance with
7 * the License.  You may obtain a copy of the License at
8 *
9 *    http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17
18package org.apache.spark.sql.catalyst.expressions
19
20import java.sql.Timestamp
21
22import org.apache.spark.SparkFunSuite
23import org.apache.spark.metrics.source.CodegenMetrics
24import org.apache.spark.sql.Row
25import org.apache.spark.sql.catalyst.InternalRow
26import org.apache.spark.sql.catalyst.dsl.expressions._
27import org.apache.spark.sql.catalyst.expressions.codegen._
28import org.apache.spark.sql.catalyst.expressions.objects.{CreateExternalRow, GetExternalRowField, ValidateExternalType}
29import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, DateTimeUtils, GenericArrayData}
30import org.apache.spark.sql.types._
31import org.apache.spark.unsafe.types.UTF8String
32import org.apache.spark.util.ThreadUtils
33
34/**
35 * Additional tests for code generation.
36 */
37class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper {
38
39  test("multithreaded eval") {
40    import scala.concurrent._
41    import ExecutionContext.Implicits.global
42    import scala.concurrent.duration._
43
44    val futures = (1 to 20).map { _ =>
45      Future {
46        GeneratePredicate.generate(EqualTo(Literal(1), Literal(1)))
47        GenerateMutableProjection.generate(EqualTo(Literal(1), Literal(1)) :: Nil)
48        GenerateOrdering.generate(Add(Literal(1), Literal(1)).asc :: Nil)
49      }
50    }
51
52    futures.foreach(ThreadUtils.awaitResult(_, 10.seconds))
53  }
54
55  test("metrics are recorded on compile") {
56    val startCount1 = CodegenMetrics.METRIC_COMPILATION_TIME.getCount()
57    val startCount2 = CodegenMetrics.METRIC_SOURCE_CODE_SIZE.getCount()
58    val startCount3 = CodegenMetrics.METRIC_GENERATED_CLASS_BYTECODE_SIZE.getCount()
59    val startCount4 = CodegenMetrics.METRIC_GENERATED_METHOD_BYTECODE_SIZE.getCount()
60    GenerateOrdering.generate(Add(Literal(123), Literal(1)).asc :: Nil)
61    assert(CodegenMetrics.METRIC_COMPILATION_TIME.getCount() == startCount1 + 1)
62    assert(CodegenMetrics.METRIC_SOURCE_CODE_SIZE.getCount() == startCount2 + 1)
63    assert(CodegenMetrics.METRIC_GENERATED_CLASS_BYTECODE_SIZE.getCount() > startCount3)
64    assert(CodegenMetrics.METRIC_GENERATED_METHOD_BYTECODE_SIZE.getCount() > startCount4)
65  }
66
67  test("SPARK-8443: split wide projections into blocks due to JVM code size limit") {
68    val length = 5000
69    val expressions = List.fill(length)(EqualTo(Literal(1), Literal(1)))
70    val plan = GenerateMutableProjection.generate(expressions)
71    val actual = plan(new GenericInternalRow(length)).toSeq(expressions.map(_.dataType))
72    val expected = Seq.fill(length)(true)
73
74    if (!checkResult(actual, expected)) {
75      fail(s"Incorrect Evaluation: expressions: $expressions, actual: $actual, expected: $expected")
76    }
77  }
78
79  test("SPARK-13242: case-when expression with large number of branches (or cases)") {
80    val cases = 50
81    val clauses = 20
82
83    // Generate an individual case
84    def generateCase(n: Int): (Expression, Expression) = {
85      val condition = (1 to clauses)
86        .map(c => EqualTo(BoundReference(0, StringType, false), Literal(s"$c:$n")))
87        .reduceLeft[Expression]((l, r) => Or(l, r))
88      (condition, Literal(n))
89    }
90
91    val expression = CaseWhen((1 to cases).map(generateCase(_)))
92
93    val plan = GenerateMutableProjection.generate(Seq(expression))
94    val input = new GenericInternalRow(Array[Any](UTF8String.fromString(s"${clauses}:${cases}")))
95    val actual = plan(input).toSeq(Seq(expression.dataType))
96
97    assert(actual(0) == cases)
98  }
99
100  test("SPARK-18091: split large if expressions into blocks due to JVM code size limit") {
101    var strExpr: Expression = Literal("abc")
102    for (_ <- 1 to 150) {
103      strExpr = Decode(Encode(strExpr, "utf-8"), "utf-8")
104    }
105
106    val expressions = Seq(If(EqualTo(strExpr, strExpr), strExpr, strExpr))
107    val plan = GenerateMutableProjection.generate(expressions)
108    val actual = plan(null).toSeq(expressions.map(_.dataType))
109    val expected = Seq(UTF8String.fromString("abc"))
110
111    if (!checkResult(actual, expected)) {
112      fail(s"Incorrect Evaluation: expressions: $expressions, actual: $actual, expected: $expected")
113    }
114  }
115
116  test("SPARK-14793: split wide array creation into blocks due to JVM code size limit") {
117    val length = 5000
118    val expressions = Seq(CreateArray(List.fill(length)(EqualTo(Literal(1), Literal(1)))))
119    val plan = GenerateMutableProjection.generate(expressions)
120    val actual = plan(new GenericInternalRow(length)).toSeq(expressions.map(_.dataType))
121    val expected = Seq(new GenericArrayData(Seq.fill(length)(true)))
122
123    if (!checkResult(actual, expected)) {
124      fail(s"Incorrect Evaluation: expressions: $expressions, actual: $actual, expected: $expected")
125    }
126  }
127
128  test("SPARK-14793: split wide map creation into blocks due to JVM code size limit") {
129    val length = 5000
130    val expressions = Seq(CreateMap(
131      List.fill(length)(EqualTo(Literal(1), Literal(1))).zipWithIndex.flatMap {
132        case (expr, i) => Seq(Literal(i), expr)
133      }))
134    val plan = GenerateMutableProjection.generate(expressions)
135    val actual = plan(new GenericInternalRow(length)).toSeq(expressions.map(_.dataType)).map {
136      case m: ArrayBasedMapData => ArrayBasedMapData.toScalaMap(m)
137    }
138    val expected = (0 until length).map((_, true)).toMap :: Nil
139
140    if (!checkResult(actual, expected)) {
141      fail(s"Incorrect Evaluation: expressions: $expressions, actual: $actual, expected: $expected")
142    }
143  }
144
145  test("SPARK-14793: split wide struct creation into blocks due to JVM code size limit") {
146    val length = 5000
147    val expressions = Seq(CreateStruct(List.fill(length)(EqualTo(Literal(1), Literal(1)))))
148    val plan = GenerateMutableProjection.generate(expressions)
149    val actual = plan(new GenericInternalRow(length)).toSeq(expressions.map(_.dataType))
150    val expected = Seq(InternalRow(Seq.fill(length)(true): _*))
151
152    if (!checkResult(actual, expected)) {
153      fail(s"Incorrect Evaluation: expressions: $expressions, actual: $actual, expected: $expected")
154    }
155  }
156
157  test("SPARK-14793: split wide named struct creation into blocks due to JVM code size limit") {
158    val length = 5000
159    val expressions = Seq(CreateNamedStruct(
160      List.fill(length)(EqualTo(Literal(1), Literal(1))).flatMap {
161        expr => Seq(Literal(expr.toString), expr)
162      }))
163    val plan = GenerateMutableProjection.generate(expressions)
164    val actual = plan(new GenericInternalRow(length)).toSeq(expressions.map(_.dataType))
165    val expected = Seq(InternalRow(Seq.fill(length)(true): _*))
166
167    if (!checkResult(actual, expected)) {
168      fail(s"Incorrect Evaluation: expressions: $expressions, actual: $actual, expected: $expected")
169    }
170  }
171
172  test("SPARK-14224: split wide external row creation into blocks due to JVM code size limit") {
173    val length = 5000
174    val schema = StructType(Seq.fill(length)(StructField("int", IntegerType)))
175    val expressions = Seq(CreateExternalRow(Seq.fill(length)(Literal(1)), schema))
176    val plan = GenerateMutableProjection.generate(expressions)
177    val actual = plan(new GenericInternalRow(length)).toSeq(expressions.map(_.dataType))
178    val expected = Seq(Row.fromSeq(Seq.fill(length)(1)))
179
180    if (!checkResult(actual, expected)) {
181      fail(s"Incorrect Evaluation: expressions: $expressions, actual: $actual, expected: $expected")
182    }
183  }
184
185  test("SPARK-17702: split wide constructor into blocks due to JVM code size limit") {
186    val length = 5000
187    val expressions = Seq.fill(length) {
188      ToUTCTimestamp(
189        Literal.create(Timestamp.valueOf("2015-07-24 00:00:00"), TimestampType),
190        Literal.create("PST", StringType))
191    }
192    val plan = GenerateMutableProjection.generate(expressions)
193    val actual = plan(new GenericInternalRow(length)).toSeq(expressions.map(_.dataType))
194    val expected = Seq.fill(length)(
195      DateTimeUtils.fromJavaTimestamp(Timestamp.valueOf("2015-07-24 07:00:00")))
196
197    if (!checkResult(actual, expected)) {
198      fail(s"Incorrect Evaluation: expressions: $expressions, actual: $actual, expected: $expected")
199    }
200  }
201
202  test("test generated safe and unsafe projection") {
203    val schema = new StructType(Array(
204      StructField("a", StringType, true),
205      StructField("b", IntegerType, true),
206      StructField("c", new StructType(Array(
207        StructField("aa", StringType, true),
208        StructField("bb", IntegerType, true)
209      )), true),
210      StructField("d", new StructType(Array(
211        StructField("a", new StructType(Array(
212          StructField("b", StringType, true),
213          StructField("", IntegerType, true)
214        )), true)
215      )), true)
216    ))
217    val row = Row("a", 1, Row("b", 2), Row(Row("c", 3)))
218    val lit = Literal.create(row, schema)
219    val internalRow = lit.value.asInstanceOf[InternalRow]
220
221    val unsafeProj = UnsafeProjection.create(schema)
222    val unsafeRow: UnsafeRow = unsafeProj(internalRow)
223    assert(unsafeRow.getUTF8String(0) === UTF8String.fromString("a"))
224    assert(unsafeRow.getInt(1) === 1)
225    assert(unsafeRow.getStruct(2, 2).getUTF8String(0) === UTF8String.fromString("b"))
226    assert(unsafeRow.getStruct(2, 2).getInt(1) === 2)
227    assert(unsafeRow.getStruct(3, 1).getStruct(0, 2).getUTF8String(0) ===
228      UTF8String.fromString("c"))
229    assert(unsafeRow.getStruct(3, 1).getStruct(0, 2).getInt(1) === 3)
230
231    val fromUnsafe = FromUnsafeProjection(schema)
232    val internalRow2 = fromUnsafe(unsafeRow)
233    assert(internalRow === internalRow2)
234
235    // update unsafeRow should not affect internalRow2
236    unsafeRow.setInt(1, 10)
237    unsafeRow.getStruct(2, 2).setInt(1, 10)
238    unsafeRow.getStruct(3, 1).getStruct(0, 2).setInt(1, 4)
239    assert(internalRow === internalRow2)
240  }
241
242  test("*/ in the data") {
243    // When */ appears in a comment block (i.e. in /**/), code gen will break.
244    // So, in Expression and CodegenFallback, we escape */ to \*\/.
245    checkEvaluation(
246      EqualTo(BoundReference(0, StringType, false), Literal.create("*/", StringType)),
247      true,
248      InternalRow(UTF8String.fromString("*/")))
249  }
250
251  test("\\u in the data") {
252    // When \ u appears in a comment block (i.e. in /**/), code gen will break.
253    // So, in Expression and CodegenFallback, we escape \ u to \\u.
254    checkEvaluation(
255      EqualTo(BoundReference(0, StringType, false), Literal.create("\\u", StringType)),
256      true,
257      InternalRow(UTF8String.fromString("\\u")))
258  }
259
260  test("check compilation error doesn't occur caused by specific literal") {
261    // The end of comment (*/) should be escaped.
262    GenerateUnsafeProjection.generate(
263      Literal.create("*/Compilation error occurs/*", StringType) :: Nil)
264
265    // `\u002A` is `*` and `\u002F` is `/`
266    // so if the end of comment consists of those characters in queries, we need to escape them.
267    GenerateUnsafeProjection.generate(
268      Literal.create("\\u002A/Compilation error occurs/*", StringType) :: Nil)
269    GenerateUnsafeProjection.generate(
270      Literal.create("\\\\u002A/Compilation error occurs/*", StringType) :: Nil)
271    GenerateUnsafeProjection.generate(
272      Literal.create("\\u002a/Compilation error occurs/*", StringType) :: Nil)
273    GenerateUnsafeProjection.generate(
274      Literal.create("\\\\u002a/Compilation error occurs/*", StringType) :: Nil)
275    GenerateUnsafeProjection.generate(
276      Literal.create("*\\u002FCompilation error occurs/*", StringType) :: Nil)
277    GenerateUnsafeProjection.generate(
278      Literal.create("*\\\\u002FCompilation error occurs/*", StringType) :: Nil)
279    GenerateUnsafeProjection.generate(
280      Literal.create("*\\002fCompilation error occurs/*", StringType) :: Nil)
281    GenerateUnsafeProjection.generate(
282      Literal.create("*\\\\002fCompilation error occurs/*", StringType) :: Nil)
283    GenerateUnsafeProjection.generate(
284      Literal.create("\\002A\\002FCompilation error occurs/*", StringType) :: Nil)
285    GenerateUnsafeProjection.generate(
286      Literal.create("\\\\002A\\002FCompilation error occurs/*", StringType) :: Nil)
287    GenerateUnsafeProjection.generate(
288      Literal.create("\\002A\\\\002FCompilation error occurs/*", StringType) :: Nil)
289
290    // \ u002X is an invalid unicode literal so it should be escaped.
291    GenerateUnsafeProjection.generate(
292      Literal.create("\\u002X/Compilation error occurs", StringType) :: Nil)
293    GenerateUnsafeProjection.generate(
294      Literal.create("\\\\u002X/Compilation error occurs", StringType) :: Nil)
295
296    // \ u001 is an invalid unicode literal so it should be escaped.
297    GenerateUnsafeProjection.generate(
298      Literal.create("\\u001/Compilation error occurs", StringType) :: Nil)
299    GenerateUnsafeProjection.generate(
300      Literal.create("\\\\u001/Compilation error occurs", StringType) :: Nil)
301
302  }
303
304  test("SPARK-17160: field names are properly escaped by GetExternalRowField") {
305    val inputObject = BoundReference(0, ObjectType(classOf[Row]), nullable = true)
306    GenerateUnsafeProjection.generate(
307      ValidateExternalType(
308        GetExternalRowField(inputObject, index = 0, fieldName = "\"quote"), IntegerType) :: Nil)
309  }
310
311  test("SPARK-17160: field names are properly escaped by AssertTrue") {
312    GenerateUnsafeProjection.generate(AssertTrue(Cast(Literal("\""), BooleanType)) :: Nil)
313  }
314}
315