1/*
2 * Licensed to the Apache Software Foundation (ASF) under one or more
3 * contributor license agreements.  See the NOTICE file distributed with
4 * this work for additional information regarding copyright ownership.
5 * The ASF licenses this file to You under the Apache License, Version 2.0
6 * (the "License"); you may not use this file except in compliance with
7 * the License.  You may obtain a copy of the License at
8 *
9 *    http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17
18package org.apache.spark.sql.catalyst.analysis
19
20import org.scalatest.ShouldMatchers
21
22import org.apache.spark.sql.catalyst.{SimpleCatalystConf, TableIdentifier}
23import org.apache.spark.sql.catalyst.dsl.expressions._
24import org.apache.spark.sql.catalyst.dsl.plans._
25import org.apache.spark.sql.catalyst.expressions._
26import org.apache.spark.sql.catalyst.plans.{Cross, Inner}
27import org.apache.spark.sql.catalyst.plans.logical._
28import org.apache.spark.sql.types._
29
30
31class AnalysisSuite extends AnalysisTest with ShouldMatchers {
32  import org.apache.spark.sql.catalyst.analysis.TestRelations._
33
34  test("union project *") {
35    val plan = (1 to 120)
36      .map(_ => testRelation)
37      .fold[LogicalPlan](testRelation) { (a, b) =>
38        a.select(UnresolvedStar(None)).select('a).union(b.select(UnresolvedStar(None)))
39      }
40
41    assertAnalysisSuccess(plan)
42  }
43
44  test("check project's resolved") {
45    assert(Project(testRelation.output, testRelation).resolved)
46
47    assert(!Project(Seq(UnresolvedAttribute("a")), testRelation).resolved)
48
49    val explode = Explode(AttributeReference("a", IntegerType, nullable = true)())
50    assert(!Project(Seq(Alias(explode, "explode")()), testRelation).resolved)
51
52    assert(!Project(Seq(Alias(count(Literal(1)), "count")()), testRelation).resolved)
53  }
54
55  test("analyze project") {
56    checkAnalysis(
57      Project(Seq(UnresolvedAttribute("a")), testRelation),
58      Project(testRelation.output, testRelation))
59
60    checkAnalysis(
61      Project(Seq(UnresolvedAttribute("TbL.a")),
62        UnresolvedRelation(TableIdentifier("TaBlE"), Some("TbL"))),
63      Project(testRelation.output, testRelation))
64
65    assertAnalysisError(
66      Project(Seq(UnresolvedAttribute("tBl.a")), UnresolvedRelation(
67        TableIdentifier("TaBlE"), Some("TbL"))),
68      Seq("cannot resolve"))
69
70    checkAnalysis(
71      Project(Seq(UnresolvedAttribute("TbL.a")), UnresolvedRelation(
72        TableIdentifier("TaBlE"), Some("TbL"))),
73      Project(testRelation.output, testRelation),
74      caseSensitive = false)
75
76    checkAnalysis(
77      Project(Seq(UnresolvedAttribute("tBl.a")), UnresolvedRelation(
78        TableIdentifier("TaBlE"), Some("TbL"))),
79      Project(testRelation.output, testRelation),
80      caseSensitive = false)
81  }
82
83  test("resolve sort references - filter/limit") {
84    val a = testRelation2.output(0)
85    val b = testRelation2.output(1)
86    val c = testRelation2.output(2)
87
88    // Case 1: one missing attribute is in the leaf node and another is in the unary node
89    val plan1 = testRelation2
90      .where('a > "str").select('a, 'b)
91      .where('b > "str").select('a)
92      .sortBy('b.asc, 'c.desc)
93    val expected1 = testRelation2
94      .where(a > "str").select(a, b, c)
95      .where(b > "str").select(a, b, c)
96      .sortBy(b.asc, c.desc)
97      .select(a)
98    checkAnalysis(plan1, expected1)
99
100    // Case 2: all the missing attributes are in the leaf node
101    val plan2 = testRelation2
102      .where('a > "str").select('a)
103      .where('a > "str").select('a)
104      .sortBy('b.asc, 'c.desc)
105    val expected2 = testRelation2
106      .where(a > "str").select(a, b, c)
107      .where(a > "str").select(a, b, c)
108      .sortBy(b.asc, c.desc)
109      .select(a)
110    checkAnalysis(plan2, expected2)
111  }
112
113  test("resolve sort references - join") {
114    val a = testRelation2.output(0)
115    val b = testRelation2.output(1)
116    val c = testRelation2.output(2)
117    val h = testRelation3.output(3)
118
119    // Case: join itself can resolve all the missing attributes
120    val plan = testRelation2.join(testRelation3)
121      .where('a > "str").select('a, 'b)
122      .sortBy('c.desc, 'h.asc)
123    val expected = testRelation2.join(testRelation3)
124      .where(a > "str").select(a, b, c, h)
125      .sortBy(c.desc, h.asc)
126      .select(a, b)
127    checkAnalysis(plan, expected)
128  }
129
130  test("resolve sort references - aggregate") {
131    val a = testRelation2.output(0)
132    val b = testRelation2.output(1)
133    val c = testRelation2.output(2)
134    val alias_a3 = count(a).as("a3")
135    val alias_b = b.as("aggOrder")
136
137    // Case 1: when the child of Sort is not Aggregate,
138    //   the sort reference is handled by the rule ResolveSortReferences
139    val plan1 = testRelation2
140      .groupBy('a, 'c, 'b)('a, 'c, count('a).as("a3"))
141      .select('a, 'c, 'a3)
142      .orderBy('b.asc)
143
144    val expected1 = testRelation2
145      .groupBy(a, c, b)(a, c, alias_a3, b)
146      .select(a, c, alias_a3.toAttribute, b)
147      .orderBy(b.asc)
148      .select(a, c, alias_a3.toAttribute)
149
150    checkAnalysis(plan1, expected1)
151
152    // Case 2: when the child of Sort is Aggregate,
153    //   the sort reference is handled by the rule ResolveAggregateFunctions
154    val plan2 = testRelation2
155      .groupBy('a, 'c, 'b)('a, 'c, count('a).as("a3"))
156      .orderBy('b.asc)
157
158    val expected2 = testRelation2
159      .groupBy(a, c, b)(a, c, alias_a3, alias_b)
160      .orderBy(alias_b.toAttribute.asc)
161      .select(a, c, alias_a3.toAttribute)
162
163    checkAnalysis(plan2, expected2)
164  }
165
166  test("resolve relations") {
167    assertAnalysisError(UnresolvedRelation(TableIdentifier("tAbLe"), None), Seq())
168    checkAnalysis(UnresolvedRelation(TableIdentifier("TaBlE"), None), testRelation)
169    checkAnalysis(
170      UnresolvedRelation(TableIdentifier("tAbLe"), None), testRelation, caseSensitive = false)
171    checkAnalysis(
172      UnresolvedRelation(TableIdentifier("TaBlE"), None), testRelation, caseSensitive = false)
173  }
174
175  test("divide should be casted into fractional types") {
176    val plan = caseInsensitiveAnalyzer.execute(
177      testRelation2.select(
178        'a / Literal(2) as 'div1,
179        'a / 'b as 'div2,
180        'a / 'c as 'div3,
181        'a / 'd as 'div4,
182        'e / 'e as 'div5))
183    val pl = plan.asInstanceOf[Project].projectList
184
185    assert(pl(0).dataType == DoubleType)
186    assert(pl(1).dataType == DoubleType)
187    assert(pl(2).dataType == DoubleType)
188    assert(pl(3).dataType == DoubleType)
189    assert(pl(4).dataType == DoubleType)
190  }
191
192  test("pull out nondeterministic expressions from RepartitionByExpression") {
193    val plan = RepartitionByExpression(Seq(Rand(33)), testRelation)
194    val projected = Alias(Rand(33), "_nondeterministic")()
195    val expected =
196      Project(testRelation.output,
197        RepartitionByExpression(Seq(projected.toAttribute),
198          Project(testRelation.output :+ projected, testRelation)))
199    checkAnalysis(plan, expected)
200  }
201
202  test("pull out nondeterministic expressions from Sort") {
203    val plan = Sort(Seq(SortOrder(Rand(33), Ascending)), false, testRelation)
204    val projected = Alias(Rand(33), "_nondeterministic")()
205    val expected =
206      Project(testRelation.output,
207        Sort(Seq(SortOrder(projected.toAttribute, Ascending)), false,
208          Project(testRelation.output :+ projected, testRelation)))
209    checkAnalysis(plan, expected)
210  }
211
212  test("SPARK-9634: cleanup unnecessary Aliases in LogicalPlan") {
213    val a = testRelation.output.head
214    var plan = testRelation.select(((a + 1).as("a+1") + 2).as("col"))
215    var expected = testRelation.select((a + 1 + 2).as("col"))
216    checkAnalysis(plan, expected)
217
218    plan = testRelation.groupBy(a.as("a1").as("a2"))((min(a).as("min_a") + 1).as("col"))
219    expected = testRelation.groupBy(a)((min(a) + 1).as("col"))
220    checkAnalysis(plan, expected)
221
222    // CreateStruct is a special case that we should not trim Alias for it.
223    plan = testRelation.select(CreateStruct(Seq(a, (a + 1).as("a+1"))).as("col"))
224    expected = testRelation.select(CreateNamedStruct(Seq(
225      Literal(a.name), a,
226      Literal("a+1"), (a + 1))).as("col"))
227    checkAnalysis(plan, expected)
228  }
229
230  test("Analysis may leave unnecassary aliases") {
231    val att1 = testRelation.output.head
232    var plan = testRelation.select(
233      CreateStruct(Seq(att1, ((att1.as("aa")) + 1).as("a_plus_1"))).as("col"),
234      att1
235    )
236    val prevPlan = getAnalyzer(true).execute(plan)
237    plan = prevPlan.select(CreateArray(Seq(
238      CreateStruct(Seq(att1, (att1 + 1).as("a_plus_1"))).as("col1"),
239      /** alias should be eliminated by [[CleanupAliases]] */
240      "col".attr.as("col2")
241    )).as("arr"))
242    plan = getAnalyzer(true).execute(plan)
243
244    val expectedPlan = prevPlan.select(
245      CreateArray(Seq(
246        CreateNamedStruct(Seq(
247          Literal(att1.name), att1,
248          Literal("a_plus_1"), (att1 + 1))),
249          'col.struct(prevPlan.output(0).dataType.asInstanceOf[StructType]).notNull
250      )).as("arr")
251    )
252
253    checkAnalysis(plan, expectedPlan)
254  }
255
256  test("SPARK-10534: resolve attribute references in order by clause") {
257    val a = testRelation2.output(0)
258    val c = testRelation2.output(2)
259
260    val plan = testRelation2.select('c).orderBy(Floor('a).asc)
261    val expected = testRelation2.select(c, a).orderBy(Floor(a.cast(DoubleType)).asc).select(c)
262
263    checkAnalysis(plan, expected)
264  }
265
266  test("self intersect should resolve duplicate expression IDs") {
267    val plan = testRelation.intersect(testRelation)
268    assertAnalysisSuccess(plan)
269  }
270
271  test("SPARK-8654: invalid CAST in NULL IN(...) expression") {
272    val plan = Project(Alias(In(Literal(null), Seq(Literal(1), Literal(2))), "a")() :: Nil,
273      LocalRelation()
274    )
275    assertAnalysisSuccess(plan)
276  }
277
278  test("SPARK-8654: different types in inlist but can be converted to a common type") {
279    val plan = Project(Alias(In(Literal(null), Seq(Literal(1), Literal(1.2345))), "a")() :: Nil,
280      LocalRelation()
281    )
282    assertAnalysisSuccess(plan)
283  }
284
285  test("SPARK-8654: check type compatibility error") {
286    val plan = Project(Alias(In(Literal(null), Seq(Literal(true), Literal(1))), "a")() :: Nil,
287      LocalRelation()
288    )
289    assertAnalysisError(plan, Seq("data type mismatch: Arguments must be same type"))
290  }
291
292  test("SPARK-11725: correctly handle null inputs for ScalaUDF") {
293    val string = testRelation2.output(0)
294    val double = testRelation2.output(2)
295    val short = testRelation2.output(4)
296    val nullResult = Literal.create(null, StringType)
297
298    def checkUDF(udf: Expression, transformed: Expression): Unit = {
299      checkAnalysis(
300        Project(Alias(udf, "")() :: Nil, testRelation2),
301        Project(Alias(transformed, "")() :: Nil, testRelation2)
302      )
303    }
304
305    // non-primitive parameters do not need special null handling
306    val udf1 = ScalaUDF((s: String) => "x", StringType, string :: Nil)
307    val expected1 = udf1
308    checkUDF(udf1, expected1)
309
310    // only primitive parameter needs special null handling
311    val udf2 = ScalaUDF((s: String, d: Double) => "x", StringType, string :: double :: Nil)
312    val expected2 = If(IsNull(double), nullResult, udf2)
313    checkUDF(udf2, expected2)
314
315    // special null handling should apply to all primitive parameters
316    val udf3 = ScalaUDF((s: Short, d: Double) => "x", StringType, short :: double :: Nil)
317    val expected3 = If(
318      IsNull(short) || IsNull(double),
319      nullResult,
320      udf3)
321    checkUDF(udf3, expected3)
322
323    // we can skip special null handling for primitive parameters that are not nullable
324    // TODO: this is disabled for now as we can not completely trust `nullable`.
325    val udf4 = ScalaUDF(
326      (s: Short, d: Double) => "x",
327      StringType,
328      short :: double.withNullability(false) :: Nil)
329    val expected4 = If(
330      IsNull(short),
331      nullResult,
332      udf4)
333    // checkUDF(udf4, expected4)
334  }
335
336  test("SPARK-11863 mixture of aliases and real columns in order by clause - tpcds 19,55,71") {
337    val a = testRelation2.output(0)
338    val c = testRelation2.output(2)
339    val alias1 = a.as("a1")
340    val alias2 = c.as("a2")
341    val alias3 = count(a).as("a3")
342
343    val plan = testRelation2
344      .groupBy('a, 'c)('a.as("a1"), 'c.as("a2"), count('a).as("a3"))
345      .orderBy('a1.asc, 'c.asc)
346
347    val expected = testRelation2
348      .groupBy(a, c)(alias1, alias2, alias3)
349      .orderBy(alias1.toAttribute.asc, alias2.toAttribute.asc)
350      .select(alias1.toAttribute, alias2.toAttribute, alias3.toAttribute)
351    checkAnalysis(plan, expected)
352  }
353
354  test("Eliminate the unnecessary union") {
355    val plan = Union(testRelation :: Nil)
356    val expected = testRelation
357    checkAnalysis(plan, expected)
358  }
359
360  test("SPARK-12102: Ignore nullablity when comparing two sides of case") {
361    val relation = LocalRelation('a.struct('x.int), 'b.struct('x.int.withNullability(false)))
362    val plan = relation.select(CaseWhen(Seq((Literal(true), 'a.attr)), 'b).as("val"))
363    assertAnalysisSuccess(plan)
364  }
365
366  test("Keep attribute qualifiers after dedup") {
367    val input = LocalRelation('key.int, 'value.string)
368
369    val query =
370      Project(Seq($"x.key", $"y.key"),
371        Join(
372          Project(Seq($"x.key"), SubqueryAlias("x", input, None)),
373          Project(Seq($"y.key"), SubqueryAlias("y", input, None)),
374          Cross, None))
375
376    assertAnalysisSuccess(query)
377  }
378
379  private def assertExpressionType(
380      expression: Expression,
381      expectedDataType: DataType): Unit = {
382    val afterAnalyze =
383      Project(Seq(Alias(expression, "a")()), OneRowRelation).analyze.expressions.head
384    if (!afterAnalyze.dataType.equals(expectedDataType)) {
385      fail(
386        s"""
387           |data type of expression $expression doesn't match expected:
388           |Actual data type:
389           |${afterAnalyze.dataType}
390           |
391           |Expected data type:
392           |${expectedDataType}
393         """.stripMargin)
394    }
395  }
396
397  test("SPARK-15776: test whether Divide expression's data type can be deduced correctly by " +
398    "analyzer") {
399    assertExpressionType(sum(Divide(1, 2)), DoubleType)
400    assertExpressionType(sum(Divide(1.0, 2)), DoubleType)
401    assertExpressionType(sum(Divide(1, 2.0)), DoubleType)
402    assertExpressionType(sum(Divide(1.0, 2.0)), DoubleType)
403    assertExpressionType(sum(Divide(1, 2.0f)), DoubleType)
404    assertExpressionType(sum(Divide(1.0f, 2)), DoubleType)
405    assertExpressionType(sum(Divide(1, Decimal(2))), DecimalType(31, 11))
406    assertExpressionType(sum(Divide(Decimal(1), 2)), DecimalType(31, 11))
407    assertExpressionType(sum(Divide(Decimal(1), 2.0)), DoubleType)
408    assertExpressionType(sum(Divide(1.0, Decimal(2.0))), DoubleType)
409  }
410
411  test("SPARK-18058: union and set operations shall not care about the nullability" +
412    " when comparing column types") {
413    val firstTable = LocalRelation(
414      AttributeReference("a",
415        StructType(Seq(StructField("a", IntegerType, nullable = true))), nullable = false)())
416    val secondTable = LocalRelation(
417      AttributeReference("a",
418        StructType(Seq(StructField("a", IntegerType, nullable = false))), nullable = false)())
419
420    val unionPlan = Union(firstTable, secondTable)
421    assertAnalysisSuccess(unionPlan)
422
423    val r1 = Except(firstTable, secondTable)
424    val r2 = Intersect(firstTable, secondTable)
425
426    assertAnalysisSuccess(r1)
427    assertAnalysisSuccess(r2)
428  }
429}
430